return wrap_kw
- def insert(self, event_key, propagate):
+ def _do_insert_or_append(self, event_key, is_append):
target = event_key.dispatch_target
assert isinstance(
target, type
), "Class-level Event targets must be classes."
if not getattr(target, "_sa_propagate_class_events", True):
raise exc.InvalidRequestError(
- "Can't assign an event directly to the %s class" % target
+ "Can't assign an event directly to the %s class" % (target,)
)
for cls in util.walk_subclasses(target):
self.update_subclass(cls)
else:
if cls not in self._clslevel:
- self._assign_cls_collection(cls)
- self._clslevel[cls].appendleft(event_key._listen_fn)
+ self.update_subclass(cls)
+ if is_append:
+ self._clslevel[cls].append(event_key._listen_fn)
+ else:
+ self._clslevel[cls].appendleft(event_key._listen_fn)
registry._stored_in_collection(event_key, self)
- def append(self, event_key, propagate):
- target = event_key.dispatch_target
- assert isinstance(
- target, type
- ), "Class-level Event targets must be classes."
- if not getattr(target, "_sa_propagate_class_events", True):
- raise exc.InvalidRequestError(
- "Can't assign an event directly to the %s class" % target
- )
- for cls in util.walk_subclasses(target):
- if cls is not target and cls not in self._clslevel:
- self.update_subclass(cls)
- else:
- if cls not in self._clslevel:
- self._assign_cls_collection(cls)
- self._clslevel[cls].append(event_key._listen_fn)
- registry._stored_in_collection(event_key, self)
+ def insert(self, event_key, propagate):
+ self._do_insert_or_append(event_key, is_append=False)
- def _assign_cls_collection(self, target):
- if getattr(target, "_sa_propagate_class_events", True):
- self._clslevel[target] = collections.deque()
- else:
- self._clslevel[target] = _empty_collection()
+ def append(self, event_key, propagate):
+ self._do_insert_or_append(event_key, is_append=True)
def update_subclass(self, target):
if target not in self._clslevel:
- self._assign_cls_collection(target)
+ if getattr(target, "_sa_propagate_class_events", True):
+ self._clslevel[target] = collections.deque()
+ else:
+ self._clslevel[target] = _empty_collection()
+
clslevel = self._clslevel[target]
+
for cls in target.__mro__[1:]:
if cls in self._clslevel:
clslevel.extend(
def remove(self, event_key):
target = event_key.dispatch_target
+
for cls in util.walk_subclasses(target):
if cls in self._clslevel:
self._clslevel[cls].remove(event_key._listen_fn)
eq_(len(SubTarget().dispatch.event_one), 2)
+ @testing.combinations(True, False, argnames="m1")
+ @testing.combinations(True, False, argnames="m2")
+ @testing.combinations(True, False, argnames="m3")
+ @testing.combinations(True, False, argnames="use_insert")
+ def test_subclass_gen_after_clslisten(self, m1, m2, m3, use_insert):
+ """test #8467"""
+ m1 = Mock() if m1 else None
+ m2 = Mock() if m2 else None
+ m3 = Mock() if m3 else None
+
+ if m1:
+ event.listen(self.TargetOne, "event_one", m1, insert=use_insert)
+
+ class SubTarget(self.TargetOne):
+ pass
+
+ if m2:
+ event.listen(SubTarget, "event_one", m2, insert=use_insert)
+
+ if m3:
+ event.listen(self.TargetOne, "event_one", m3, insert=use_insert)
+
+ st = SubTarget()
+ st.dispatch.event_one()
+
+ for m in m1, m2, m3:
+ if m:
+ eq_(m.mock_calls, [call()])
+
def test_lis_multisub_lis(self):
@event.listens_for(self.TargetOne, "event_one")
def handler1(x, y):
s = fixture_session()
assert my_listener in s.dispatch.before_flush
+ @testing.combinations(True, False, argnames="m1")
+ @testing.combinations(True, False, argnames="m2")
+ @testing.combinations(True, False, argnames="m3")
+ @testing.combinations(True, False, argnames="use_insert")
+ def test_sessionmaker_gen_after_session_listen(
+ self, m1, m2, m3, use_insert
+ ):
+ m1 = Mock() if m1 else None
+ m2 = Mock() if m2 else None
+ m3 = Mock() if m3 else None
+
+ if m1:
+ event.listen(Session, "before_flush", m1, insert=use_insert)
+
+ factory = sessionmaker()
+
+ if m2:
+ event.listen(factory, "before_flush", m2, insert=use_insert)
+
+ if m3:
+ event.listen(factory, "before_flush", m3, insert=use_insert)
+
+ st = factory()
+ st.dispatch.before_flush()
+
+ for m in m1, m2, m3:
+ if m:
+ eq_(m.mock_calls, [call()])
+
def test_sessionmaker_listen(self):
"""test that listen can be applied to individual
scoped_session() classes."""