]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
run update_subclass anytime we add new clslevel dispatch
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Aug 2022 15:07:23 +0000 (11:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Aug 2022 20:54:33 +0000 (16:54 -0400)
Fixed event listening issue where event listeners added to a superclass
would be lost if a subclass were created which then had its own listeners
associated. The practical example is that of the :class:`.sessionmaker`
class created after events have been associated with the
:class:`_orm.Session` class.

Fixes: #8467
Change-Id: I9bdba8769147e30110a09900d4a577e833ac3af9
(cherry picked from commit d3e0b8e750d864766148cdf1a658a601079eed46)

doc/build/changelog/unreleased_14/8467.rst [new file with mode: 0644]
lib/sqlalchemy/event/attr.py
test/base/test_events.py
test/orm/test_events.py

diff --git a/doc/build/changelog/unreleased_14/8467.rst b/doc/build/changelog/unreleased_14/8467.rst
new file mode 100644 (file)
index 0000000..7626f50
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, events, orm
+    :tickets: 8467
+
+    Fixed event listening issue where event listeners added to a superclass
+    would be lost if a subclass were created which then had its own listeners
+    associated. The practical example is that of the :class:`.sessionmaker`
+    class created after events have been associated with the
+    :class:`_orm.Session` class.
index 0d16165c4ee95a77d9d4ae22ab8131ed2eb69bba..09b5a2267f04ff28d2266e317fefa21c05cb64d3 100644 (file)
@@ -118,14 +118,14 @@ class _ClsLevelDispatch(RefCollection):
 
         return wrap_kw
 
-    def insert(self, event_key, propagate):
+    def _do_insert_or_append(self, event_key, is_append):
         target = event_key.dispatch_target
         assert isinstance(
             target, type
         ), "Class-level Event targets must be classes."
         if not getattr(target, "_sa_propagate_class_events", True):
             raise exc.InvalidRequestError(
-                "Can't assign an event directly to the %s class" % target
+                "Can't assign an event directly to the %s class" % (target,)
             )
 
         for cls in util.walk_subclasses(target):
@@ -133,38 +133,28 @@ class _ClsLevelDispatch(RefCollection):
                 self.update_subclass(cls)
             else:
                 if cls not in self._clslevel:
-                    self._assign_cls_collection(cls)
-                self._clslevel[cls].appendleft(event_key._listen_fn)
+                    self.update_subclass(cls)
+                if is_append:
+                    self._clslevel[cls].append(event_key._listen_fn)
+                else:
+                    self._clslevel[cls].appendleft(event_key._listen_fn)
         registry._stored_in_collection(event_key, self)
 
-    def append(self, event_key, propagate):
-        target = event_key.dispatch_target
-        assert isinstance(
-            target, type
-        ), "Class-level Event targets must be classes."
-        if not getattr(target, "_sa_propagate_class_events", True):
-            raise exc.InvalidRequestError(
-                "Can't assign an event directly to the %s class" % target
-            )
-        for cls in util.walk_subclasses(target):
-            if cls is not target and cls not in self._clslevel:
-                self.update_subclass(cls)
-            else:
-                if cls not in self._clslevel:
-                    self._assign_cls_collection(cls)
-                self._clslevel[cls].append(event_key._listen_fn)
-        registry._stored_in_collection(event_key, self)
+    def insert(self, event_key, propagate):
+        self._do_insert_or_append(event_key, is_append=False)
 
-    def _assign_cls_collection(self, target):
-        if getattr(target, "_sa_propagate_class_events", True):
-            self._clslevel[target] = collections.deque()
-        else:
-            self._clslevel[target] = _empty_collection()
+    def append(self, event_key, propagate):
+        self._do_insert_or_append(event_key, is_append=True)
 
     def update_subclass(self, target):
         if target not in self._clslevel:
-            self._assign_cls_collection(target)
+            if getattr(target, "_sa_propagate_class_events", True):
+                self._clslevel[target] = collections.deque()
+            else:
+                self._clslevel[target] = _empty_collection()
+
         clslevel = self._clslevel[target]
+
         for cls in target.__mro__[1:]:
             if cls in self._clslevel:
                 clslevel.extend(
@@ -173,6 +163,7 @@ class _ClsLevelDispatch(RefCollection):
 
     def remove(self, event_key):
         target = event_key.dispatch_target
+
         for cls in util.walk_subclasses(target):
             if cls in self._clslevel:
                 self._clslevel[cls].remove(event_key._listen_fn)
index 4409d6b2947f0917f385e3af33f10e2f205b66b5..e8ed0ff362860ef7e3a7ec9f7eecdc9bfb9604d7 100644 (file)
@@ -677,6 +677,35 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
 
         eq_(len(SubTarget().dispatch.event_one), 2)
 
+    @testing.combinations(True, False, argnames="m1")
+    @testing.combinations(True, False, argnames="m2")
+    @testing.combinations(True, False, argnames="m3")
+    @testing.combinations(True, False, argnames="use_insert")
+    def test_subclass_gen_after_clslisten(self, m1, m2, m3, use_insert):
+        """test #8467"""
+        m1 = Mock() if m1 else None
+        m2 = Mock() if m2 else None
+        m3 = Mock() if m3 else None
+
+        if m1:
+            event.listen(self.TargetOne, "event_one", m1, insert=use_insert)
+
+        class SubTarget(self.TargetOne):
+            pass
+
+        if m2:
+            event.listen(SubTarget, "event_one", m2, insert=use_insert)
+
+        if m3:
+            event.listen(self.TargetOne, "event_one", m3, insert=use_insert)
+
+        st = SubTarget()
+        st.dispatch.event_one()
+
+        for m in m1, m2, m3:
+            if m:
+                eq_(m.mock_calls, [call()])
+
     def test_lis_multisub_lis(self):
         @event.listens_for(self.TargetOne, "event_one")
         def handler1(x, y):
index 4009dc3aecbb1c811e183e3605b0e7245d087d63..50265510042505984e19272de2eb58a679d8bf88 100644 (file)
@@ -2078,6 +2078,35 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
         s = fixture_session()
         assert my_listener in s.dispatch.before_flush
 
+    @testing.combinations(True, False, argnames="m1")
+    @testing.combinations(True, False, argnames="m2")
+    @testing.combinations(True, False, argnames="m3")
+    @testing.combinations(True, False, argnames="use_insert")
+    def test_sessionmaker_gen_after_session_listen(
+        self, m1, m2, m3, use_insert
+    ):
+        m1 = Mock() if m1 else None
+        m2 = Mock() if m2 else None
+        m3 = Mock() if m3 else None
+
+        if m1:
+            event.listen(Session, "before_flush", m1, insert=use_insert)
+
+        factory = sessionmaker()
+
+        if m2:
+            event.listen(factory, "before_flush", m2, insert=use_insert)
+
+        if m3:
+            event.listen(factory, "before_flush", m3, insert=use_insert)
+
+        st = factory()
+        st.dispatch.before_flush()
+
+        for m in m1, m2, m3:
+            if m:
+                eq_(m.mock_calls, [call()])
+
     def test_sessionmaker_listen(self):
         """test that listen can be applied to individual
         scoped_session() classes."""