]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- begin adding tests for event registration and dispatch standalone
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Oct 2010 21:22:37 +0000 (17:22 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Oct 2010 21:22:37 +0000 (17:22 -0400)
- fix pickling again
- other test fixes

lib/sqlalchemy/event.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/orm/state.py
test/base/test_events.py [new file with mode: 0644]

index c7df2bf4859630cc4d6cc072251fe8b2a0a3bb72..0f6342e6b2844171772119bcfed373a4936a5e24 100644 (file)
@@ -11,20 +11,13 @@ CANCEL = util.symbol('CANCEL')
 NO_RETVAL = util.symbol('NO_RETVAL')
 
 def listen(fn, identifier, target, *args, **kw):
-    """Listen for events, accepting an event function that's "raw".
-    Only the exact arguments are received in order.
-    
-    This is used by SQLA internals simply to reduce the overhead
-    of creating an event dictionary for each event call.
+    """Register a listener function for the given target.
     
     """
-
-    # rationale - the events on ClassManager, Session, and Mapper
-    # will need to accept mapped classes directly as targets and know 
-    # what to do
     
     for evt_cls in _registrars[identifier]:
-        for tgt in evt_cls.accept_with(target):
+        tgt = evt_cls.accept_with(target)
+        if tgt is not None:
             tgt.dispatch.listen(fn, identifier, tgt, *args, **kw)
             return
     raise exc.InvalidRequestError("No such event %s for target %s" %
@@ -41,10 +34,18 @@ def remove(fn, identifier, target):
     for evt_cls in _registrars[identifier]:
         for tgt in evt_cls.accept_with(target):
             tgt.dispatch.remove(fn, identifier, tgt, *args, **kw)
-
+            return
         
 _registrars = util.defaultdict(list)
 
+class _UnpickleDispatch(object):
+    """Serializable callable that re-generates an instance of :class:`_Dispatch`
+    given a particular :class:`.Events` subclass.
+    
+    """
+    def __call__(self, parent_cls):
+        return parent_cls.__dict__['dispatch'].dispatch_cls(parent_cls)
+        
 class _Dispatch(object):
     """Mirror the event listening definitions of an Events class with 
     listener collections.
@@ -55,9 +56,8 @@ class _Dispatch(object):
         self.parent_cls = parent_cls
     
     def __reduce__(self):
-        return dispatcher, (
-                            self.parent_cls.__dict__['dispatch'].events, 
-                            )
+        
+        return _UnpickleDispatch(), (self.parent_cls, )
         
     @property
     def descriptors(self):
@@ -81,7 +81,7 @@ class _EventMeta(type):
     def __init__(cls, classname, bases, dict_):
         _create_dispatcher_class(cls, classname, bases, dict_)
         return type.__init__(cls, classname, bases, dict_)
-        
+    
 def _create_dispatcher_class(cls, classname, bases, dict_):
     # there's all kinds of ways to do this,
     # i.e. make a Dispatch class that shares the 'listen' method
@@ -96,6 +96,13 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
             setattr(dispatch_cls, k, _DispatchDescriptor(dict_[k]))
             _registrars[k].append(cls)
 
+def _remove_dispatcher(cls):
+    for k in dir(cls):
+        if k.startswith('on_'):
+            _registrars[k].remove(cls)
+            if not _registrars[k]:
+                del _registrars[k]
+    
 class Events(object):
     """Define event listening functions for a particular target type."""
     
@@ -111,9 +118,9 @@ class Events(object):
                     isinstance(target.dispatch, type) and \
                     issubclass(target.dispatch, cls.dispatch)
                 ):
-            return [target]
+            return target
         else:
-            return []
+            return None
 
     @classmethod
     def listen(cls, fn, identifier, target):
index 29274ac3b56563e656624c17c13666e83e07e31d..ff9f7dbc6ca88ffa553e95b12fc4fac2664b6519 100644 (file)
@@ -20,9 +20,9 @@ class InstrumentationEvents(event.Events):
         from sqlalchemy.orm.instrumentation import instrumentation_registry
         
         if isinstance(target, type):
-            return [instrumentation_registry]
+            return instrumentation_registry
         else:
-            return []
+            return None
 
     @classmethod
     def listen(cls, fn, identifier, target):
@@ -64,12 +64,12 @@ class InstanceEvents(event.Events):
         from sqlalchemy.orm.instrumentation import ClassManager, manager_of_class
         
         if isinstance(target, ClassManager):
-            return [target]
+            return target
         elif isinstance(target, type):
             manager = manager_of_class(target)
             if manager:
-                return [manager]
-        return []
+                return manager
+        return None
     
     @classmethod
     def listen(cls, fn, identifier, target, raw=False):
@@ -185,26 +185,51 @@ class AttributeEvents(event.Events):
     def unwrap(cls, identifier, event):
         return event['value']
         
-    def on_append(self, state, value, initiator):
+    def on_append(self, target, value, initiator):
         """Receive a collection append event.
 
-        The returned value will be used as the actual value to be
-        appended.
+        :param target: the object instance receiving the event.
+          If the listener is registered with ``raw=True``, this will
+          be the :class:`.InstanceState` object.
+        :param value: the value being appended.  If this listener
+          is registered with ``retval=True``, the listener
+          function must return this value, or a new value which 
+          replaces it.
+        :param initiator: the attribute implementation object 
+          which initiated this event.
 
         """
 
-    def on_remove(self, state, value, initiator):
-        """Receive a remove event.
+    def on_remove(self, target, value, initiator):
+        """Receive a collection remove event.
 
-        No return value is defined.
+        :param target: the object instance receiving the event.
+          If the listener is registered with ``raw=True``, this will
+          be the :class:`.InstanceState` object.
+        :param value: the value being removed.
+        :param initiator: the attribute implementation object 
+          which initiated this event.
 
         """
 
-    def on_set(self, state, value, oldvalue, initiator):
-        """Receive a set event.
-
-        The returned value will be used as the actual value to be
-        set.
+    def on_set(self, target, value, oldvalue, initiator):
+        """Receive a scalar set event.
+
+        :param target: the object instance receiving the event.
+          If the listener is registered with ``raw=True``, this will
+          be the :class:`.InstanceState` object.
+        :param value: the value being set.  If this listener
+          is registered with ``retval=True``, the listener
+          function must return this value, or a new value which 
+          replaces it.
+        :param oldvalue: the previous value being replaced.  This
+          may also be the symbol ``NEVER_SET`` or ``NO_VALUE``.
+          If the listener is registered with ``active_history=True``,
+          the previous value of the attribute will be loaded from
+          the database if the existing value is currently unloaded 
+          or expired.
+        :param initiator: the attribute implementation object 
+          which initiated this event.
 
         """
 
index 02ba5e1a22d1a1d05732a304921019f0764ebce3..52c1c7213c061c8621156ffe4cfc9de533dc72ee 100644 (file)
@@ -361,7 +361,7 @@ class _ClassInstrumentationAdapter(ClassManager):
             self._adapted.instrument_attribute(self.class_, key, inst)
 
     def post_configure_attribute(self, key):
-        super(_ClassInstrumentationAdpter, self).post_configure_attribute(key)
+        super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
         self._adapted.post_configure_attribute(self.class_, key, self[key])
 
     def install_descriptor(self, key, inst):
index 42fc5b98ee63db6e2165a26b50a1d89f1bf6c386..dc8a07c177d6704189b6bfdae3abf3b76340d1db 100644 (file)
@@ -509,7 +509,7 @@ class MutableAttrInstanceState(InstanceState):
         obj.__dict__.update(self.mutable_dict)
 
         # re-establishes identity attributes from the key
-        self.manager.dispatch.on_resurrect(self, obj)
+        self.manager.dispatch.on_resurrect(self)
         
         return obj
 
diff --git a/test/base/test_events.py b/test/base/test_events.py
new file mode 100644 (file)
index 0000000..9099619
--- /dev/null
@@ -0,0 +1,231 @@
+"""Test event registration and listening."""
+
+from sqlalchemy.test.testing import TestBase, eq_, assert_raises
+from sqlalchemy import event, exc, util
+
+class TestEvents(TestBase):
+    """Test class- and instance-level event registration."""
+    
+    def setUp(self):
+        global Target
+        
+        assert 'on_event_one' not in event._registrars
+        assert 'on_event_two' not in event._registrars
+        
+        class TargetEvents(event.Events):
+            def on_event_one(self, x, y):
+                pass
+            
+            def on_event_two(self, x):
+                pass
+                
+        class Target(object):
+            dispatch = event.dispatcher(TargetEvents)
+    
+    def tearDown(self):
+        event._remove_dispatcher(Target.__dict__['dispatch'].events)
+    
+    def test_register_class(self):
+        def listen(x, y):
+            pass
+        
+        event.listen(listen, "on_event_one", Target)
+        
+        eq_(len(Target().dispatch.on_event_one), 1)
+        eq_(len(Target().dispatch.on_event_two), 0)
+
+    def test_register_instance(self):
+        def listen(x, y):
+            pass
+        
+        t1 = Target()
+        event.listen(listen, "on_event_one", t1)
+
+        eq_(len(Target().dispatch.on_event_one), 0)
+        eq_(len(t1.dispatch.on_event_one), 1)
+        eq_(len(Target().dispatch.on_event_two), 0)
+        eq_(len(t1.dispatch.on_event_two), 0)
+    
+    def test_register_class_instance(self):
+        def listen_one(x, y):
+            pass
+
+        def listen_two(x, y):
+            pass
+
+        event.listen(listen_one, "on_event_one", Target)
+        
+        t1 = Target()
+        event.listen(listen_two, "on_event_one", t1)
+
+        eq_(len(Target().dispatch.on_event_one), 1)
+        eq_(len(t1.dispatch.on_event_one), 2)
+        eq_(len(Target().dispatch.on_event_two), 0)
+        eq_(len(t1.dispatch.on_event_two), 0)
+        
+        def listen_three(x, y):
+            pass
+        
+        event.listen(listen_three, "on_event_one", Target)
+        eq_(len(Target().dispatch.on_event_one), 2)
+        eq_(len(t1.dispatch.on_event_one), 3)
+        
+class TestAcceptTargets(TestBase):
+    """Test default target acceptance."""
+    
+    def setUp(self):
+        global TargetOne, TargetTwo
+        
+        class TargetEventsOne(event.Events):
+            def on_event_one(self, x, y):
+                pass
+
+        class TargetEventsTwo(event.Events):
+            def on_event_one(self, x, y):
+                pass
+                
+        class TargetOne(object):
+            dispatch = event.dispatcher(TargetEventsOne)
+
+        class TargetTwo(object):
+            dispatch = event.dispatcher(TargetEventsTwo)
+    
+    def tearDown(self):
+        event._remove_dispatcher(TargetOne.__dict__['dispatch'].events)
+        event._remove_dispatcher(TargetTwo.__dict__['dispatch'].events)
+        
+    def test_target_accept(self):
+        """Test that events of the same name are routed to the correct
+        collection based on the type of target given.
+        
+        """
+        def listen_one(x, y):
+            pass
+
+        def listen_two(x, y):
+            pass
+        
+        def listen_three(x, y):
+            pass
+
+        def listen_four(x, y):
+            pass
+            
+        event.listen(listen_one, "on_event_one", TargetOne)
+        event.listen(listen_two, "on_event_one", TargetTwo)
+        
+        eq_(
+            list(TargetOne().dispatch.on_event_one),
+            [listen_one]
+        )
+
+        eq_(
+            list(TargetTwo().dispatch.on_event_one),
+            [listen_two]
+        )
+
+        t1 = TargetOne()
+        t2 = TargetTwo()
+
+        event.listen(listen_three, "on_event_one", t1)
+        event.listen(listen_four, "on_event_one", t2)
+        
+        eq_(
+            list(t1.dispatch.on_event_one),
+            [listen_one, listen_three]
+        )
+
+        eq_(
+            list(t2.dispatch.on_event_one),
+            [listen_two, listen_four]
+        )
+        
+class TestCustomTargets(TestBase):
+    """Test custom target acceptance."""
+
+    def setUp(self):
+        global Target
+        
+        class TargetEvents(event.Events):
+            @classmethod
+            def accept_with(cls, target):
+                if target == 'one':
+                    return Target
+                else:
+                    return None
+                    
+            def on_event_one(self, x, y):
+                pass
+
+        class Target(object):
+            dispatch = event.dispatcher(TargetEvents)
+
+    def tearDown(self):
+        event._remove_dispatcher(Target.__dict__['dispatch'].events)
+    
+    def test_indirect(self):
+        def listen(x, y):
+            pass
+        
+        event.listen(listen, "on_event_one", "one")
+
+        eq_(
+            list(Target().dispatch.on_event_one),
+            [listen]
+        )
+        
+        assert_raises(
+            exc.InvalidRequestError, 
+            event.listen,
+            listen, "on_event_one", Target
+        )
+        
+class TestListenOverride(TestBase):
+    """Test custom listen functions which change the listener function signature."""
+    
+    def setUp(self):
+        global Target
+        
+        class TargetEvents(event.Events):
+            @classmethod
+            def listen(cls, fn, identifier, target, add=False):
+                if add:
+                    def adapt(x, y):
+                        fn(x + y)
+                else:
+                    adapt = fn
+                    
+                event.Events.listen(adapt, identifier, target)
+                    
+            def on_event_one(self, x, y):
+                pass
+
+        class Target(object):
+            dispatch = event.dispatcher(TargetEvents)
+
+    def tearDown(self):
+        event._remove_dispatcher(Target.__dict__['dispatch'].events)
+    
+    def test_listen_override(self):
+        result = []
+        def listen_one(x):
+            result.append(x)
+            
+        def listen_two(x, y):
+            result.append((x, y))
+        
+        event.listen(listen_one, "on_event_one", Target, add=True)
+        event.listen(listen_two, "on_event_one", Target)
+
+        t1 = Target()
+        t1.dispatch.on_event_one(5, 7)
+        t1.dispatch.on_event_one(10, 5)
+        
+        eq_(result,
+            [
+                12, (5, 7), 15, (10, 5)
+            ]
+        )
+        
+        
+        
\ No newline at end of file