From: Mike Bayer Date: Sat, 2 Oct 2010 21:22:37 +0000 (-0400) Subject: - begin adding tests for event registration and dispatch standalone X-Git-Tag: rel_0_7b1~253^2~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d9dc05adb689bc4eab2227a96af0d874696cc63d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - begin adding tests for event registration and dispatch standalone - fix pickling again - other test fixes --- diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py index c7df2bf485..0f6342e6b2 100644 --- a/lib/sqlalchemy/event.py +++ b/lib/sqlalchemy/event.py @@ -11,20 +11,13 @@ CANCEL = util.symbol('CANCEL') NO_RETVAL = util.symbol('NO_RETVAL') def listen(fn, identifier, target, *args, **kw): - """Listen for events, accepting an event function that's "raw". - Only the exact arguments are received in order. - - This is used by SQLA internals simply to reduce the overhead - of creating an event dictionary for each event call. + """Register a listener function for the given target. """ - - # rationale - the events on ClassManager, Session, and Mapper - # will need to accept mapped classes directly as targets and know - # what to do for evt_cls in _registrars[identifier]: - for tgt in evt_cls.accept_with(target): + tgt = evt_cls.accept_with(target) + if tgt is not None: tgt.dispatch.listen(fn, identifier, tgt, *args, **kw) return raise exc.InvalidRequestError("No such event %s for target %s" % @@ -41,10 +34,18 @@ def remove(fn, identifier, target): for evt_cls in _registrars[identifier]: for tgt in evt_cls.accept_with(target): tgt.dispatch.remove(fn, identifier, tgt, *args, **kw) - + return _registrars = util.defaultdict(list) +class _UnpickleDispatch(object): + """Serializable callable that re-generates an instance of :class:`_Dispatch` + given a particular :class:`.Events` subclass. + + """ + def __call__(self, parent_cls): + return parent_cls.__dict__['dispatch'].dispatch_cls(parent_cls) + class _Dispatch(object): """Mirror the event listening definitions of an Events class with listener collections. @@ -55,9 +56,8 @@ class _Dispatch(object): self.parent_cls = parent_cls def __reduce__(self): - return dispatcher, ( - self.parent_cls.__dict__['dispatch'].events, - ) + + return _UnpickleDispatch(), (self.parent_cls, ) @property def descriptors(self): @@ -81,7 +81,7 @@ class _EventMeta(type): def __init__(cls, classname, bases, dict_): _create_dispatcher_class(cls, classname, bases, dict_) return type.__init__(cls, classname, bases, dict_) - + def _create_dispatcher_class(cls, classname, bases, dict_): # there's all kinds of ways to do this, # i.e. make a Dispatch class that shares the 'listen' method @@ -96,6 +96,13 @@ def _create_dispatcher_class(cls, classname, bases, dict_): setattr(dispatch_cls, k, _DispatchDescriptor(dict_[k])) _registrars[k].append(cls) +def _remove_dispatcher(cls): + for k in dir(cls): + if k.startswith('on_'): + _registrars[k].remove(cls) + if not _registrars[k]: + del _registrars[k] + class Events(object): """Define event listening functions for a particular target type.""" @@ -111,9 +118,9 @@ class Events(object): isinstance(target.dispatch, type) and \ issubclass(target.dispatch, cls.dispatch) ): - return [target] + return target else: - return [] + return None @classmethod def listen(cls, fn, identifier, target): diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 29274ac3b5..ff9f7dbc6c 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -20,9 +20,9 @@ class InstrumentationEvents(event.Events): from sqlalchemy.orm.instrumentation import instrumentation_registry if isinstance(target, type): - return [instrumentation_registry] + return instrumentation_registry else: - return [] + return None @classmethod def listen(cls, fn, identifier, target): @@ -64,12 +64,12 @@ class InstanceEvents(event.Events): from sqlalchemy.orm.instrumentation import ClassManager, manager_of_class if isinstance(target, ClassManager): - return [target] + return target elif isinstance(target, type): manager = manager_of_class(target) if manager: - return [manager] - return [] + return manager + return None @classmethod def listen(cls, fn, identifier, target, raw=False): @@ -185,26 +185,51 @@ class AttributeEvents(event.Events): def unwrap(cls, identifier, event): return event['value'] - def on_append(self, state, value, initiator): + def on_append(self, target, value, initiator): """Receive a collection append event. - The returned value will be used as the actual value to be - appended. + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being appended. If this listener + is registered with ``retval=True``, the listener + function must return this value, or a new value which + replaces it. + :param initiator: the attribute implementation object + which initiated this event. """ - def on_remove(self, state, value, initiator): - """Receive a remove event. + def on_remove(self, target, value, initiator): + """Receive a collection remove event. - No return value is defined. + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being removed. + :param initiator: the attribute implementation object + which initiated this event. """ - def on_set(self, state, value, oldvalue, initiator): - """Receive a set event. - - The returned value will be used as the actual value to be - set. + def on_set(self, target, value, oldvalue, initiator): + """Receive a scalar set event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being set. If this listener + is registered with ``retval=True``, the listener + function must return this value, or a new value which + replaces it. + :param oldvalue: the previous value being replaced. This + may also be the symbol ``NEVER_SET`` or ``NO_VALUE``. + If the listener is registered with ``active_history=True``, + the previous value of the attribute will be loaded from + the database if the existing value is currently unloaded + or expired. + :param initiator: the attribute implementation object + which initiated this event. """ diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 02ba5e1a22..52c1c7213c 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -361,7 +361,7 @@ class _ClassInstrumentationAdapter(ClassManager): self._adapted.instrument_attribute(self.class_, key, inst) def post_configure_attribute(self, key): - super(_ClassInstrumentationAdpter, self).post_configure_attribute(key) + super(_ClassInstrumentationAdapter, self).post_configure_attribute(key) self._adapted.post_configure_attribute(self.class_, key, self[key]) def install_descriptor(self, key, inst): diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 42fc5b98ee..dc8a07c177 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -509,7 +509,7 @@ class MutableAttrInstanceState(InstanceState): obj.__dict__.update(self.mutable_dict) # re-establishes identity attributes from the key - self.manager.dispatch.on_resurrect(self, obj) + self.manager.dispatch.on_resurrect(self) return obj diff --git a/test/base/test_events.py b/test/base/test_events.py new file mode 100644 index 0000000000..9099619e5f --- /dev/null +++ b/test/base/test_events.py @@ -0,0 +1,231 @@ +"""Test event registration and listening.""" + +from sqlalchemy.test.testing import TestBase, eq_, assert_raises +from sqlalchemy import event, exc, util + +class TestEvents(TestBase): + """Test class- and instance-level event registration.""" + + def setUp(self): + global Target + + assert 'on_event_one' not in event._registrars + assert 'on_event_two' not in event._registrars + + class TargetEvents(event.Events): + def on_event_one(self, x, y): + pass + + def on_event_two(self, x): + pass + + class Target(object): + dispatch = event.dispatcher(TargetEvents) + + def tearDown(self): + event._remove_dispatcher(Target.__dict__['dispatch'].events) + + def test_register_class(self): + def listen(x, y): + pass + + event.listen(listen, "on_event_one", Target) + + eq_(len(Target().dispatch.on_event_one), 1) + eq_(len(Target().dispatch.on_event_two), 0) + + def test_register_instance(self): + def listen(x, y): + pass + + t1 = Target() + event.listen(listen, "on_event_one", t1) + + eq_(len(Target().dispatch.on_event_one), 0) + eq_(len(t1.dispatch.on_event_one), 1) + eq_(len(Target().dispatch.on_event_two), 0) + eq_(len(t1.dispatch.on_event_two), 0) + + def test_register_class_instance(self): + def listen_one(x, y): + pass + + def listen_two(x, y): + pass + + event.listen(listen_one, "on_event_one", Target) + + t1 = Target() + event.listen(listen_two, "on_event_one", t1) + + eq_(len(Target().dispatch.on_event_one), 1) + eq_(len(t1.dispatch.on_event_one), 2) + eq_(len(Target().dispatch.on_event_two), 0) + eq_(len(t1.dispatch.on_event_two), 0) + + def listen_three(x, y): + pass + + event.listen(listen_three, "on_event_one", Target) + eq_(len(Target().dispatch.on_event_one), 2) + eq_(len(t1.dispatch.on_event_one), 3) + +class TestAcceptTargets(TestBase): + """Test default target acceptance.""" + + def setUp(self): + global TargetOne, TargetTwo + + class TargetEventsOne(event.Events): + def on_event_one(self, x, y): + pass + + class TargetEventsTwo(event.Events): + def on_event_one(self, x, y): + pass + + class TargetOne(object): + dispatch = event.dispatcher(TargetEventsOne) + + class TargetTwo(object): + dispatch = event.dispatcher(TargetEventsTwo) + + def tearDown(self): + event._remove_dispatcher(TargetOne.__dict__['dispatch'].events) + event._remove_dispatcher(TargetTwo.__dict__['dispatch'].events) + + def test_target_accept(self): + """Test that events of the same name are routed to the correct + collection based on the type of target given. + + """ + def listen_one(x, y): + pass + + def listen_two(x, y): + pass + + def listen_three(x, y): + pass + + def listen_four(x, y): + pass + + event.listen(listen_one, "on_event_one", TargetOne) + event.listen(listen_two, "on_event_one", TargetTwo) + + eq_( + list(TargetOne().dispatch.on_event_one), + [listen_one] + ) + + eq_( + list(TargetTwo().dispatch.on_event_one), + [listen_two] + ) + + t1 = TargetOne() + t2 = TargetTwo() + + event.listen(listen_three, "on_event_one", t1) + event.listen(listen_four, "on_event_one", t2) + + eq_( + list(t1.dispatch.on_event_one), + [listen_one, listen_three] + ) + + eq_( + list(t2.dispatch.on_event_one), + [listen_two, listen_four] + ) + +class TestCustomTargets(TestBase): + """Test custom target acceptance.""" + + def setUp(self): + global Target + + class TargetEvents(event.Events): + @classmethod + def accept_with(cls, target): + if target == 'one': + return Target + else: + return None + + def on_event_one(self, x, y): + pass + + class Target(object): + dispatch = event.dispatcher(TargetEvents) + + def tearDown(self): + event._remove_dispatcher(Target.__dict__['dispatch'].events) + + def test_indirect(self): + def listen(x, y): + pass + + event.listen(listen, "on_event_one", "one") + + eq_( + list(Target().dispatch.on_event_one), + [listen] + ) + + assert_raises( + exc.InvalidRequestError, + event.listen, + listen, "on_event_one", Target + ) + +class TestListenOverride(TestBase): + """Test custom listen functions which change the listener function signature.""" + + def setUp(self): + global Target + + class TargetEvents(event.Events): + @classmethod + def listen(cls, fn, identifier, target, add=False): + if add: + def adapt(x, y): + fn(x + y) + else: + adapt = fn + + event.Events.listen(adapt, identifier, target) + + def on_event_one(self, x, y): + pass + + class Target(object): + dispatch = event.dispatcher(TargetEvents) + + def tearDown(self): + event._remove_dispatcher(Target.__dict__['dispatch'].events) + + def test_listen_override(self): + result = [] + def listen_one(x): + result.append(x) + + def listen_two(x, y): + result.append((x, y)) + + event.listen(listen_one, "on_event_one", Target, add=True) + event.listen(listen_two, "on_event_one", Target) + + t1 = Target() + t1.dispatch.on_event_one(5, 7) + t1.dispatch.on_event_one(10, 5) + + eq_(result, + [ + 12, (5, 7), 15, (10, 5) + ] + ) + + + \ No newline at end of file