From 8696a45b096dd7fedb6e9683bef4de99220c976d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 7 Mar 2012 22:36:22 -0500 Subject: [PATCH] - [bug] Fixed event registration bug which would primarily show up as events not being registered with sessionmaker() instances created after the event was associated with the Session class. [ticket:2424] --- CHANGES | 7 +++ lib/sqlalchemy/event.py | 35 +++++++++--- lib/sqlalchemy/orm/session.py | 2 +- test/base/test_events.py | 104 ++++++++++++++++++++++++++++++++++ test/orm/test_events.py | 16 ++++-- 5 files changed, 148 insertions(+), 16 deletions(-) diff --git a/CHANGES b/CHANGES index 74c6df6532..c98b80f569 100644 --- a/CHANGES +++ b/CHANGES @@ -6,6 +6,13 @@ CHANGES 0.7.6 ===== - orm + - [bug] Fixed event registration bug + which would primarily show up as + events not being registered with + sessionmaker() instances created + after the event was associated + with the Session class. [ticket:2424] + - [feature] Added "no_autoflush" context manager to Session, used with with: will temporarily disable autoflush. diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py index 9cc3139afc..cd70b3a7c4 100644 --- a/lib/sqlalchemy/event.py +++ b/lib/sqlalchemy/event.py @@ -13,12 +13,12 @@ NO_RETVAL = util.symbol('NO_RETVAL') def listen(target, identifier, fn, *args, **kw): """Register a listener function for the given target. - + e.g.:: - + from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - + def unique_constraint_name(const, table): const.name = "uq_%s_%s" % ( table.name, @@ -41,12 +41,12 @@ def listen(target, identifier, fn, *args, **kw): def listens_for(target, identifier, *args, **kw): """Decorate a function as a listener for the given target + identifier. - + e.g.:: - + from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - + @event.listens_for(UniqueConstraint, "after_parent_attach") def unique_constraint_name(const, table): const.name = "uq_%s_%s" % ( @@ -205,12 +205,14 @@ class _DispatchDescriptor(object): def insert(self, obj, target, propagate): assert isinstance(target, type), \ "Class-level Event targets must be classes." - stack = [target] while stack: cls = stack.pop(0) stack.extend(cls.__subclasses__()) - self._clslevel[cls].insert(0, obj) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + self._clslevel[cls].insert(0, obj) def append(self, obj, target, propagate): assert isinstance(target, type), \ @@ -220,7 +222,20 @@ class _DispatchDescriptor(object): while stack: cls = stack.pop(0) stack.extend(cls.__subclasses__()) - self._clslevel[cls].append(obj) + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + self._clslevel[cls].append(obj) + + def update_subclass(self, target): + clslevel = self._clslevel[target] + for cls in target.__mro__[1:]: + if cls in self._clslevel: + clslevel.extend([ + fn for fn + in self._clslevel[cls] + if fn not in clslevel + ]) def remove(self, obj, target): stack = [target] @@ -252,6 +267,8 @@ class _ListenerCollection(object): _exec_once = False def __init__(self, parent, target_cls): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) self.parent_listeners = parent._clslevel[target_cls] self.name = parent.__name__ self.listeners = [] diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d01c1598ae..14778705d0 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -99,7 +99,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, kwargs.update(new_kwargs) - return type("Session", (Sess, class_), {}) + return type("SessionMaker", (Sess, class_), {}) class SessionTransaction(object): diff --git a/test/base/test_events.py b/test/base/test_events.py index 94d3dad855..3ec0f99531 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -117,6 +117,110 @@ class TestEvents(fixtures.TestBase): [listen_two] ) +class TestClsLevelListen(fixtures.TestBase): + def setUp(self): + class TargetEventsOne(event.Events): + def event_one(self, x, y): + pass + class TargetOne(object): + dispatch = event.dispatcher(TargetEventsOne) + self.TargetOne = TargetOne + + def tearDown(self): + event._remove_dispatcher( + self.TargetOne.__dict__['dispatch'].events) + + def test_lis_subcalss_lis(self): + @event.listens_for(self.TargetOne, "event_one") + def handler1(x, y): + print 'handler1' + + class SubTarget(self.TargetOne): + pass + + @event.listens_for(self.TargetOne, "event_one") + def handler2(x, y): + pass + + eq_( + len(SubTarget().dispatch.event_one), + 2 + ) + + def test_lis_multisub_lis(self): + @event.listens_for(self.TargetOne, "event_one") + def handler1(x, y): + print 'handler1' + + class SubTarget(self.TargetOne): + pass + + class SubSubTarget(SubTarget): + pass + + @event.listens_for(self.TargetOne, "event_one") + def handler2(x, y): + pass + + eq_( + len(SubTarget().dispatch.event_one), + 2 + ) + eq_( + len(SubSubTarget().dispatch.event_one), + 2 + ) + + def test_two_sub_lis(self): + class SubTarget1(self.TargetOne): + pass + class SubTarget2(self.TargetOne): + pass + + @event.listens_for(self.TargetOne, "event_one") + def handler1(x, y): + pass + @event.listens_for(SubTarget1, "event_one") + def handler2(x, y): + pass + + s1 = SubTarget1() + assert handler1 in s1.dispatch.event_one + assert handler2 in s1.dispatch.event_one + + s2 = SubTarget2() + assert handler1 in s2.dispatch.event_one + assert handler2 not in s2.dispatch.event_one + + +class TestClsLevelListen(fixtures.TestBase): + def setUp(self): + class TargetEventsOne(event.Events): + def event_one(self, x, y): + pass + class TargetOne(object): + dispatch = event.dispatcher(TargetEventsOne) + self.TargetOne = TargetOne + + def tearDown(self): + event._remove_dispatcher(self.TargetOne.__dict__['dispatch'].events) + + def test_lis_subcalss_lis(self): + @event.listens_for(self.TargetOne, "event_one") + def handler1(x, y): + print 'handler1' + + class SubTarget(self.TargetOne): + pass + + @event.listens_for(self.TargetOne, "event_one") + def handler2(x, y): + pass + + eq_( + len(SubTarget().dispatch.event_one), + 2 + ) class TestAcceptTargets(fixtures.TestBase): """Test default target acceptance.""" diff --git a/test/orm/test_events.py b/test/orm/test_events.py index e52b9299f5..f8158369c2 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -129,11 +129,12 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): sess.flush() eq_(canary, ['init', 'before_insert', - 'after_insert', 'expire', 'translate_row', 'populate_instance', - 'refresh', + 'after_insert', 'expire', 'translate_row', + 'populate_instance', 'refresh', 'append_result', 'translate_row', 'create_instance', 'populate_instance', 'load', 'append_result', - 'before_update', 'after_update', 'before_delete', 'after_delete']) + 'before_update', 'after_update', 'before_delete', + 'after_delete']) def test_merge(self): users, User = self.tables.users, self.classes.User @@ -203,7 +204,8 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): """ - keywords, items, item_keywords, Keyword, Item = (self.tables.keywords, + keywords, items, item_keywords, Keyword, Item = ( + self.tables.keywords, self.tables.items, self.tables.item_keywords, self.classes.Keyword, @@ -466,7 +468,8 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): assert my_listener in s.dispatch.before_flush def test_sessionmaker_listen(self): - """test that listen can be applied to individual scoped_session() classes.""" + """test that listen can be applied to individual + scoped_session() classes.""" def my_listener_one(*arg, **kw): pass @@ -564,7 +567,8 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): mapper(User, users) - sess, canary = self._listener_fixture(autoflush=False, autocommit=True, expire_on_commit=False) + sess, canary = self._listener_fixture(autoflush=False, + autocommit=True, expire_on_commit=False) u = User(name='u1') sess.add(u) -- 2.47.2