From 3c68c7c0341ac41b11185491cf2165336dfed1de Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 31 Aug 2022 11:07:23 -0400 Subject: [PATCH] run update_subclass anytime we add new clslevel dispatch Fixed event listening issue where event listeners added to a superclass would be lost if a subclass were created which then had its own listeners associated. The practical example is that of the :class:`.sessionmaker` class created after events have been associated with the :class:`_orm.Session` class. Fixes: #8467 Change-Id: I9bdba8769147e30110a09900d4a577e833ac3af9 (cherry picked from commit d3e0b8e750d864766148cdf1a658a601079eed46) --- doc/build/changelog/unreleased_14/8467.rst | 9 +++++ lib/sqlalchemy/event/attr.py | 45 +++++++++------------- test/base/test_events.py | 29 ++++++++++++++ test/orm/test_events.py | 29 ++++++++++++++ 4 files changed, 85 insertions(+), 27 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/8467.rst diff --git a/doc/build/changelog/unreleased_14/8467.rst b/doc/build/changelog/unreleased_14/8467.rst new file mode 100644 index 0000000000..7626f50a39 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8467.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, events, orm + :tickets: 8467 + + Fixed event listening issue where event listeners added to a superclass + would be lost if a subclass were created which then had its own listeners + associated. The practical example is that of the :class:`.sessionmaker` + class created after events have been associated with the + :class:`_orm.Session` class. diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index 0d16165c4e..09b5a2267f 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -118,14 +118,14 @@ class _ClsLevelDispatch(RefCollection): return wrap_kw - def insert(self, event_key, propagate): + def _do_insert_or_append(self, event_key, is_append): target = event_key.dispatch_target assert isinstance( target, type ), "Class-level Event targets must be classes." if not getattr(target, "_sa_propagate_class_events", True): raise exc.InvalidRequestError( - "Can't assign an event directly to the %s class" % target + "Can't assign an event directly to the %s class" % (target,) ) for cls in util.walk_subclasses(target): @@ -133,38 +133,28 @@ class _ClsLevelDispatch(RefCollection): self.update_subclass(cls) else: if cls not in self._clslevel: - self._assign_cls_collection(cls) - self._clslevel[cls].appendleft(event_key._listen_fn) + self.update_subclass(cls) + if is_append: + self._clslevel[cls].append(event_key._listen_fn) + else: + self._clslevel[cls].appendleft(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def append(self, event_key, propagate): - target = event_key.dispatch_target - assert isinstance( - target, type - ), "Class-level Event targets must be classes." - if not getattr(target, "_sa_propagate_class_events", True): - raise exc.InvalidRequestError( - "Can't assign an event directly to the %s class" % target - ) - for cls in util.walk_subclasses(target): - if cls is not target and cls not in self._clslevel: - self.update_subclass(cls) - else: - if cls not in self._clslevel: - self._assign_cls_collection(cls) - self._clslevel[cls].append(event_key._listen_fn) - registry._stored_in_collection(event_key, self) + def insert(self, event_key, propagate): + self._do_insert_or_append(event_key, is_append=False) - def _assign_cls_collection(self, target): - if getattr(target, "_sa_propagate_class_events", True): - self._clslevel[target] = collections.deque() - else: - self._clslevel[target] = _empty_collection() + def append(self, event_key, propagate): + self._do_insert_or_append(event_key, is_append=True) def update_subclass(self, target): if target not in self._clslevel: - self._assign_cls_collection(target) + if getattr(target, "_sa_propagate_class_events", True): + self._clslevel[target] = collections.deque() + else: + self._clslevel[target] = _empty_collection() + clslevel = self._clslevel[target] + for cls in target.__mro__[1:]: if cls in self._clslevel: clslevel.extend( @@ -173,6 +163,7 @@ class _ClsLevelDispatch(RefCollection): def remove(self, event_key): target = event_key.dispatch_target + for cls in util.walk_subclasses(target): if cls in self._clslevel: self._clslevel[cls].remove(event_key._listen_fn) diff --git a/test/base/test_events.py b/test/base/test_events.py index 4409d6b294..e8ed0ff362 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -677,6 +677,35 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase): eq_(len(SubTarget().dispatch.event_one), 2) + @testing.combinations(True, False, argnames="m1") + @testing.combinations(True, False, argnames="m2") + @testing.combinations(True, False, argnames="m3") + @testing.combinations(True, False, argnames="use_insert") + def test_subclass_gen_after_clslisten(self, m1, m2, m3, use_insert): + """test #8467""" + m1 = Mock() if m1 else None + m2 = Mock() if m2 else None + m3 = Mock() if m3 else None + + if m1: + event.listen(self.TargetOne, "event_one", m1, insert=use_insert) + + class SubTarget(self.TargetOne): + pass + + if m2: + event.listen(SubTarget, "event_one", m2, insert=use_insert) + + if m3: + event.listen(self.TargetOne, "event_one", m3, insert=use_insert) + + st = SubTarget() + st.dispatch.event_one() + + for m in m1, m2, m3: + if m: + eq_(m.mock_calls, [call()]) + def test_lis_multisub_lis(self): @event.listens_for(self.TargetOne, "event_one") def handler1(x, y): diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 4009dc3aec..5026551004 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -2078,6 +2078,35 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): s = fixture_session() assert my_listener in s.dispatch.before_flush + @testing.combinations(True, False, argnames="m1") + @testing.combinations(True, False, argnames="m2") + @testing.combinations(True, False, argnames="m3") + @testing.combinations(True, False, argnames="use_insert") + def test_sessionmaker_gen_after_session_listen( + self, m1, m2, m3, use_insert + ): + m1 = Mock() if m1 else None + m2 = Mock() if m2 else None + m3 = Mock() if m3 else None + + if m1: + event.listen(Session, "before_flush", m1, insert=use_insert) + + factory = sessionmaker() + + if m2: + event.listen(factory, "before_flush", m2, insert=use_insert) + + if m3: + event.listen(factory, "before_flush", m3, insert=use_insert) + + st = factory() + st.dispatch.before_flush() + + for m in m1, m2, m3: + if m: + eq_(m.mock_calls, [call()]) + def test_sessionmaker_listen(self): """test that listen can be applied to individual scoped_session() classes.""" -- 2.47.2