]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- replace GenericBackrefExtension with straight events
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Dec 2010 01:25:22 +0000 (20:25 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Dec 2010 01:25:22 +0000 (20:25 -0500)
- add "backref" argument to register_attribute_impl

lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_attributes.py
test/orm/test_extendedattr.py

index 9ae885bf9930b1cfb859ba3b2077a055a671a865..d80a7fe5a549346bb8056a456e42bb714fb8c33e 100644 (file)
@@ -791,21 +791,10 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return getattr(user_data, '_sa_adapter')
 
-class GenericBackrefExtension(interfaces.AttributeExtension):
-    """An extension which synchronizes a two-way relationship.
+def backref_listeners(attribute, key, uselist):
+    """Apply listeners to synchronize a two-way relationship."""
 
-    A typical two-way relationship is a parent object containing a list of
-    child objects, where each child object references the parent.  The other
-    are two objects which contain scalar references to each other.
-
-    """
-    
-    active_history = False
-    
-    def __init__(self, key):
-        self.key = key
-
-    def set(self, state, child, oldchild, initiator):
+    def set_(state, child, oldchild, initiator):
         if oldchild is child:
             return child
 
@@ -814,7 +803,7 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
             # present when updating via a backref.
             old_state, old_dict = instance_state(oldchild),\
                                     instance_dict(oldchild)
-            impl = old_state.get_impl(self.key)
+            impl = old_state.get_impl(key)
             try:
                 impl.remove(old_state, 
                             old_dict, 
@@ -826,7 +815,7 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
         if child is not None:
             child_state, child_dict = instance_state(child),\
                                         instance_dict(child)
-            child_state.get_impl(self.key).append(
+            child_state.get_impl(key).append(
                                             child_state, 
                                             child_dict, 
                                             state.obj(), 
@@ -834,10 +823,10 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
                                             passive=PASSIVE_NO_FETCH)
         return child
 
-    def append(self, state, child, initiator):
+    def append(state, child, initiator):
         child_state, child_dict = instance_state(child), \
                                     instance_dict(child)
-        child_state.get_impl(self.key).append(
+        child_state.get_impl(key).append(
                                             child_state, 
                                             child_dict, 
                                             state.obj(), 
@@ -845,18 +834,24 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
                                             passive=PASSIVE_NO_FETCH)
         return child
 
-    def remove(self, state, child, initiator):
+    def remove(state, child, initiator):
         if child is not None:
             child_state, child_dict = instance_state(child),\
                                         instance_dict(child)
-            child_state.get_impl(self.key).remove(
+            child_state.get_impl(key).remove(
                                             child_state, 
                                             child_dict, 
                                             state.obj(), 
                                             initiator,
                                             passive=PASSIVE_NO_FETCH)
-
-
+    
+    if uselist:
+        event.listen(append, "on_append", attribute, retval=False, raw=True)
+    else:
+        event.listen(set_, "on_set", attribute, retval=False, raw=True)
+    # TODO: need coverage in test/orm/ of remove event
+    event.listen(remove, "on_remove", attribute, retval=False, raw=True)
+        
 class History(tuple):
     """A 3-tuple of added, unchanged and deleted values,
     representing the changes which have occured on an instrumented
@@ -1010,14 +1005,15 @@ def register_attribute(class_, key, **kw):
     comparator = kw.pop('comparator', None)
     parententity = kw.pop('parententity', None)
     doc = kw.pop('doc', None)
-    register_descriptor(class_, key, 
+    desc = register_descriptor(class_, key, 
                             comparator, parententity, doc=doc)
     register_attribute_impl(class_, key, **kw)
+    return desc
     
 def register_attribute_impl(class_, key,         
         uselist=False, callable_=None, 
         useobject=False, mutable_scalars=False, 
-        impl_class=None, **kw):
+        impl_class=None, backref=None, **kw):
     
     manager = manager_of_class(class_)
     if uselist:
@@ -1044,6 +1040,9 @@ def register_attribute_impl(class_, key,
         impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw)
 
     manager[key].impl = impl
+    
+    if backref:
+        backref_listeners(manager[key], backref, uselist)
 
     manager.post_configure_attribute(key)
 
@@ -1058,6 +1057,7 @@ def register_descriptor(class_, key, comparator=None,
     descriptor.__doc__ = doc
         
     manager.instrument_attribute(key, descriptor)
+    return descriptor
 
 def unregister_attribute(class_, key):
     manager_of_class(class_).uninstrument_attribute(key)
index b68290dbdd1d9b1b2d7d7277b69055e095e27a69..81ac9262ce3d8aa09dc78599b2e2baf3101aa363 100644 (file)
@@ -1347,10 +1347,6 @@ class RelationshipProperty(StrategizedProperty):
                 )
             mapper._configure_property(backref_key, relationship)
         if self.back_populates:
-            self.extension = list(util.to_list(self.extension,
-                                  default=[]))
-            self.extension.append(
-                    attributes.GenericBackrefExtension(self.back_populates))
             self._add_reverse_property(self.back_populates)
         
     def _post_init(self):
index d8d4afc37e11ec56a725d42b8ef2697d18d7993c..f23145da55ab6cff6344789c0c7136413baef47b 100644 (file)
@@ -47,7 +47,6 @@ def _register_attribute(strategy, mapper, useobject,
     
     if useobject:
         attribute_ext.append(sessionlib.UOWEventHandler(prop.key))
-
     
     for m in mapper.self_and_descendants:
         if prop is m._props.get(prop.key):
@@ -60,7 +59,7 @@ def _register_attribute(strategy, mapper, useobject,
                 uselist=uselist, 
                 copy_function=copy_function, 
                 compare_function=compare_function, 
-                useobject=useobject, 
+                useobject=useobject,
                 extension=attribute_ext, 
                 trackparent=useobject, 
                 typecallable=typecallable,
@@ -398,6 +397,7 @@ class LazyLoader(AbstractRelationshipLoader):
                 useobject=True,
                 callable_=self._class_level_loader,
                 uselist = self.parent_property.uselist,
+                backref = self.parent_property.back_populates,
                 typecallable = self.parent_property.collection_class,
                 active_history = \
                     self.parent_property.active_history or \
index 913c6ec52e2eb2285f303a9e2f0ad9fba09a96f2..82859cd54660e504d0759b0a5be37db90acf6967 100644 (file)
@@ -348,7 +348,7 @@ class AttributesTest(_base.ORMTest):
             return [bar1, bar2, bar3]
 
         attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lambda o:func1, useobject=True, extension=[ReceiveEvents()])
-        attributes.register_attribute(Bar, 'foos', uselist=True, useobject=True, extension=[attributes.GenericBackrefExtension('bars')])
+        attributes.register_attribute(Bar, 'foos', uselist=True, useobject=True, backref="bars")
 
         x = Foo()
         assert_raises(AssertionError, Bar(id=4).foos.append, x)
@@ -413,10 +413,10 @@ class AttributesTest(_base.ORMTest):
 
         # set up instrumented attributes with backrefs
         attributes.register_attribute(Post, 'blog', uselist=False,
-                                        extension=attributes.GenericBackrefExtension('posts'),
+                                        backref='posts',
                                         trackparent=True, useobject=True)
         attributes.register_attribute(Blog, 'posts', uselist=True,
-                                        extension=attributes.GenericBackrefExtension('blog'),
+                                        backref='blog',
                                         trackparent=True, useobject=True)
 
         # create objects as if they'd been freshly loaded from the database (without history)
@@ -701,11 +701,9 @@ class BackrefTest(_base.ORMTest):
         instrumentation.register_class(Student)
         instrumentation.register_class(Course)
         attributes.register_attribute(Student, 'courses', uselist=True,
-                extension=attributes.GenericBackrefExtension('students'
-                ), useobject=True)
+                backref="students", useobject=True)
         attributes.register_attribute(Course, 'students', uselist=True,
-                extension=attributes.GenericBackrefExtension('courses'
-                ), useobject=True)
+                backref="courses", useobject=True)
 
         s = Student()
         c = Course()
@@ -729,10 +727,10 @@ class BackrefTest(_base.ORMTest):
         instrumentation.register_class(Post)
         instrumentation.register_class(Blog)
         attributes.register_attribute(Post, 'blog', uselist=False,
-                extension=attributes.GenericBackrefExtension('posts'),
+                backref='posts',
                 trackparent=True, useobject=True)
         attributes.register_attribute(Blog, 'posts', uselist=True,
-                extension=attributes.GenericBackrefExtension('blog'),
+                backref='blog',
                 trackparent=True, useobject=True)
         b = Blog()
         (p1, p2, p3) = (Post(), Post(), Post())
@@ -762,12 +760,14 @@ class BackrefTest(_base.ORMTest):
         class Jack(object):pass
         instrumentation.register_class(Port)
         instrumentation.register_class(Jack)
-        attributes.register_attribute(Port, 'jack', uselist=False,
-                extension=attributes.GenericBackrefExtension('port'),
-                useobject=True)
+
+        attributes.register_attribute(Port, 'jack', uselist=False, 
+                                            useobject=True, backref="port")
+        
         attributes.register_attribute(Jack, 'port', uselist=False,
-                extension=attributes.GenericBackrefExtension('jack'),
-                useobject=True)
+                                            useobject=True, backref="jack")
+        
+                
         p = Port()
         j = Jack()
         p.jack = j
@@ -798,16 +798,16 @@ class BackrefTest(_base.ORMTest):
         instrumentation.register_class(Child)
         instrumentation.register_class(SubChild)
         attributes.register_attribute(Parent, 'child', uselist=False,
-                extension=attributes.GenericBackrefExtension('parent'),
+                backref="parent",
                 parent_token = p_token,
                 useobject=True)
         attributes.register_attribute(Child, 'parent', uselist=False,
-                extension=attributes.GenericBackrefExtension('child'),
+                backref="child",
                 parent_token = c_token,
                 useobject=True)
         attributes.register_attribute(SubChild, 'parent',
                 uselist=False,
-                extension=attributes.GenericBackrefExtension('child'),
+                backref="child",
                 parent_token = c_token,
                 useobject=True)
         
@@ -833,15 +833,15 @@ class BackrefTest(_base.ORMTest):
         instrumentation.register_class(SubParent)
         instrumentation.register_class(Child)
         attributes.register_attribute(Parent, 'children', uselist=True,
-                extension=attributes.GenericBackrefExtension('parent'),
+                backref='parent',
                 parent_token = p_token,
                 useobject=True)
         attributes.register_attribute(SubParent, 'children', uselist=True,
-                extension=attributes.GenericBackrefExtension('parent'),
+                backref='parent',
                 parent_token = p_token,
                 useobject=True)
         attributes.register_attribute(Child, 'parent', uselist=False,
-                extension=attributes.GenericBackrefExtension('children'),
+                backref='children',
                 parent_token = c_token,
                 useobject=True)
         
@@ -899,8 +899,8 @@ class PendingBackrefTest(_base.ORMTest):
 
         instrumentation.register_class(Post)
         instrumentation.register_class(Blog)
-        attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
-        attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True)
+        attributes.register_attribute(Post, 'blog', uselist=False, backref='posts', trackparent=True, useobject=True)
+        attributes.register_attribute(Blog, 'posts', uselist=True, backref='blog', callable_=lazy_posts, trackparent=True, useobject=True)
 
     def test_lazy_add(self):
         global lazy_load
@@ -1355,8 +1355,8 @@ class HistoryTest(_base.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, useobject=True)
-        attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True)
+        attributes.register_attribute(Foo, 'bars', uselist=True, backref='foo', trackparent=True, useobject=True)
+        attributes.register_attribute(Bar, 'foo', uselist=False, backref='bars', trackparent=True, useobject=True)
 
         f1 = Foo()
         b1 = Bar()
@@ -1388,8 +1388,8 @@ class HistoryTest(_base.ORMTest):
 
         instrumentation.register_class(Foo)
         instrumentation.register_class(Bar)
-        attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, callable_=lazyload, useobject=True)
-        attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True)
+        attributes.register_attribute(Foo, 'bars', uselist=True, backref='foo', trackparent=True, callable_=lazyload, useobject=True)
+        attributes.register_attribute(Bar, 'foo', uselist=False, backref='bars', trackparent=True, useobject=True)
 
         bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)]
         lazy_load = [bar1, bar2, bar3]
index aae1ecdbb5a38e3d75fcb9ec2764327a333fe279..ec7963c29382d9482b67c0d9d6f22de5b2fc7286 100644 (file)
@@ -224,8 +224,8 @@ class UserDefinedExtensionTest(_base.ORMTest):
 
             instrumentation.register_class(Post)
             instrumentation.register_class(Blog)
-            attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
-            attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
+            attributes.register_attribute(Post, 'blog', uselist=False, backref='posts', trackparent=True, useobject=True)
+            attributes.register_attribute(Blog, 'posts', uselist=True, backref='blog', trackparent=True, useobject=True)
             b = Blog()
             (p1, p2, p3) = (Post(), Post(), Post())
             b.posts.append(p1)