From: Mike Bayer Date: Wed, 1 Dec 2010 01:25:22 +0000 (-0500) Subject: - replace GenericBackrefExtension with straight events X-Git-Tag: rel_0_7b1~215 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3c9d2d7b2f76fc18c0f1141a813a7045ac8cb853;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - replace GenericBackrefExtension with straight events - add "backref" argument to register_attribute_impl --- diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 9ae885bf99..d80a7fe5a5 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -791,21 +791,10 @@ class CollectionAttributeImpl(AttributeImpl): return getattr(user_data, '_sa_adapter') -class GenericBackrefExtension(interfaces.AttributeExtension): - """An extension which synchronizes a two-way relationship. +def backref_listeners(attribute, key, uselist): + """Apply listeners to synchronize a two-way relationship.""" - A typical two-way relationship is a parent object containing a list of - child objects, where each child object references the parent. The other - are two objects which contain scalar references to each other. - - """ - - active_history = False - - def __init__(self, key): - self.key = key - - def set(self, state, child, oldchild, initiator): + def set_(state, child, oldchild, initiator): if oldchild is child: return child @@ -814,7 +803,7 @@ class GenericBackrefExtension(interfaces.AttributeExtension): # present when updating via a backref. old_state, old_dict = instance_state(oldchild),\ instance_dict(oldchild) - impl = old_state.get_impl(self.key) + impl = old_state.get_impl(key) try: impl.remove(old_state, old_dict, @@ -826,7 +815,7 @@ class GenericBackrefExtension(interfaces.AttributeExtension): if child is not None: child_state, child_dict = instance_state(child),\ instance_dict(child) - child_state.get_impl(self.key).append( + child_state.get_impl(key).append( child_state, child_dict, state.obj(), @@ -834,10 +823,10 @@ class GenericBackrefExtension(interfaces.AttributeExtension): passive=PASSIVE_NO_FETCH) return child - def append(self, state, child, initiator): + def append(state, child, initiator): child_state, child_dict = instance_state(child), \ instance_dict(child) - child_state.get_impl(self.key).append( + child_state.get_impl(key).append( child_state, child_dict, state.obj(), @@ -845,18 +834,24 @@ class GenericBackrefExtension(interfaces.AttributeExtension): passive=PASSIVE_NO_FETCH) return child - def remove(self, state, child, initiator): + def remove(state, child, initiator): if child is not None: child_state, child_dict = instance_state(child),\ instance_dict(child) - child_state.get_impl(self.key).remove( + child_state.get_impl(key).remove( child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_FETCH) - - + + if uselist: + event.listen(append, "on_append", attribute, retval=False, raw=True) + else: + event.listen(set_, "on_set", attribute, retval=False, raw=True) + # TODO: need coverage in test/orm/ of remove event + event.listen(remove, "on_remove", attribute, retval=False, raw=True) + class History(tuple): """A 3-tuple of added, unchanged and deleted values, representing the changes which have occured on an instrumented @@ -1010,14 +1005,15 @@ def register_attribute(class_, key, **kw): comparator = kw.pop('comparator', None) parententity = kw.pop('parententity', None) doc = kw.pop('doc', None) - register_descriptor(class_, key, + desc = register_descriptor(class_, key, comparator, parententity, doc=doc) register_attribute_impl(class_, key, **kw) + return desc def register_attribute_impl(class_, key, uselist=False, callable_=None, useobject=False, mutable_scalars=False, - impl_class=None, **kw): + impl_class=None, backref=None, **kw): manager = manager_of_class(class_) if uselist: @@ -1044,6 +1040,9 @@ def register_attribute_impl(class_, key, impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw) manager[key].impl = impl + + if backref: + backref_listeners(manager[key], backref, uselist) manager.post_configure_attribute(key) @@ -1058,6 +1057,7 @@ def register_descriptor(class_, key, comparator=None, descriptor.__doc__ = doc manager.instrument_attribute(key, descriptor) + return descriptor def unregister_attribute(class_, key): manager_of_class(class_).uninstrument_attribute(key) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index b68290dbdd..81ac9262ce 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -1347,10 +1347,6 @@ class RelationshipProperty(StrategizedProperty): ) mapper._configure_property(backref_key, relationship) if self.back_populates: - self.extension = list(util.to_list(self.extension, - default=[])) - self.extension.append( - attributes.GenericBackrefExtension(self.back_populates)) self._add_reverse_property(self.back_populates) def _post_init(self): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d8d4afc37e..f23145da55 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -47,7 +47,6 @@ def _register_attribute(strategy, mapper, useobject, if useobject: attribute_ext.append(sessionlib.UOWEventHandler(prop.key)) - for m in mapper.self_and_descendants: if prop is m._props.get(prop.key): @@ -60,7 +59,7 @@ def _register_attribute(strategy, mapper, useobject, uselist=uselist, copy_function=copy_function, compare_function=compare_function, - useobject=useobject, + useobject=useobject, extension=attribute_ext, trackparent=useobject, typecallable=typecallable, @@ -398,6 +397,7 @@ class LazyLoader(AbstractRelationshipLoader): useobject=True, callable_=self._class_level_loader, uselist = self.parent_property.uselist, + backref = self.parent_property.back_populates, typecallable = self.parent_property.collection_class, active_history = \ self.parent_property.active_history or \ diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 913c6ec52e..82859cd546 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -348,7 +348,7 @@ class AttributesTest(_base.ORMTest): return [bar1, bar2, bar3] attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lambda o:func1, useobject=True, extension=[ReceiveEvents()]) - attributes.register_attribute(Bar, 'foos', uselist=True, useobject=True, extension=[attributes.GenericBackrefExtension('bars')]) + attributes.register_attribute(Bar, 'foos', uselist=True, useobject=True, backref="bars") x = Foo() assert_raises(AssertionError, Bar(id=4).foos.append, x) @@ -413,10 +413,10 @@ class AttributesTest(_base.ORMTest): # set up instrumented attributes with backrefs attributes.register_attribute(Post, 'blog', uselist=False, - extension=attributes.GenericBackrefExtension('posts'), + backref='posts', trackparent=True, useobject=True) attributes.register_attribute(Blog, 'posts', uselist=True, - extension=attributes.GenericBackrefExtension('blog'), + backref='blog', trackparent=True, useobject=True) # create objects as if they'd been freshly loaded from the database (without history) @@ -701,11 +701,9 @@ class BackrefTest(_base.ORMTest): instrumentation.register_class(Student) instrumentation.register_class(Course) attributes.register_attribute(Student, 'courses', uselist=True, - extension=attributes.GenericBackrefExtension('students' - ), useobject=True) + backref="students", useobject=True) attributes.register_attribute(Course, 'students', uselist=True, - extension=attributes.GenericBackrefExtension('courses' - ), useobject=True) + backref="courses", useobject=True) s = Student() c = Course() @@ -729,10 +727,10 @@ class BackrefTest(_base.ORMTest): instrumentation.register_class(Post) instrumentation.register_class(Blog) attributes.register_attribute(Post, 'blog', uselist=False, - extension=attributes.GenericBackrefExtension('posts'), + backref='posts', trackparent=True, useobject=True) attributes.register_attribute(Blog, 'posts', uselist=True, - extension=attributes.GenericBackrefExtension('blog'), + backref='blog', trackparent=True, useobject=True) b = Blog() (p1, p2, p3) = (Post(), Post(), Post()) @@ -762,12 +760,14 @@ class BackrefTest(_base.ORMTest): class Jack(object):pass instrumentation.register_class(Port) instrumentation.register_class(Jack) - attributes.register_attribute(Port, 'jack', uselist=False, - extension=attributes.GenericBackrefExtension('port'), - useobject=True) + + attributes.register_attribute(Port, 'jack', uselist=False, + useobject=True, backref="port") + attributes.register_attribute(Jack, 'port', uselist=False, - extension=attributes.GenericBackrefExtension('jack'), - useobject=True) + useobject=True, backref="jack") + + p = Port() j = Jack() p.jack = j @@ -798,16 +798,16 @@ class BackrefTest(_base.ORMTest): instrumentation.register_class(Child) instrumentation.register_class(SubChild) attributes.register_attribute(Parent, 'child', uselist=False, - extension=attributes.GenericBackrefExtension('parent'), + backref="parent", parent_token = p_token, useobject=True) attributes.register_attribute(Child, 'parent', uselist=False, - extension=attributes.GenericBackrefExtension('child'), + backref="child", parent_token = c_token, useobject=True) attributes.register_attribute(SubChild, 'parent', uselist=False, - extension=attributes.GenericBackrefExtension('child'), + backref="child", parent_token = c_token, useobject=True) @@ -833,15 +833,15 @@ class BackrefTest(_base.ORMTest): instrumentation.register_class(SubParent) instrumentation.register_class(Child) attributes.register_attribute(Parent, 'children', uselist=True, - extension=attributes.GenericBackrefExtension('parent'), + backref='parent', parent_token = p_token, useobject=True) attributes.register_attribute(SubParent, 'children', uselist=True, - extension=attributes.GenericBackrefExtension('parent'), + backref='parent', parent_token = p_token, useobject=True) attributes.register_attribute(Child, 'parent', uselist=False, - extension=attributes.GenericBackrefExtension('children'), + backref='children', parent_token = c_token, useobject=True) @@ -899,8 +899,8 @@ class PendingBackrefTest(_base.ORMTest): instrumentation.register_class(Post) instrumentation.register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True) + attributes.register_attribute(Post, 'blog', uselist=False, backref='posts', trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, backref='blog', callable_=lazy_posts, trackparent=True, useobject=True) def test_lazy_add(self): global lazy_load @@ -1355,8 +1355,8 @@ class HistoryTest(_base.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, useobject=True) - attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) + attributes.register_attribute(Foo, 'bars', uselist=True, backref='foo', trackparent=True, useobject=True) + attributes.register_attribute(Bar, 'foo', uselist=False, backref='bars', trackparent=True, useobject=True) f1 = Foo() b1 = Bar() @@ -1388,8 +1388,8 @@ class HistoryTest(_base.ORMTest): instrumentation.register_class(Foo) instrumentation.register_class(Bar) - attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, callable_=lazyload, useobject=True) - attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True) + attributes.register_attribute(Foo, 'bars', uselist=True, backref='foo', trackparent=True, callable_=lazyload, useobject=True) + attributes.register_attribute(Bar, 'foo', uselist=False, backref='bars', trackparent=True, useobject=True) bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)] lazy_load = [bar1, bar2, bar3] diff --git a/test/orm/test_extendedattr.py b/test/orm/test_extendedattr.py index aae1ecdbb5..ec7963c293 100644 --- a/test/orm/test_extendedattr.py +++ b/test/orm/test_extendedattr.py @@ -224,8 +224,8 @@ class UserDefinedExtensionTest(_base.ORMTest): instrumentation.register_class(Post) instrumentation.register_class(Blog) - attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True) - attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True) + attributes.register_attribute(Post, 'blog', uselist=False, backref='posts', trackparent=True, useobject=True) + attributes.register_attribute(Blog, 'posts', uselist=True, backref='blog', trackparent=True, useobject=True) b = Blog() (p1, p2, p3) = (Post(), Post(), Post()) b.posts.append(p1)