From: Mike Bayer Date: Wed, 31 Aug 2022 15:07:23 +0000 (-0400) Subject: run update_subclass anytime we add new clslevel dispatch X-Git-Tag: rel_2_0_0b1~80 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d3e0b8e750d864766148cdf1a658a601079eed46;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git run update_subclass anytime we add new clslevel dispatch Fixed event listening issue where event listeners added to a superclass would be lost if a subclass were created which then had its own listeners associated. The practical example is that of the :class:`.sessionmaker` class created after events have been associated with the :class:`_orm.Session` class. Fixes: #8467 Change-Id: I9bdba8769147e30110a09900d4a577e833ac3af9 --- diff --git a/doc/build/changelog/unreleased_14/8467.rst b/doc/build/changelog/unreleased_14/8467.rst new file mode 100644 index 0000000000..7626f50a39 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8467.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, events, orm + :tickets: 8467 + + Fixed event listening issue where event listeners added to a superclass + would be lost if a subclass were created which then had its own listeners + associated. The practical example is that of the :class:`.sessionmaker` + class created after events have been associated with the + :class:`_orm.Session` class. diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index f8d70a06a6..21d0a22741 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -175,57 +175,48 @@ class _ClsLevelDispatch(RefCollection[_ET]): return wrap_kw - def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + def _do_insert_or_append( + self, event_key: _EventKey[_ET], is_append: bool + ) -> None: target = event_key.dispatch_target assert isinstance( target, type ), "Class-level Event targets must be classes." if not getattr(target, "_sa_propagate_class_events", True): raise exc.InvalidRequestError( - "Can't assign an event directly to the %s class" % target + f"Can't assign an event directly to the {target} class" ) - for cls in util.walk_subclasses(target): - cls = cast(Type[_ET], cls) - if cls is not target and cls not in self._clslevel: - self.update_subclass(cls) - else: - if cls not in self._clslevel: - self._assign_cls_collection(cls) - self._clslevel[cls].appendleft(event_key._listen_fn) - registry._stored_in_collection(event_key, self) + cls: Type[_ET] - def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: - target = event_key.dispatch_target - assert isinstance( - target, type - ), "Class-level Event targets must be classes." - if not getattr(target, "_sa_propagate_class_events", True): - raise exc.InvalidRequestError( - "Can't assign an event directly to the %s class" % target - ) for cls in util.walk_subclasses(target): - cls = cast("Type[_ET]", cls) if cls is not target and cls not in self._clslevel: self.update_subclass(cls) else: if cls not in self._clslevel: - self._assign_cls_collection(cls) - self._clslevel[cls].append(event_key._listen_fn) + self.update_subclass(cls) + if is_append: + self._clslevel[cls].append(event_key._listen_fn) + else: + self._clslevel[cls].appendleft(event_key._listen_fn) registry._stored_in_collection(event_key, self) - def _assign_cls_collection(self, target: Type[_ET]) -> None: - if getattr(target, "_sa_propagate_class_events", True): - self._clslevel[target] = collections.deque() - else: - self._clslevel[target] = _empty_collection() + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self._do_insert_or_append(event_key, is_append=False) + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self._do_insert_or_append(event_key, is_append=True) def update_subclass(self, target: Type[_ET]) -> None: if target not in self._clslevel: - self._assign_cls_collection(target) + if getattr(target, "_sa_propagate_class_events", True): + self._clslevel[target] = collections.deque() + else: + self._clslevel[target] = _empty_collection() + clslevel = self._clslevel[target] + cls: Type[_ET] for cls in target.__mro__[1:]: - cls = cast("Type[_ET]", cls) if cls in self._clslevel: clslevel.extend( [fn for fn in self._clslevel[cls] if fn not in clslevel] @@ -233,8 +224,8 @@ class _ClsLevelDispatch(RefCollection[_ET]): def remove(self, event_key: _EventKey[_ET]) -> None: target = event_key.dispatch_target + cls: Type[_ET] for cls in util.walk_subclasses(target): - cls = cast("Type[_ET]", cls) if cls in self._clslevel: self._clslevel[cls].remove(event_key._listen_fn) registry._removed_from_collection(event_key, self) diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 95dc1d4d49..86b2952cb8 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -559,7 +559,8 @@ def _new_annotation_type( def _prepare_annotations( - target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated] + target_hierarchy: Type[SupportsWrappingAnnotations], + base_cls: Type[Annotated], ) -> None: for cls in util.walk_subclasses(target_hierarchy): _new_annotation_type(cls, base_cls) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 66354f6b64..70c9bba9f8 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -152,7 +152,7 @@ class safe_reraise: raise value.with_traceback(traceback) -def walk_subclasses(cls: type) -> Iterator[type]: +def walk_subclasses(cls: Type[_T]) -> Iterator[Type[_T]]: seen: Set[Any] = set() stack = [cls] diff --git a/test/base/test_events.py b/test/base/test_events.py index 7e978d23b0..67933a5fe0 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -677,6 +677,35 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase): eq_(len(SubTarget().dispatch.event_one), 2) + @testing.combinations(True, False, argnames="m1") + @testing.combinations(True, False, argnames="m2") + @testing.combinations(True, False, argnames="m3") + @testing.combinations(True, False, argnames="use_insert") + def test_subclass_gen_after_clslisten(self, m1, m2, m3, use_insert): + """test #8467""" + m1 = Mock() if m1 else None + m2 = Mock() if m2 else None + m3 = Mock() if m3 else None + + if m1: + event.listen(self.TargetOne, "event_one", m1, insert=use_insert) + + class SubTarget(self.TargetOne): + pass + + if m2: + event.listen(SubTarget, "event_one", m2, insert=use_insert) + + if m3: + event.listen(self.TargetOne, "event_one", m3, insert=use_insert) + + st = SubTarget() + st.dispatch.event_one() + + for m in m1, m2, m3: + if m: + eq_(m.mock_calls, [call()]) + def test_lis_multisub_lis(self): @event.listens_for(self.TargetOne, "event_one") def handler1(x, y): diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 7e1b29cb1b..24870e20f1 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -2195,6 +2195,35 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): s = fixture_session() assert my_listener in s.dispatch.before_flush + @testing.combinations(True, False, argnames="m1") + @testing.combinations(True, False, argnames="m2") + @testing.combinations(True, False, argnames="m3") + @testing.combinations(True, False, argnames="use_insert") + def test_sessionmaker_gen_after_session_listen( + self, m1, m2, m3, use_insert + ): + m1 = Mock() if m1 else None + m2 = Mock() if m2 else None + m3 = Mock() if m3 else None + + if m1: + event.listen(Session, "before_flush", m1, insert=use_insert) + + factory = sessionmaker() + + if m2: + event.listen(factory, "before_flush", m2, insert=use_insert) + + if m3: + event.listen(factory, "before_flush", m3, insert=use_insert) + + st = factory() + st.dispatch.before_flush() + + for m in m1, m2, m3: + if m: + eq_(m.mock_calls, [call()]) + def test_sessionmaker_listen(self): """test that listen can be applied to individual scoped_session() classes."""