From: Mike Bayer Date: Sun, 7 Nov 2010 17:49:48 +0000 (-0500) Subject: - propagate flag on event.listen() results in the listener being placed X-Git-Tag: rel_0_7b1~253^2~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=13fedc23ecca81d0881a994a45efae3a77b74fcb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - propagate flag on event.listen() results in the listener being placed in a separate collection. this collection also propagates during update() - ClassManager now handles bases, subclasses collections. - ClassManager looks at __bases__ instead of __mro__ for superclasses. It's assumed ClassManagers are in an unbroken chain upwards through __mro__. - trying to get a clear() that makes sense on cls.dispatch - implemented propagate for attribute events, plus permutation-based test - implemented propagate for mapper / instance events with rudimentary test - some pool events tests are failing for some reason --- diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py index 955c33797c..75512f7d29 100644 --- a/lib/sqlalchemy/event.py +++ b/lib/sqlalchemy/event.py @@ -18,8 +18,6 @@ def listen(fn, identifier, target, *args, **kw): for evt_cls in _registrars[identifier]: tgt = evt_cls.accept_with(target) if tgt is not None: - if kw.pop('propagate', False): - fn._sa_event_propagate = True tgt.dispatch.listen(fn, identifier, tgt, *args, **kw) return raise exc.InvalidRequestError("No such event %s for target %s" % @@ -37,7 +35,7 @@ def remove(fn, identifier, target): for tgt in evt_cls.accept_with(target): tgt.dispatch.remove(fn, identifier, tgt, *args, **kw) return - + _registrars = util.defaultdict(list) class _UnpickleDispatch(object): @@ -60,7 +58,7 @@ class _Dispatch(object): def __reduce__(self): return _UnpickleDispatch(), (self.parent_cls, ) - + @property def descriptors(self): return (getattr(self, k) for k in dir(self) if k.startswith("on_")) @@ -71,7 +69,8 @@ class _Dispatch(object): for ls in other.descriptors: getattr(self, ls.name).update(ls, only_propagate=only_propagate) - + + class _EventMeta(type): """Intercept new Event subclasses and create associated _Dispatch classes.""" @@ -88,7 +87,8 @@ def _create_dispatcher_class(cls, classname, bases, dict_): cls.dispatch = dispatch_cls = type("%sDispatch" % classname, (dispatch_base, ), {}) dispatch_cls.listen = cls.listen - + dispatch_cls.clear = cls.clear + for k in dict_: if k.startswith('on_'): setattr(dispatch_cls, k, _DispatchDescriptor(dict_[k])) @@ -121,13 +121,19 @@ class Events(object): return None @classmethod - def listen(cls, fn, identifier, target): - getattr(target.dispatch, identifier).append(fn, target) + def listen(cls, fn, identifier, target, propagate=False): + getattr(target.dispatch, identifier).append(fn, target, propagate) @classmethod def remove(cls, fn, identifier, target): getattr(target.dispatch, identifier).remove(fn, target) - + + @classmethod + def clear(cls): + for attr in dir(cls.dispatch): + if attr.startswith("on_"): + getattr(cls.dispatch, attr).clear() + class _DispatchDescriptor(object): """Class-level attributes on _Dispatch classes.""" @@ -136,7 +142,7 @@ class _DispatchDescriptor(object): self.__doc__ = fn.__doc__ self._clslevel = util.defaultdict(list) - def append(self, obj, target): + def append(self, obj, target, propagate): assert isinstance(target, type), \ "Class-level Event targets must be classes." @@ -146,7 +152,13 @@ class _DispatchDescriptor(object): def remove(self, obj, target): for cls in [target] + target.__subclasses__(): self._clslevel[cls].remove(obj) + + def clear(self): + """Clear all class level listeners""" + for dispatcher in self._clslevel.values(): + dispatcher[:] = [] + def __get__(self, obj, cls): if obj is None: return self @@ -166,7 +178,8 @@ class _ListenerCollection(object): self.parent_listeners = parent._clslevel[target_cls] self.name = parent.__name__ self.listeners = [] - + self.propagate = set() + def exec_once(self, *args, **kw): """Execute this event, but only if it has not been executed already for this collection.""" @@ -200,22 +213,30 @@ class _ListenerCollection(object): def update(self, other, only_propagate=True): """Populate from the listeners in another :class:`_Dispatch` object.""" - + existing_listeners = self.listeners existing_listener_set = set(existing_listeners) + self.propagate.update(other.propagate) existing_listeners.extend([l for l in other.listeners if l not in existing_listener_set - and not only_propagate or getattr(l, '_sa_event_propagate', False) + and not only_propagate or l in self.propagate ]) - def append(self, obj, target): + def append(self, obj, target, propagate): if obj not in self.listeners: self.listeners.append(obj) + if propagate: + self.propagate.add(obj) def remove(self, obj, target): if obj in self.listeners: self.listeners.remove(obj) + self.propagate.discard(obj) + + def clear(self): + self.listeners[:] = [] + self.propagate.clear() class dispatcher(object): """Descriptor used by target classes to diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 6bc15d8f5e..fcaabfdddb 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -55,12 +55,22 @@ PASSIVE_OFF = False #util.symbol('PASSIVE_OFF') class QueryableAttribute(interfaces.PropComparator): """Base class for class-bound attributes. """ - def __init__(self, class_, key, impl=None, comparator=None, parententity=None): + def __init__(self, class_, key, impl=None, + comparator=None, parententity=None): self.class_ = class_ self.key = key self.impl = impl self.comparator = comparator self.parententity = parententity + + manager = manager_of_class(class_) + # manager is None in the case of AliasedClass + if manager: + # propagate existing event listeners from + # immediate superclass + for base in manager._bases: + if key in base: + self.dispatch.update(base[key].dispatch) dispatch = event.dispatcher(events.AttributeEvents) dispatch.dispatch_cls.active_history = False diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 36c12cf8ae..205ec32355 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1,7 +1,7 @@ """ORM event interfaces. """ -from sqlalchemy import event, util, exc +from sqlalchemy import event, exc import inspect class InstrumentationEvents(event.Events): @@ -26,13 +26,8 @@ class InstrumentationEvents(event.Events): return None @classmethod - def listen(cls, fn, identifier, target): - - @util.decorator - def adapt_to_target(fn, cls, *arg): - if issubclass(cls, target): - fn(cls, *arg) - event.Events.listen(fn, identifier, target) + def listen(cls, fn, identifier, target, propagate=False): + event.Events.listen(fn, identifier, target, propagate=propagate) @classmethod def remove(cls, fn, identifier, target): @@ -68,9 +63,14 @@ class InstanceEvents(event.Events): @classmethod def accept_with(cls, target): from sqlalchemy.orm.instrumentation import ClassManager, manager_of_class + from sqlalchemy.orm import Mapper, mapper if isinstance(target, ClassManager): return target + elif isinstance(target, Mapper): + return target.class_manager + elif target is Mapper or target is mapper: + return ClassManager elif isinstance(target, type): manager = manager_of_class(target) if manager: @@ -78,14 +78,18 @@ class InstanceEvents(event.Events): return None @classmethod - def listen(cls, fn, identifier, target, raw=False): + def listen(cls, fn, identifier, target, raw=False, propagate=False): if not raw: orig_fn = fn def wrap(state, *arg, **kw): return orig_fn(state.obj(), *arg, **kw) fn = wrap - event.Events.listen(fn, identifier, target) - + + event.Events.listen(fn, identifier, target, propagate=propagate) + if propagate: + for mgr in target.subclass_managers(True): + event.Events.listen(fn, identifier, mgr, True) + @classmethod def remove(cls, fn, identifier, target): raise NotImplementedError("Removal of instance events not yet implemented") @@ -148,6 +152,8 @@ class MapperEvents(event.Events): Several modifiers are available to the listen() function. + :param propagate=False: When True, the event listener should + be applied to all inheriting mappers as well. :param raw=False: When True, the "target" argument to the event, if applicable will be the :class:`.InstanceState` management object, rather than the mapped instance itself. @@ -178,7 +184,7 @@ class MapperEvents(event.Events): @classmethod def listen(cls, fn, identifier, target, - raw=False, retval=False): + raw=False, retval=False, propagate=False): from sqlalchemy.orm.interfaces import EXT_CONTINUE if not raw or not retval: @@ -201,9 +207,11 @@ class MapperEvents(event.Events): return wrapped_fn(*arg, **kw) fn = wrap - for mapper in target.self_and_descendants: - event.Events.listen(fn, identifier, mapper) - + if propagate: + for mapper in target.self_and_descendants: + event.Events.listen(fn, identifier, mapper, propagate=True) + else: + event.Events.listen(fn, identifier, self) def on_instrument_class(self, mapper, class_): """Receive a class when the mapper is first constructed, and has @@ -437,32 +445,26 @@ class AttributeEvents(event.Events): # of the wrapper with the original function. if not raw or not retval: - @util.decorator - def wrap(fn, target, value, *arg): + orig_fn = fn + def wrap(target, value, *arg): if not raw: target = target.obj() if not retval: - fn(target, value, *arg) + orig_fn(target, value, *arg) return value else: - return fn(target, value, *arg) - fn = wrap(fn) + return orig_fn(target, value, *arg) + fn = wrap - event.Events.listen(fn, identifier, target) + event.Events.listen(fn, identifier, target, propagate) if propagate: - - raise NotImplementedError() - - # TODO: for removal, need to implement - # packaging this info for operation in reverse. - - class_ = target.class_ - for cls in class_.__subclasses__(): - impl = getattr(cls, target.key) - if impl is not target: - event.Events.listen(fn, identifier, impl) + from sqlalchemy.orm.instrumentation import manager_of_class + + manager = manager_of_class(target.class_) + for mgr in manager.subclass_managers(True): + event.Events.listen(fn, identifier, mgr[target.key], True) @classmethod def remove(cls, fn, identifier, target): diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 6e357e1579..dba3a68307 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -83,12 +83,16 @@ class ClassManager(dict): self.mutable_attributes = set() self.local_attrs = {} self.originals = {} - for base in class_.__mro__[-2:0:-1]: # reverse, skipping 1st and last - if not isinstance(base, type): - continue - cls_state = manager_of_class(base) - if cls_state: - self.update(cls_state) + + self._bases = [mgr for mgr in [ + manager_of_class(base) + for base in self.class_.__bases__ + if isinstance(base, type) + ] if mgr is not None] + + for base in self._bases: + self.update(base) + self.manage() self._instrument_init() @@ -194,6 +198,15 @@ class ClassManager(dict): manager = self._subclass_manager(cls) manager.instrument_attribute(key, inst, True) + def subclass_managers(self, recursive): + for cls in self.class_.__subclasses__(): + mgr = manager_of_class(cls) + if mgr is not None and mgr is not self: + yield mgr + if recursive: + for m in mgr.subclass_managers(True): + yield m + def post_configure_attribute(self, key): instrumentation_registry.dispatch.\ on_attribute_instrument(self.class_, key, self[key]) diff --git a/lib/sqlalchemy/test/util.py b/lib/sqlalchemy/test/util.py index ff2c3d7b79..98667d8c26 100644 --- a/lib/sqlalchemy/test/util.py +++ b/lib/sqlalchemy/test/util.py @@ -1,4 +1,4 @@ -from sqlalchemy.util import jython, function_named +from sqlalchemy.util import jython, function_named, defaultdict import gc import time @@ -75,4 +75,35 @@ class RandomSet(set): def copy(self): return RandomSet(self) - \ No newline at end of file + +def conforms_partial_ordering(tuples, sorted_elements): + """True if the given sorting conforms to the given partial ordering.""" + + deps = defaultdict(set) + for parent, child in tuples: + deps[parent].add(child) + for i, node in enumerate(sorted_elements): + for n in sorted_elements[i:]: + if node in deps[n]: + return False + else: + return True + +def all_partial_orderings(tuples, elements): + edges = defaultdict(set) + for parent, child in tuples: + edges[child].add(parent) + + def _all_orderings(elements): + + if len(elements) == 1: + yield list(elements) + else: + for elem in elements: + subset = set(elements).difference([elem]) + if not subset.intersection(edges[elem]): + for sub_ordering in _all_orderings(subset): + yield [elem] + sub_ordering + + return iter(_all_orderings(elements)) + diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py index 9fddfc47ff..605a16cc34 100644 --- a/test/base/test_dependency.py +++ b/test/base/test_dependency.py @@ -1,6 +1,7 @@ import sqlalchemy.topological as topological from sqlalchemy.test import TestBase from sqlalchemy.test.testing import assert_raises, eq_ +from sqlalchemy.test.util import conforms_partial_ordering from sqlalchemy import exc, util @@ -12,13 +13,7 @@ class DependencySortTest(TestBase): else: allitems = self._nodes_from_tuples(tuples).union(allitems) result = list(topological.sort(tuples, allitems)) - deps = util.defaultdict(set) - for parent, child in tuples: - deps[parent].add(child) - assert len(result) - for i, node in enumerate(result): - for n in result[i:]: - assert node not in deps[n] + assert conforms_partial_ordering(tuples, result) def _nodes_from_tuples(self, tups): s = set() diff --git a/test/base/test_events.py b/test/base/test_events.py index 995569c60f..2920598775 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -262,17 +262,3 @@ class TestPropagate(TestBase): t2.dispatch.on_event_one(t2, 1) t2.dispatch.on_event_two(t2, 2) eq_(result, [(t2, 1)]) - - - - - - - - - - - - - - \ No newline at end of file diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 60f9c6b9ee..b5a6c1f5e9 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -6,8 +6,9 @@ from sqlalchemy import exc as sa_exc from sqlalchemy.test import * from sqlalchemy.test.testing import eq_, ne_, assert_raises from test.orm import _base -from sqlalchemy.test.util import gc_collect +from sqlalchemy.test.util import gc_collect, all_partial_orderings from sqlalchemy.util import cmp, jython +from sqlalchemy import event, topological # global for pickling tests MyTest = None @@ -1569,33 +1570,32 @@ class HistoryTest(_base.ORMTest): class ListenerTest(_base.ORMTest): def test_receive_changes(self): - """test that Listeners can mutate the given value. + """test that Listeners can mutate the given value.""" - This is a rudimentary test which would be better suited by a full-blown inclusion - into collection.py. - - """ class Foo(object): pass class Bar(object): pass + + def on_append(state, child, initiator): + b2 = Bar() + b2.data = b1.data + " appended" + return b2 - class AlteringListener(AttributeExtension): - def append(self, state, child, initiator): - b2 = Bar() - b2.data = b1.data + " appended" - return b2 - - def set(self, state, value, oldvalue, initiator): - return value + " modified" + def on_set(state, value, oldvalue, initiator): + return value + " modified" instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'data', uselist=False, useobject=False, extension=AlteringListener()) - attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True, extension=AlteringListener()) - attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True, extension=AlteringListener()) + attributes.register_attribute(Foo, 'data', uselist=False, useobject=False) + attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True) + attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True) attributes.register_attribute(Bar, 'data', uselist=False, useobject=False) + event.listen(on_set, 'on_set', Foo.data, retval=True) + event.listen(on_append, 'on_append', Foo.barlist, retval=True) + event.listen(on_append, 'on_append', Foo.barset, retval=True) + f1 = Foo() f1.data = "some data" eq_(f1.data, "some data modified") @@ -1608,4 +1608,84 @@ class ListenerTest(_base.ORMTest): f1.barset.add(b1) assert f1.barset.pop().data == "some bar appended" - + def test_propagate(self): + classes = [None, None, None] + canary = [] + def make_a(): + class A(object): + pass + classes[0] = A + + def make_b(): + class B(classes[0]): + pass + classes[1] = B + + def make_c(): + class C(classes[1]): + pass + classes[2] = C + + def instrument_a(): + instrumentation.register_class(classes[0]) + + def instrument_b(): + instrumentation.register_class(classes[1]) + + def instrument_c(): + instrumentation.register_class(classes[2]) + + def attr_a(): + attributes.register_attribute(classes[0], 'attrib', uselist=False, useobject=False) + + def attr_b(): + attributes.register_attribute(classes[1], 'attrib', uselist=False, useobject=False) + + def attr_c(): + attributes.register_attribute(classes[2], 'attrib', uselist=False, useobject=False) + + def on_set(state, value, oldvalue, initiator): + canary.append(value) + + def events_a(): + event.listen(on_set, 'on_set', classes[0].attrib, propagate=True) + + def teardown(): + classes[:] = [None, None, None] + canary[:] = [] + + ordering = [ + (instrument_a, instrument_b), + (instrument_b, instrument_c), + (attr_a, attr_b), + (attr_b, attr_c), + (make_a, instrument_a), + (instrument_a, attr_a), + (attr_a, events_a), + (make_b, instrument_b), + (instrument_b, attr_b), + (make_c, instrument_c), + (instrument_c, attr_c), + (make_a, make_b), + (make_b, make_c) + ] + elements = [make_a, make_b, make_c, + instrument_a, instrument_b, instrument_c, + attr_a, attr_b, attr_c, events_a] + + for i, series in enumerate(all_partial_orderings(ordering, elements)): + for fn in series: + fn() + + b = classes[1]() + b.attrib = "foo" + eq_(b.attrib, "foo") + eq_(canary, ["foo"]) + + c = classes[2]() + c.attrib = "bar" + eq_(c.attrib, "bar") + eq_(canary, ["foo", "bar"]) + + teardown() + diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index b6432a39aa..9cc25c8728 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -6,12 +6,17 @@ from sqlalchemy.test import testing, pickleable from sqlalchemy import MetaData, Integer, String, ForeignKey, func, util from sqlalchemy.test.schema import Table, Column from sqlalchemy.engine import default -from sqlalchemy.orm import mapper, relationship, backref, create_session, class_mapper, compile_mappers, reconstructor, validates, aliased -from sqlalchemy.orm import defer, deferred, synonym, attributes, column_property, composite, relationship, dynamic_loader, comparable_property,AttributeExtension +from sqlalchemy.orm import mapper, relationship, backref, \ + create_session, class_mapper, compile_mappers, reconstructor, \ + validates, aliased, Mapper +from sqlalchemy.orm import defer, deferred, synonym, attributes, \ + column_property, composite, relationship, dynamic_loader, \ + comparable_property, AttributeExtension +from sqlalchemy.orm.instrumentation import ClassManager from sqlalchemy.test.testing import eq_, AssertsCompiledSQL from test.orm import _base, _fixtures from sqlalchemy.test.assertsql import AllOf, CompiledSQL - +from sqlalchemy import event class MapperTest(_fixtures.FixtureTest): @@ -2549,7 +2554,57 @@ class AttributeExtensionTest(_base.MappedTest): eq_(ext_msg, ["Ex1 'a1'", "Ex1 'b1'", "Ex2 'c1'", "Ex1 'a2'", "Ex1 'b2'", "Ex2 'c2'"]) +class MapperEventsTest(_fixtures.FixtureTest): + @testing.resolve_artifact_names + def test_instance_event_listen(self): + """test listen targets for instance events""" + + canary = [] + class A(object): + pass + class B(A): + pass + + mapper(A, users) + mapper(B, addresses, inherits=A) + + def on_init_a(target, args, kwargs): + canary.append(('on_init_a', target)) + + def on_init_b(target, args, kwargs): + canary.append(('on_init_b', target)) + + def on_init_c(target, args, kwargs): + canary.append(('on_init_c', target)) + + def on_init_d(target, args, kwargs): + canary.append(('on_init_d', target)) + + def on_init_e(target, args, kwargs): + canary.append(('on_init_e', target)) + + event.listen(on_init_a, 'on_init', mapper) + event.listen(on_init_b, 'on_init', Mapper) + event.listen(on_init_c, 'on_init', class_mapper(A)) + event.listen(on_init_d, 'on_init', A) + event.listen(on_init_e, 'on_init', A, propagate=True) + + a = A() + eq_(canary, [('on_init_a', a),('on_init_b', a), + ('on_init_c', a),('on_init_d', a),('on_init_e', a)]) + + # test propagate flag + canary[:] = [] + b = B() + eq_(canary, [('on_init_a', b), ('on_init_b', b),('on_init_e', b)]) + def teardown(self): + # TODO: need to get remove() functionality + # going + Mapper.dispatch.clear() + ClassManager.dispatch.clear() + + class MapperExtensionTest(_fixtures.FixtureTest): run_inserts = None