]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
run update_subclass anytime we add new clslevel dispatch
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Aug 2022 15:07:23 +0000 (11:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Aug 2022 18:31:34 +0000 (14:31 -0400)
Fixed event listening issue where event listeners added to a superclass
would be lost if a subclass were created which then had its own listeners
associated. The practical example is that of the :class:`.sessionmaker`
class created after events have been associated with the
:class:`_orm.Session` class.

Fixes: #8467
Change-Id: I9bdba8769147e30110a09900d4a577e833ac3af9

doc/build/changelog/unreleased_14/8467.rst [new file with mode: 0644]
lib/sqlalchemy/event/attr.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_events.py
test/orm/test_events.py

diff --git a/doc/build/changelog/unreleased_14/8467.rst b/doc/build/changelog/unreleased_14/8467.rst
new file mode 100644 (file)
index 0000000..7626f50
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, events, orm
+    :tickets: 8467
+
+    Fixed event listening issue where event listeners added to a superclass
+    would be lost if a subclass were created which then had its own listeners
+    associated. The practical example is that of the :class:`.sessionmaker`
+    class created after events have been associated with the
+    :class:`_orm.Session` class.
index f8d70a06a6f817fa8a8ed8f6e4a862398d131102..21d0a2274193bd534e092ede09f2db1ef8a8b132 100644 (file)
@@ -175,57 +175,48 @@ class _ClsLevelDispatch(RefCollection[_ET]):
 
         return wrap_kw
 
-    def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
+    def _do_insert_or_append(
+        self, event_key: _EventKey[_ET], is_append: bool
+    ) -> None:
         target = event_key.dispatch_target
         assert isinstance(
             target, type
         ), "Class-level Event targets must be classes."
         if not getattr(target, "_sa_propagate_class_events", True):
             raise exc.InvalidRequestError(
-                "Can't assign an event directly to the %s class" % target
+                f"Can't assign an event directly to the {target} class"
             )
 
-        for cls in util.walk_subclasses(target):
-            cls = cast(Type[_ET], cls)
-            if cls is not target and cls not in self._clslevel:
-                self.update_subclass(cls)
-            else:
-                if cls not in self._clslevel:
-                    self._assign_cls_collection(cls)
-                self._clslevel[cls].appendleft(event_key._listen_fn)
-        registry._stored_in_collection(event_key, self)
+        cls: Type[_ET]
 
-    def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
-        target = event_key.dispatch_target
-        assert isinstance(
-            target, type
-        ), "Class-level Event targets must be classes."
-        if not getattr(target, "_sa_propagate_class_events", True):
-            raise exc.InvalidRequestError(
-                "Can't assign an event directly to the %s class" % target
-            )
         for cls in util.walk_subclasses(target):
-            cls = cast("Type[_ET]", cls)
             if cls is not target and cls not in self._clslevel:
                 self.update_subclass(cls)
             else:
                 if cls not in self._clslevel:
-                    self._assign_cls_collection(cls)
-                self._clslevel[cls].append(event_key._listen_fn)
+                    self.update_subclass(cls)
+                if is_append:
+                    self._clslevel[cls].append(event_key._listen_fn)
+                else:
+                    self._clslevel[cls].appendleft(event_key._listen_fn)
         registry._stored_in_collection(event_key, self)
 
-    def _assign_cls_collection(self, target: Type[_ET]) -> None:
-        if getattr(target, "_sa_propagate_class_events", True):
-            self._clslevel[target] = collections.deque()
-        else:
-            self._clslevel[target] = _empty_collection()
+    def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None:
+        self._do_insert_or_append(event_key, is_append=False)
+
+    def append(self, event_key: _EventKey[_ET], propagate: bool) -> None:
+        self._do_insert_or_append(event_key, is_append=True)
 
     def update_subclass(self, target: Type[_ET]) -> None:
         if target not in self._clslevel:
-            self._assign_cls_collection(target)
+            if getattr(target, "_sa_propagate_class_events", True):
+                self._clslevel[target] = collections.deque()
+            else:
+                self._clslevel[target] = _empty_collection()
+
         clslevel = self._clslevel[target]
+        cls: Type[_ET]
         for cls in target.__mro__[1:]:
-            cls = cast("Type[_ET]", cls)
             if cls in self._clslevel:
                 clslevel.extend(
                     [fn for fn in self._clslevel[cls] if fn not in clslevel]
@@ -233,8 +224,8 @@ class _ClsLevelDispatch(RefCollection[_ET]):
 
     def remove(self, event_key: _EventKey[_ET]) -> None:
         target = event_key.dispatch_target
+        cls: Type[_ET]
         for cls in util.walk_subclasses(target):
-            cls = cast("Type[_ET]", cls)
             if cls in self._clslevel:
                 self._clslevel[cls].remove(event_key._listen_fn)
         registry._removed_from_collection(event_key, self)
index 95dc1d4d4953f5e4a39056bf5dafc3652b6bad64..86b2952cb83e3dab5f827e4ad88c340e257352a7 100644 (file)
@@ -559,7 +559,8 @@ def _new_annotation_type(
 
 
 def _prepare_annotations(
-    target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated]
+    target_hierarchy: Type[SupportsWrappingAnnotations],
+    base_cls: Type[Annotated],
 ) -> None:
     for cls in util.walk_subclasses(target_hierarchy):
         _new_annotation_type(cls, base_cls)
index 66354f6b642c6eac99d0ecda339168a9dce0a977..70c9bba9f8b0d02f178d0d8b46f8a7bef7dd95fa 100644 (file)
@@ -152,7 +152,7 @@ class safe_reraise:
             raise value.with_traceback(traceback)
 
 
-def walk_subclasses(cls: type) -> Iterator[type]:
+def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]:
     seen: Set[Any] = set()
 
     stack = [cls]
index 7e978d23b00faf3fd00c50059ee79dddb5c91666..67933a5fe0fb003a9783a3026fe59343eeb2d5d0 100644 (file)
@@ -677,6 +677,35 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
 
         eq_(len(SubTarget().dispatch.event_one), 2)
 
+    @testing.combinations(True, False, argnames="m1")
+    @testing.combinations(True, False, argnames="m2")
+    @testing.combinations(True, False, argnames="m3")
+    @testing.combinations(True, False, argnames="use_insert")
+    def test_subclass_gen_after_clslisten(self, m1, m2, m3, use_insert):
+        """test #8467"""
+        m1 = Mock() if m1 else None
+        m2 = Mock() if m2 else None
+        m3 = Mock() if m3 else None
+
+        if m1:
+            event.listen(self.TargetOne, "event_one", m1, insert=use_insert)
+
+        class SubTarget(self.TargetOne):
+            pass
+
+        if m2:
+            event.listen(SubTarget, "event_one", m2, insert=use_insert)
+
+        if m3:
+            event.listen(self.TargetOne, "event_one", m3, insert=use_insert)
+
+        st = SubTarget()
+        st.dispatch.event_one()
+
+        for m in m1, m2, m3:
+            if m:
+                eq_(m.mock_calls, [call()])
+
     def test_lis_multisub_lis(self):
         @event.listens_for(self.TargetOne, "event_one")
         def handler1(x, y):
index 7e1b29cb1bb6d83102bc2a84242f013219ec8e17..24870e20f1219b164d8a9765f738ac0d3e31f015 100644 (file)
@@ -2195,6 +2195,35 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
         s = fixture_session()
         assert my_listener in s.dispatch.before_flush
 
+    @testing.combinations(True, False, argnames="m1")
+    @testing.combinations(True, False, argnames="m2")
+    @testing.combinations(True, False, argnames="m3")
+    @testing.combinations(True, False, argnames="use_insert")
+    def test_sessionmaker_gen_after_session_listen(
+        self, m1, m2, m3, use_insert
+    ):
+        m1 = Mock() if m1 else None
+        m2 = Mock() if m2 else None
+        m3 = Mock() if m3 else None
+
+        if m1:
+            event.listen(Session, "before_flush", m1, insert=use_insert)
+
+        factory = sessionmaker()
+
+        if m2:
+            event.listen(factory, "before_flush", m2, insert=use_insert)
+
+        if m3:
+            event.listen(factory, "before_flush", m3, insert=use_insert)
+
+        st = factory()
+        st.dispatch.before_flush()
+
+        for m in m1, m2, m3:
+            if m:
+                eq_(m.mock_calls, [call()])
+
     def test_sessionmaker_listen(self):
         """test that listen can be applied to individual
         scoped_session() classes."""