NO_RETVAL = util.symbol('NO_RETVAL')
def listen(fn, identifier, target, *args, **kw):
- """Listen for events, accepting an event function that's "raw".
- Only the exact arguments are received in order.
-
- This is used by SQLA internals simply to reduce the overhead
- of creating an event dictionary for each event call.
+ """Register a listener function for the given target.
"""
-
- # rationale - the events on ClassManager, Session, and Mapper
- # will need to accept mapped classes directly as targets and know
- # what to do
for evt_cls in _registrars[identifier]:
- for tgt in evt_cls.accept_with(target):
+ tgt = evt_cls.accept_with(target)
+ if tgt is not None:
tgt.dispatch.listen(fn, identifier, tgt, *args, **kw)
return
raise exc.InvalidRequestError("No such event %s for target %s" %
for evt_cls in _registrars[identifier]:
for tgt in evt_cls.accept_with(target):
tgt.dispatch.remove(fn, identifier, tgt, *args, **kw)
-
+ return
_registrars = util.defaultdict(list)
+class _UnpickleDispatch(object):
+ """Serializable callable that re-generates an instance of :class:`_Dispatch`
+ given a particular :class:`.Events` subclass.
+
+ """
+ def __call__(self, parent_cls):
+ return parent_cls.__dict__['dispatch'].dispatch_cls(parent_cls)
+
class _Dispatch(object):
"""Mirror the event listening definitions of an Events class with
listener collections.
self.parent_cls = parent_cls
def __reduce__(self):
- return dispatcher, (
- self.parent_cls.__dict__['dispatch'].events,
- )
+
+ return _UnpickleDispatch(), (self.parent_cls, )
@property
def descriptors(self):
def __init__(cls, classname, bases, dict_):
_create_dispatcher_class(cls, classname, bases, dict_)
return type.__init__(cls, classname, bases, dict_)
-
+
def _create_dispatcher_class(cls, classname, bases, dict_):
# there's all kinds of ways to do this,
# i.e. make a Dispatch class that shares the 'listen' method
setattr(dispatch_cls, k, _DispatchDescriptor(dict_[k]))
_registrars[k].append(cls)
+def _remove_dispatcher(cls):
+ for k in dir(cls):
+ if k.startswith('on_'):
+ _registrars[k].remove(cls)
+ if not _registrars[k]:
+ del _registrars[k]
+
class Events(object):
"""Define event listening functions for a particular target type."""
isinstance(target.dispatch, type) and \
issubclass(target.dispatch, cls.dispatch)
):
- return [target]
+ return target
else:
- return []
+ return None
@classmethod
def listen(cls, fn, identifier, target):
from sqlalchemy.orm.instrumentation import instrumentation_registry
if isinstance(target, type):
- return [instrumentation_registry]
+ return instrumentation_registry
else:
- return []
+ return None
@classmethod
def listen(cls, fn, identifier, target):
from sqlalchemy.orm.instrumentation import ClassManager, manager_of_class
if isinstance(target, ClassManager):
- return [target]
+ return target
elif isinstance(target, type):
manager = manager_of_class(target)
if manager:
- return [manager]
- return []
+ return manager
+ return None
@classmethod
def listen(cls, fn, identifier, target, raw=False):
def unwrap(cls, identifier, event):
return event['value']
- def on_append(self, state, value, initiator):
+ def on_append(self, target, value, initiator):
"""Receive a collection append event.
- The returned value will be used as the actual value to be
- appended.
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being appended. If this listener
+ is registered with ``retval=True``, the listener
+ function must return this value, or a new value which
+ replaces it.
+ :param initiator: the attribute implementation object
+ which initiated this event.
"""
- def on_remove(self, state, value, initiator):
- """Receive a remove event.
+ def on_remove(self, target, value, initiator):
+ """Receive a collection remove event.
- No return value is defined.
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being removed.
+ :param initiator: the attribute implementation object
+ which initiated this event.
"""
- def on_set(self, state, value, oldvalue, initiator):
- """Receive a set event.
-
- The returned value will be used as the actual value to be
- set.
+ def on_set(self, target, value, oldvalue, initiator):
+ """Receive a scalar set event.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being set. If this listener
+ is registered with ``retval=True``, the listener
+ function must return this value, or a new value which
+ replaces it.
+ :param oldvalue: the previous value being replaced. This
+ may also be the symbol ``NEVER_SET`` or ``NO_VALUE``.
+ If the listener is registered with ``active_history=True``,
+ the previous value of the attribute will be loaded from
+ the database if the existing value is currently unloaded
+ or expired.
+ :param initiator: the attribute implementation object
+ which initiated this event.
"""
self._adapted.instrument_attribute(self.class_, key, inst)
def post_configure_attribute(self, key):
- super(_ClassInstrumentationAdpter, self).post_configure_attribute(key)
+ super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
self._adapted.post_configure_attribute(self.class_, key, self[key])
def install_descriptor(self, key, inst):
obj.__dict__.update(self.mutable_dict)
# re-establishes identity attributes from the key
- self.manager.dispatch.on_resurrect(self, obj)
+ self.manager.dispatch.on_resurrect(self)
return obj
--- /dev/null
+"""Test event registration and listening."""
+
+from sqlalchemy.test.testing import TestBase, eq_, assert_raises
+from sqlalchemy import event, exc, util
+
+class TestEvents(TestBase):
+ """Test class- and instance-level event registration."""
+
+ def setUp(self):
+ global Target
+
+ assert 'on_event_one' not in event._registrars
+ assert 'on_event_two' not in event._registrars
+
+ class TargetEvents(event.Events):
+ def on_event_one(self, x, y):
+ pass
+
+ def on_event_two(self, x):
+ pass
+
+ class Target(object):
+ dispatch = event.dispatcher(TargetEvents)
+
+ def tearDown(self):
+ event._remove_dispatcher(Target.__dict__['dispatch'].events)
+
+ def test_register_class(self):
+ def listen(x, y):
+ pass
+
+ event.listen(listen, "on_event_one", Target)
+
+ eq_(len(Target().dispatch.on_event_one), 1)
+ eq_(len(Target().dispatch.on_event_two), 0)
+
+ def test_register_instance(self):
+ def listen(x, y):
+ pass
+
+ t1 = Target()
+ event.listen(listen, "on_event_one", t1)
+
+ eq_(len(Target().dispatch.on_event_one), 0)
+ eq_(len(t1.dispatch.on_event_one), 1)
+ eq_(len(Target().dispatch.on_event_two), 0)
+ eq_(len(t1.dispatch.on_event_two), 0)
+
+ def test_register_class_instance(self):
+ def listen_one(x, y):
+ pass
+
+ def listen_two(x, y):
+ pass
+
+ event.listen(listen_one, "on_event_one", Target)
+
+ t1 = Target()
+ event.listen(listen_two, "on_event_one", t1)
+
+ eq_(len(Target().dispatch.on_event_one), 1)
+ eq_(len(t1.dispatch.on_event_one), 2)
+ eq_(len(Target().dispatch.on_event_two), 0)
+ eq_(len(t1.dispatch.on_event_two), 0)
+
+ def listen_three(x, y):
+ pass
+
+ event.listen(listen_three, "on_event_one", Target)
+ eq_(len(Target().dispatch.on_event_one), 2)
+ eq_(len(t1.dispatch.on_event_one), 3)
+
+class TestAcceptTargets(TestBase):
+ """Test default target acceptance."""
+
+ def setUp(self):
+ global TargetOne, TargetTwo
+
+ class TargetEventsOne(event.Events):
+ def on_event_one(self, x, y):
+ pass
+
+ class TargetEventsTwo(event.Events):
+ def on_event_one(self, x, y):
+ pass
+
+ class TargetOne(object):
+ dispatch = event.dispatcher(TargetEventsOne)
+
+ class TargetTwo(object):
+ dispatch = event.dispatcher(TargetEventsTwo)
+
+ def tearDown(self):
+ event._remove_dispatcher(TargetOne.__dict__['dispatch'].events)
+ event._remove_dispatcher(TargetTwo.__dict__['dispatch'].events)
+
+ def test_target_accept(self):
+ """Test that events of the same name are routed to the correct
+ collection based on the type of target given.
+
+ """
+ def listen_one(x, y):
+ pass
+
+ def listen_two(x, y):
+ pass
+
+ def listen_three(x, y):
+ pass
+
+ def listen_four(x, y):
+ pass
+
+ event.listen(listen_one, "on_event_one", TargetOne)
+ event.listen(listen_two, "on_event_one", TargetTwo)
+
+ eq_(
+ list(TargetOne().dispatch.on_event_one),
+ [listen_one]
+ )
+
+ eq_(
+ list(TargetTwo().dispatch.on_event_one),
+ [listen_two]
+ )
+
+ t1 = TargetOne()
+ t2 = TargetTwo()
+
+ event.listen(listen_three, "on_event_one", t1)
+ event.listen(listen_four, "on_event_one", t2)
+
+ eq_(
+ list(t1.dispatch.on_event_one),
+ [listen_one, listen_three]
+ )
+
+ eq_(
+ list(t2.dispatch.on_event_one),
+ [listen_two, listen_four]
+ )
+
+class TestCustomTargets(TestBase):
+ """Test custom target acceptance."""
+
+ def setUp(self):
+ global Target
+
+ class TargetEvents(event.Events):
+ @classmethod
+ def accept_with(cls, target):
+ if target == 'one':
+ return Target
+ else:
+ return None
+
+ def on_event_one(self, x, y):
+ pass
+
+ class Target(object):
+ dispatch = event.dispatcher(TargetEvents)
+
+ def tearDown(self):
+ event._remove_dispatcher(Target.__dict__['dispatch'].events)
+
+ def test_indirect(self):
+ def listen(x, y):
+ pass
+
+ event.listen(listen, "on_event_one", "one")
+
+ eq_(
+ list(Target().dispatch.on_event_one),
+ [listen]
+ )
+
+ assert_raises(
+ exc.InvalidRequestError,
+ event.listen,
+ listen, "on_event_one", Target
+ )
+
+class TestListenOverride(TestBase):
+ """Test custom listen functions which change the listener function signature."""
+
+ def setUp(self):
+ global Target
+
+ class TargetEvents(event.Events):
+ @classmethod
+ def listen(cls, fn, identifier, target, add=False):
+ if add:
+ def adapt(x, y):
+ fn(x + y)
+ else:
+ adapt = fn
+
+ event.Events.listen(adapt, identifier, target)
+
+ def on_event_one(self, x, y):
+ pass
+
+ class Target(object):
+ dispatch = event.dispatcher(TargetEvents)
+
+ def tearDown(self):
+ event._remove_dispatcher(Target.__dict__['dispatch'].events)
+
+ def test_listen_override(self):
+ result = []
+ def listen_one(x):
+ result.append(x)
+
+ def listen_two(x, y):
+ result.append((x, y))
+
+ event.listen(listen_one, "on_event_one", Target, add=True)
+ event.listen(listen_two, "on_event_one", Target)
+
+ t1 = Target()
+ t1.dispatch.on_event_one(5, 7)
+ t1.dispatch.on_event_one(10, 5)
+
+ eq_(result,
+ [
+ 12, (5, 7), 15, (10, 5)
+ ]
+ )
+
+
+
\ No newline at end of file