]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- propagate flag on event.listen() results in the listener being placed
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Nov 2010 17:49:48 +0000 (12:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Nov 2010 17:49:48 +0000 (12:49 -0500)
in a separate collection.  this collection also propagates during update()
- ClassManager now handles bases, subclasses collections.
- ClassManager looks at __bases__ instead of __mro__ for superclasses.
It's assumed ClassManagers are in an unbroken chain upwards through __mro__.
- trying to get a clear() that makes sense on cls.dispatch
- implemented propagate for attribute events, plus permutation-based test
- implemented propagate for mapper / instance events with rudimentary test
- some pool events tests are failing for some reason

lib/sqlalchemy/event.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/instrumentation.py
lib/sqlalchemy/test/util.py
test/base/test_dependency.py
test/base/test_events.py
test/orm/test_attributes.py
test/orm/test_mapper.py

index 955c33797c0c63685f06930911c5ca8cd719b674..75512f7d2902e359b57f871ba0645981e492f06a 100644 (file)
@@ -18,8 +18,6 @@ def listen(fn, identifier, target, *args, **kw):
     for evt_cls in _registrars[identifier]:
         tgt = evt_cls.accept_with(target)
         if tgt is not None:
-            if kw.pop('propagate', False):
-                fn._sa_event_propagate = True
             tgt.dispatch.listen(fn, identifier, tgt, *args, **kw)
             return
     raise exc.InvalidRequestError("No such event %s for target %s" %
@@ -37,7 +35,7 @@ def remove(fn, identifier, target):
         for tgt in evt_cls.accept_with(target):
             tgt.dispatch.remove(fn, identifier, tgt, *args, **kw)
             return
-        
+
 _registrars = util.defaultdict(list)
 
 class _UnpickleDispatch(object):
@@ -60,7 +58,7 @@ class _Dispatch(object):
     def __reduce__(self):
         
         return _UnpickleDispatch(), (self.parent_cls, )
-        
+    
     @property
     def descriptors(self):
         return (getattr(self, k) for k in dir(self) if k.startswith("on_"))
@@ -71,7 +69,8 @@ class _Dispatch(object):
 
         for ls in other.descriptors:
             getattr(self, ls.name).update(ls, only_propagate=only_propagate)
-
+            
+            
 class _EventMeta(type):
     """Intercept new Event subclasses and create 
     associated _Dispatch classes."""
@@ -88,7 +87,8 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
     cls.dispatch = dispatch_cls = type("%sDispatch" % classname, 
                                         (dispatch_base, ), {})
     dispatch_cls.listen = cls.listen
-
+    dispatch_cls.clear = cls.clear
+    
     for k in dict_:
         if k.startswith('on_'):
             setattr(dispatch_cls, k, _DispatchDescriptor(dict_[k]))
@@ -121,13 +121,19 @@ class Events(object):
             return None
 
     @classmethod
-    def listen(cls, fn, identifier, target):
-        getattr(target.dispatch, identifier).append(fn, target)
+    def listen(cls, fn, identifier, target, propagate=False):
+        getattr(target.dispatch, identifier).append(fn, target, propagate)
     
     @classmethod
     def remove(cls, fn, identifier, target):
         getattr(target.dispatch, identifier).remove(fn, target)
-        
+    
+    @classmethod
+    def clear(cls):
+        for attr in dir(cls.dispatch):
+            if attr.startswith("on_"):
+                getattr(cls.dispatch, attr).clear()
+                
 class _DispatchDescriptor(object):
     """Class-level attributes on _Dispatch classes."""
     
@@ -136,7 +142,7 @@ class _DispatchDescriptor(object):
         self.__doc__ = fn.__doc__
         self._clslevel = util.defaultdict(list)
     
-    def append(self, obj, target):
+    def append(self, obj, target, propagate):
         assert isinstance(target, type), \
                 "Class-level Event targets must be classes."
                 
@@ -146,7 +152,13 @@ class _DispatchDescriptor(object):
     def remove(self, obj, target):
         for cls in [target] + target.__subclasses__():
             self._clslevel[cls].remove(obj)
+    
+    def clear(self):
+        """Clear all class level listeners"""
         
+        for dispatcher in self._clslevel.values():
+            dispatcher[:] = []
+            
     def __get__(self, obj, cls):
         if obj is None:
             return self
@@ -166,7 +178,8 @@ class _ListenerCollection(object):
         self.parent_listeners = parent._clslevel[target_cls]
         self.name = parent.__name__
         self.listeners = []
-
+        self.propagate = set()
+        
     def exec_once(self, *args, **kw):
         """Execute this event, but only if it has not been
         executed already for this collection."""
@@ -200,22 +213,30 @@ class _ListenerCollection(object):
     def update(self, other, only_propagate=True):
         """Populate from the listeners in another :class:`_Dispatch`
             object."""
-
+        
         existing_listeners = self.listeners
         existing_listener_set = set(existing_listeners)
+        self.propagate.update(other.propagate)
         existing_listeners.extend([l for l 
                                 in other.listeners 
                                 if l not in existing_listener_set
-                                and not only_propagate or getattr(l, '_sa_event_propagate', False)
+                                and not only_propagate or l in self.propagate
                                 ])
 
-    def append(self, obj, target):
+    def append(self, obj, target, propagate):
         if obj not in self.listeners:
             self.listeners.append(obj)
+            if propagate:
+                self.propagate.add(obj)
     
     def remove(self, obj, target):
         if obj in self.listeners:
             self.listeners.remove(obj)
+            self.propagate.discard(obj)
+    
+    def clear(self):
+        self.listeners[:] = []
+        self.propagate.clear()
         
 class dispatcher(object):
     """Descriptor used by target classes to 
index 6bc15d8f5efdef7462296b8a4f6e8ca075120d93..fcaabfdddb8e1050a530e0526d2a286bdaae74ce 100644 (file)
@@ -55,12 +55,22 @@ PASSIVE_OFF = False #util.symbol('PASSIVE_OFF')
 class QueryableAttribute(interfaces.PropComparator):
     """Base class for class-bound attributes. """
     
-    def __init__(self, class_, key, impl=None, comparator=None, parententity=None):
+    def __init__(self, class_, key, impl=None, 
+                        comparator=None, parententity=None):
         self.class_ = class_
         self.key = key
         self.impl = impl
         self.comparator = comparator
         self.parententity = parententity
+        
+        manager = manager_of_class(class_)
+        # manager is None in the case of AliasedClass
+        if manager:
+            # propagate existing event listeners from 
+            # immediate superclass
+            for base in manager._bases:
+                if key in base:
+                    self.dispatch.update(base[key].dispatch)
 
     dispatch = event.dispatcher(events.AttributeEvents)
     dispatch.dispatch_cls.active_history = False
index 36c12cf8aea2786c3b6849d20768d27430841766..205ec32355af33d92f781d965cadf170c2c45f23 100644 (file)
@@ -1,7 +1,7 @@
 """ORM event interfaces.
 
 """
-from sqlalchemy import event, util, exc
+from sqlalchemy import event, exc
 import inspect
 
 class InstrumentationEvents(event.Events):
@@ -26,13 +26,8 @@ class InstrumentationEvents(event.Events):
             return None
 
     @classmethod
-    def listen(cls, fn, identifier, target):
-        
-        @util.decorator
-        def adapt_to_target(fn, cls, *arg):
-            if issubclass(cls, target):
-                fn(cls, *arg)
-        event.Events.listen(fn, identifier, target)
+    def listen(cls, fn, identifier, target, propagate=False):
+        event.Events.listen(fn, identifier, target, propagate=propagate)
 
     @classmethod
     def remove(cls, fn, identifier, target):
@@ -68,9 +63,14 @@ class InstanceEvents(event.Events):
     @classmethod
     def accept_with(cls, target):
         from sqlalchemy.orm.instrumentation import ClassManager, manager_of_class
+        from sqlalchemy.orm import Mapper, mapper
         
         if isinstance(target, ClassManager):
             return target
+        elif isinstance(target, Mapper):
+            return target.class_manager
+        elif target is Mapper or target is mapper:
+            return ClassManager
         elif isinstance(target, type):
             manager = manager_of_class(target)
             if manager:
@@ -78,14 +78,18 @@ class InstanceEvents(event.Events):
         return None
     
     @classmethod
-    def listen(cls, fn, identifier, target, raw=False):
+    def listen(cls, fn, identifier, target, raw=False, propagate=False):
         if not raw:
             orig_fn = fn
             def wrap(state, *arg, **kw):
                 return orig_fn(state.obj(), *arg, **kw)
             fn = wrap
-        event.Events.listen(fn, identifier, target)
-        
+
+        event.Events.listen(fn, identifier, target, propagate=propagate)
+        if propagate:
+            for mgr in target.subclass_managers(True):
+                event.Events.listen(fn, identifier, mgr, True)
+            
     @classmethod
     def remove(cls, fn, identifier, target):
         raise NotImplementedError("Removal of instance events not yet implemented")
@@ -148,6 +152,8 @@ class MapperEvents(event.Events):
     
     Several modifiers are available to the listen() function.
     
+    :param propagate=False: When True, the event listener should 
+      be applied to all inheriting mappers as well.
     :param raw=False: When True, the "target" argument to the
       event, if applicable will be the :class:`.InstanceState` management
       object, rather than the mapped instance itself.
@@ -178,7 +184,7 @@ class MapperEvents(event.Events):
         
     @classmethod
     def listen(cls, fn, identifier, target, 
-                                        raw=False, retval=False):
+                            raw=False, retval=False, propagate=False):
         from sqlalchemy.orm.interfaces import EXT_CONTINUE
 
         if not raw or not retval:
@@ -201,9 +207,11 @@ class MapperEvents(event.Events):
                     return wrapped_fn(*arg, **kw)
             fn = wrap
         
-        for mapper in target.self_and_descendants:
-            event.Events.listen(fn, identifier, mapper)
-    
+        if propagate:
+            for mapper in target.self_and_descendants:
+                event.Events.listen(fn, identifier, mapper, propagate=True)
+        else:
+            event.Events.listen(fn, identifier, self)
         
     def on_instrument_class(self, mapper, class_):
         """Receive a class when the mapper is first constructed, and has
@@ -437,32 +445,26 @@ class AttributeEvents(event.Events):
         # of the wrapper with the original function.
         
         if not raw or not retval:
-            @util.decorator
-            def wrap(fn, target, value, *arg):
+            orig_fn = fn
+            def wrap(target, value, *arg):
                 if not raw:
                     target = target.obj()
                 if not retval:
-                    fn(target, value, *arg)
+                    orig_fn(target, value, *arg)
                     return value
                 else:
-                    return fn(target, value, *arg)
-            fn = wrap(fn)
+                    return orig_fn(target, value, *arg)
+            fn = wrap
             
-        event.Events.listen(fn, identifier, target)
+        event.Events.listen(fn, identifier, target, propagate)
         
         if propagate:
-
-            raise NotImplementedError()
-
-            # TODO: for removal, need to implement
-            # packaging this info for operation in reverse.
-
-            class_ = target.class_
-            for cls in class_.__subclasses__():
-                impl = getattr(cls, target.key)
-                if impl is not target:
-                    event.Events.listen(fn, identifier, impl)
+            from sqlalchemy.orm.instrumentation import manager_of_class
+            
+            manager = manager_of_class(target.class_)
             
+            for mgr in manager.subclass_managers(True):
+                event.Events.listen(fn, identifier, mgr[target.key], True)
         
     @classmethod
     def remove(cls, fn, identifier, target):
index 6e357e15795c5f77dd66dd3b73d626dab75e07cc..dba3a683072742af101b8844b2c71939eb5b74f2 100644 (file)
@@ -83,12 +83,16 @@ class ClassManager(dict):
         self.mutable_attributes = set()
         self.local_attrs = {}
         self.originals = {}
-        for base in class_.__mro__[-2:0:-1]:  # reverse, skipping 1st and last
-            if not isinstance(base, type):
-                continue
-            cls_state = manager_of_class(base)
-            if cls_state:
-                self.update(cls_state)
+
+        self._bases = [mgr for mgr in [
+                        manager_of_class(base) 
+                        for base in self.class_.__bases__
+                        if isinstance(base, type)
+                 ] if mgr is not None]
+
+        for base in self._bases:
+            self.update(base)
+
         self.manage()
         self._instrument_init()
     
@@ -194,6 +198,15 @@ class ClassManager(dict):
             manager = self._subclass_manager(cls)
             manager.instrument_attribute(key, inst, True)
 
+    def subclass_managers(self, recursive):
+        for cls in self.class_.__subclasses__():
+            mgr = manager_of_class(cls)
+            if mgr is not None and mgr is not self:
+                yield mgr
+                if recursive:
+                    for m in mgr.subclass_managers(True):
+                        yield m
+        
     def post_configure_attribute(self, key):
         instrumentation_registry.dispatch.\
                 on_attribute_instrument(self.class_, key, self[key])
index ff2c3d7b79bbc373a79e7b3b6729e0a1dec59f03..98667d8c26f5d232dddfb5ba91598155e90dcf11 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy.util import jython, function_named
+from sqlalchemy.util import jython, function_named, defaultdict
 
 import gc
 import time
@@ -75,4 +75,35 @@ class RandomSet(set):
         
     def copy(self):
         return RandomSet(self)
-        
\ No newline at end of file
+        
+def conforms_partial_ordering(tuples, sorted_elements):
+    """True if the given sorting conforms to the given partial ordering."""
+    
+    deps = defaultdict(set)
+    for parent, child in tuples:
+        deps[parent].add(child)
+    for i, node in enumerate(sorted_elements):
+        for n in sorted_elements[i:]:
+            if node in deps[n]:
+                return False
+    else:
+        return True
+
+def all_partial_orderings(tuples, elements):
+    edges = defaultdict(set)
+    for parent, child in tuples:
+        edges[child].add(parent)
+
+    def _all_orderings(elements):
+
+        if len(elements) == 1:
+            yield list(elements)
+        else:
+            for elem in elements:
+                subset = set(elements).difference([elem])
+                if not subset.intersection(edges[elem]):
+                    for sub_ordering in _all_orderings(subset):
+                        yield [elem] + sub_ordering
+    
+    return iter(_all_orderings(elements))
+
index 9fddfc47ffedeef31ff4b1d21d80a0a4daac7fa9..605a16cc340d8e278a4d135f73c8aab619b5f7dc 100644 (file)
@@ -1,6 +1,7 @@
 import sqlalchemy.topological as topological
 from sqlalchemy.test import TestBase
 from sqlalchemy.test.testing import assert_raises, eq_
+from sqlalchemy.test.util import conforms_partial_ordering
 from sqlalchemy import exc, util
 
 
@@ -12,13 +13,7 @@ class DependencySortTest(TestBase):
         else:
             allitems = self._nodes_from_tuples(tuples).union(allitems)
         result = list(topological.sort(tuples, allitems))
-        deps = util.defaultdict(set)
-        for parent, child in tuples:
-            deps[parent].add(child)
-        assert len(result)
-        for i, node in enumerate(result):
-            for n in result[i:]:
-                assert node not in deps[n]
+        assert conforms_partial_ordering(tuples, result)
 
     def _nodes_from_tuples(self, tups):
         s = set()
index 995569c60ff2f9217d05ff98bd45c0ba7340cda0..2920598775e29283642d4e3d419fb168efcac59e 100644 (file)
@@ -262,17 +262,3 @@ class TestPropagate(TestBase):
         t2.dispatch.on_event_one(t2, 1)
         t2.dispatch.on_event_two(t2, 2)
         eq_(result, [(t2, 1)])
-        
-        
-        
-            
-        
-        
-        
-        
-        
-        
-        
-    
-    
-        
\ No newline at end of file
index 60f9c6b9ee8dc37612ef4ef6cb3a391dce420514..b5a6c1f5e9739baf5f4ccb0bea3469b779b51c44 100644 (file)
@@ -6,8 +6,9 @@ from sqlalchemy import exc as sa_exc
 from sqlalchemy.test import *
 from sqlalchemy.test.testing import eq_, ne_, assert_raises
 from test.orm import _base
-from sqlalchemy.test.util import gc_collect
+from sqlalchemy.test.util import gc_collect, all_partial_orderings
 from sqlalchemy.util import cmp, jython
+from sqlalchemy import event, topological
 
 # global for pickling tests
 MyTest = None
@@ -1569,33 +1570,32 @@ class HistoryTest(_base.ORMTest):
 
 class ListenerTest(_base.ORMTest):
     def test_receive_changes(self):
-        """test that Listeners can mutate the given value.
+        """test that Listeners can mutate the given value."""
         
-        This is a rudimentary test which would be better suited by a full-blown inclusion
-        into collection.py.
-        
-        """
         class Foo(object):
             pass
         class Bar(object):
             pass
+            
+        def on_append(state, child, initiator):
+            b2 = Bar()
+            b2.data = b1.data + " appended"
+            return b2
 
-        class AlteringListener(AttributeExtension):
-            def append(self, state, child, initiator):
-                b2 = Bar()
-                b2.data = b1.data + " appended"
-                return b2
-
-            def set(self, state, value, oldvalue, initiator):
-                return value + " modified"
+        def on_set(state, value, oldvalue, initiator):
+            return value + " modified"
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(Foo, 'data', uselist=False, useobject=False, extension=AlteringListener())
-        attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True, extension=AlteringListener())
-        attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True, extension=AlteringListener())
+        attributes.register_attribute(Foo, 'data', uselist=False, useobject=False)
+        attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True)
+        attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True)
         attributes.register_attribute(Bar, 'data', uselist=False, useobject=False)
         
+        event.listen(on_set, 'on_set', Foo.data, retval=True)
+        event.listen(on_append, 'on_append', Foo.barlist, retval=True)
+        event.listen(on_append, 'on_append', Foo.barset, retval=True)
+        
         f1 = Foo()
         f1.data = "some data"
         eq_(f1.data, "some data modified")
@@ -1608,4 +1608,84 @@ class ListenerTest(_base.ORMTest):
         f1.barset.add(b1)
         assert f1.barset.pop().data == "some bar appended"
     
-    
+    def test_propagate(self):
+        classes = [None, None, None]
+        canary = []
+        def make_a():
+            class A(object):
+                pass
+            classes[0] = A
+            
+        def make_b():
+            class B(classes[0]):
+                pass
+            classes[1] = B
+        
+        def make_c():
+            class C(classes[1]):
+                pass
+            classes[2] = C
+            
+        def instrument_a():
+            instrumentation.register_class(classes[0])
+
+        def instrument_b():
+            instrumentation.register_class(classes[1])
+        
+        def instrument_c():
+            instrumentation.register_class(classes[2])
+            
+        def attr_a():
+            attributes.register_attribute(classes[0], 'attrib', uselist=False, useobject=False)
+        
+        def attr_b():
+            attributes.register_attribute(classes[1], 'attrib', uselist=False, useobject=False)
+
+        def attr_c():
+            attributes.register_attribute(classes[2], 'attrib', uselist=False, useobject=False)
+        
+        def on_set(state, value, oldvalue, initiator):
+            canary.append(value)
+            
+        def events_a():
+            event.listen(on_set, 'on_set', classes[0].attrib, propagate=True)
+        
+        def teardown():
+            classes[:] = [None, None, None]
+            canary[:] = []
+        
+        ordering = [
+            (instrument_a, instrument_b),
+            (instrument_b, instrument_c),
+            (attr_a, attr_b),
+            (attr_b, attr_c),
+            (make_a, instrument_a),
+            (instrument_a, attr_a),
+            (attr_a, events_a),
+            (make_b, instrument_b),
+            (instrument_b, attr_b),
+            (make_c, instrument_c),
+            (instrument_c, attr_c),
+            (make_a, make_b),
+            (make_b, make_c)
+        ]
+        elements = [make_a, make_b, make_c, 
+                    instrument_a, instrument_b, instrument_c, 
+                    attr_a, attr_b, attr_c, events_a]
+        
+        for i, series in enumerate(all_partial_orderings(ordering, elements)):
+            for fn in series:
+                fn()
+                
+            b = classes[1]()
+            b.attrib = "foo"
+            eq_(b.attrib, "foo")
+            eq_(canary, ["foo"])
+            
+            c = classes[2]()
+            c.attrib = "bar"
+            eq_(c.attrib, "bar")
+            eq_(canary, ["foo", "bar"])
+            
+            teardown()
+        
index b6432a39aa462a8b5c7c1dca03e8f3e51882e182..9cc25c872864f55569ed38602f9a4a8e2dfd5bc5 100644 (file)
@@ -6,12 +6,17 @@ from sqlalchemy.test import testing, pickleable
 from sqlalchemy import MetaData, Integer, String, ForeignKey, func, util
 from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.engine import default
-from sqlalchemy.orm import mapper, relationship, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased
-from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relationship, dynamic_loader, comparable_property,AttributeExtension
+from sqlalchemy.orm import mapper, relationship, backref, \
+    create_session, class_mapper, compile_mappers, reconstructor, \
+    validates, aliased, Mapper
+from sqlalchemy.orm import defer, deferred, synonym, attributes, \
+    column_property, composite, relationship, dynamic_loader, \
+    comparable_property, AttributeExtension
+from sqlalchemy.orm.instrumentation import ClassManager
 from sqlalchemy.test.testing import eq_, AssertsCompiledSQL
 from test.orm import _base, _fixtures
 from sqlalchemy.test.assertsql import AllOf, CompiledSQL
-
+from sqlalchemy import event
 
 class MapperTest(_fixtures.FixtureTest):
 
@@ -2549,7 +2554,57 @@ class AttributeExtensionTest(_base.MappedTest):
         
         eq_(ext_msg, ["Ex1 'a1'", "Ex1 'b1'", "Ex2 'c1'", "Ex1 'a2'", "Ex1 'b2'", "Ex2 'c2'"])
         
+class MapperEventsTest(_fixtures.FixtureTest):
+    @testing.resolve_artifact_names
+    def test_instance_event_listen(self):
+        """test listen targets for instance events"""
+        
+        canary = []
+        class A(object):
+            pass
+        class B(A):
+            pass
+            
+        mapper(A, users)
+        mapper(B, addresses, inherits=A)
+        
+        def on_init_a(target, args, kwargs):
+            canary.append(('on_init_a', target))
+            
+        def on_init_b(target, args, kwargs):
+            canary.append(('on_init_b', target))
+
+        def on_init_c(target, args, kwargs):
+            canary.append(('on_init_c', target))
+
+        def on_init_d(target, args, kwargs):
+            canary.append(('on_init_d', target))
+
+        def on_init_e(target, args, kwargs):
+            canary.append(('on_init_e', target))
+        
+        event.listen(on_init_a, 'on_init', mapper)
+        event.listen(on_init_b, 'on_init', Mapper)
+        event.listen(on_init_c, 'on_init', class_mapper(A))
+        event.listen(on_init_d, 'on_init', A)
+        event.listen(on_init_e, 'on_init', A, propagate=True)
+        
+        a = A()
+        eq_(canary, [('on_init_a', a),('on_init_b', a),
+                        ('on_init_c', a),('on_init_d', a),('on_init_e', a)])
+        
+        # test propagate flag
+        canary[:] = []
+        b = B()
+        eq_(canary, [('on_init_a', b), ('on_init_b', b),('on_init_e', b)])
     
+    def teardown(self):
+        # TODO: need to get remove() functionality
+        # going
+        Mapper.dispatch.clear()
+        ClassManager.dispatch.clear()
+        
+        
 class MapperExtensionTest(_fixtures.FixtureTest):
     run_inserts = None