]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- a major behavioral change to collection-based backrefs: they no
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Dec 2007 20:43:16 +0000 (20:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Dec 2007 20:43:16 +0000 (20:43 +0000)
longer trigger lazy loads !  "reverse" adds and removes
are queued up and are merged with the collection when it is
actually read from and loaded; but do not trigger a load beforehand.
For users who have noticed this behavior, this should be much more
convenient than using dynamic relations in some cases; for those who
have not, you might notice your apps using a lot fewer queries than
before in some situations. [ticket:871]

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dynamic.py
test/orm/attributes.py
test/orm/lazy_relations.py
test/orm/mapper.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index 935bfb6ce3fbeea4d6ff4ea8d9c602de98add611..3b212fe82f619259eea0e2e7e5438794ef74c42d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -31,7 +31,15 @@ CHANGES
     - from_obj keyword argument to select() can be a scalar or a list.
     
 - orm
-
+   - a major behavioral change to collection-based backrefs: they no 
+     longer trigger lazy loads !  "reverse" adds and removes 
+     are queued up and are merged with the collection when it is 
+     actually read from and loaded; but do not trigger a load beforehand.
+     For users who have noticed this behavior, this should be much more
+     convenient than using dynamic relations in some cases; for those who 
+     have not, you might notice your apps using a lot fewer queries than
+     before in some situations. [ticket:871]
+     
    - new synonym() behavior: an attribute will be placed on the mapped
      class, if one does not exist already, in all cases. if a property
      already exists on the class, the synonym will decorate the property
index 8268d0816c511747337f78dafdf4f8e424c87b6e..bb7085402d590e13e53be0276760785c75050afd 100644 (file)
@@ -151,6 +151,7 @@ class AttributeImpl(object):
                 value = state.dict[self.key]
         if value is not NO_VALUE:
             state.committed_state[self.key] = self.copy(value)
+        state.pending.pop(self.key, None)
 
     def hasparent(self, state, optimistic=False):
         """Return the boolean value of a `hasparent` flag attached to the given item.
@@ -181,7 +182,7 @@ class AttributeImpl(object):
         current = self.get(state, passive=passive)
         if current is PASSIVE_NORESULT:
             return None
-        return AttributeHistory(self, state, current, passive=passive)
+        return AttributeHistory(self, state, current)
         
     def set_callable(self, state, callable_, clear=False):
         """Set a callable function for this attribute on the given object.
@@ -249,10 +250,10 @@ class AttributeImpl(object):
                 # Return a new, empty value
                 return self.initialize(state)
 
-    def append(self, state, value, initiator):
+    def append(self, state, value, initiator, passive=False):
         self.set(state, value, initiator)
 
-    def remove(self, state, value, initiator):
+    def remove(self, state, value, initiator, passive=False):
         self.set(state, None, initiator)
 
     def set(self, state, value, initiator):
@@ -433,17 +434,27 @@ class CollectionAttributeImpl(AttributeImpl):
         state.dict[self.key] = user_data
         return user_data
 
-    def append(self, state, value, initiator):
+    def append(self, state, value, initiator, passive=False):
         if initiator is self:
             return
-        collection = self.get_collection(state)
-        collection.append_with_event(value, initiator)
 
-    def remove(self, state, value, initiator):
+        collection = self.get_collection(state, passive=passive)
+        if collection is PASSIVE_NORESULT:
+            state.get_pending(self).append(value)
+            self.fire_append_event(state, value, initiator)
+        else:
+            collection.append_with_event(value, initiator)
+
+    def remove(self, state, value, initiator, passive=False):
         if initiator is self:
             return
-        collection = self.get_collection(state)
-        collection.remove_with_event(value, initiator)
+
+        collection = self.get_collection(state, passive=passive)
+        if collection is PASSIVE_NORESULT:
+            state.get_pending(self).remove(value)
+            self.fire_remove_event(state, value, initiator)
+        else:
+            collection.remove_with_event(value, initiator)
 
     def set(self, state, value, initiator):
         """Set a value on the given object.
@@ -470,7 +481,7 @@ class CollectionAttributeImpl(AttributeImpl):
 
         old = self.get(state)
         old_collection = self.get_collection(state, old)
-
+        
         new_collection, user_data = self._build_collection(state)
 
         idset = util.IdentitySet
@@ -494,7 +505,10 @@ class CollectionAttributeImpl(AttributeImpl):
             old_collection.unlink(old)
 
     def set_committed_value(self, state, value):
-        """Set an attribute value on the given instance and 'commit' it."""
+        """Set an attribute value on the given instance and 'commit' it.
+        
+        Loads the existing collection from lazy callables in all cases.
+        """
 
         collection, user_data = self._build_collection(state)
         self._load_collection(state, value or [], emit_events=False,
@@ -509,24 +523,45 @@ class CollectionAttributeImpl(AttributeImpl):
         return value
 
     def _build_collection(self, state):
+        """build a new, blank collection and return it wrapped in a CollectionAdapter."""
+        
         user_data = self.collection_factory()
         collection = collections.CollectionAdapter(self, state, user_data)
         return collection, user_data
 
     def _load_collection(self, state, values, emit_events=True, collection=None):
+        """given an empty CollectionAdapter, load the collection with current values.
+        
+        Loads the collection from lazy callables in all cases.
+        """
+        
         collection = collection or self.get_collection(state)
         if values is None:
             return
-        elif emit_events:
+
+        appender = emit_events and collection.append_with_event or collection.append_without_event
+        
+        if self.key in state.pending:
+            # move 'pending' items into the newly loaded collection
+            added = state.pending[self.key].added_items
+            removed = state.pending[self.key].deleted_items
             for item in values:
-                collection.append_with_event(item)
+                if item not in removed:
+                    appender(item)
+            for item in added:
+                appender(item)
+            del state.pending[self.key]
         else:
             for item in values:
-                collection.append_without_event(item)
+                appender(item)
 
-    def get_collection(self, state, user_data=None):
+    def get_collection(self, state, user_data=None, passive=False):
+        """retrieve the CollectionAdapter associated with the given state."""
+        
         if user_data is None:
-            user_data = self.get(state)
+            user_data = self.get(state, passive=passive)
+            if user_data is PASSIVE_NORESULT:
+                return user_data
         try:
             return getattr(user_data, '_sa_adapter')
         except AttributeError:
@@ -554,18 +589,18 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
             # present when updating via a backref.
             impl = getattr(oldchild.__class__, self.key).impl
             try:                
-                impl.remove(oldchild._state, instance, initiator)
+                impl.remove(oldchild._state, instance, initiator, passive=True)
             except (ValueError, KeyError, IndexError):
                 pass
         if child is not None:
-            getattr(child.__class__, self.key).impl.append(child._state, instance, initiator)
+            getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True)
 
     def append(self, instance, child, initiator):
-        getattr(child.__class__, self.key).impl.append(child._state, instance, initiator)
+        getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True)
 
     def remove(self, instance, child, initiator):
         if child is not None:
-            getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator)
+            getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator, passive=True)
 
 class ClassState(object):
     """tracks state information at the class level."""
@@ -577,7 +612,7 @@ class ClassState(object):
 class InstanceState(object):
     """tracks state information at the instance level."""
 
-    __slots__ = 'class_', 'obj', 'dict', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj', 'expired_attributes'
+    __slots__ = 'class_', 'obj', 'dict', 'pending', 'committed_state', 'modified', 'trigger', 'callables', 'parents', 'instance_dict', '_strong_obj', 'expired_attributes'
     
     def __init__(self, obj):
         self.class_ = obj.__class__
@@ -588,6 +623,7 @@ class InstanceState(object):
         self.trigger = None
         self.callables = {}
         self.parents = {}
+        self.pending = {}
         self.instance_dict = None
         
     def __cleanup(self, ref):
@@ -627,6 +663,11 @@ class InstanceState(object):
         finally:
             instance_dict._mutex.release()
 
+    def get_pending(self, attributeimpl):
+        if attributeimpl.key not in self.pending:
+            self.pending[attributeimpl.key] = PendingCollection()
+        return self.pending[attributeimpl.key]
+        
     def is_modified(self):
         if self.modified:
             return True
@@ -654,11 +695,12 @@ class InstanceState(object):
             return None
             
     def __getstate__(self):
-        return {'committed_state':self.committed_state, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()}
+        return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()}
     
     def __setstate__(self, state):
         self.committed_state = state['committed_state']
         self.parents = state['parents']
+        self.pending = state['pending']
         self.modified = state['modified']
         self.obj = weakref.ref(state['instance'])
         self.class_ = self.obj().__class__
@@ -857,7 +899,7 @@ class AttributeHistory(object):
     particular instance.
     """
 
-    def __init__(self, attr, state, current, passive=False):
+    def __init__(self, attr, state, current):
         self.attr = attr
 
         # get the "original" value.  if a lazy load was fired when we got
@@ -919,6 +961,27 @@ class AttributeHistory(object):
     def deleted_items(self):
         return list(self._deleted_items)
 
+class PendingCollection(object):
+    """stores items appended and removed from a collection that has not been loaded yet.
+    
+    When the collection is loaded, the changes present in PendingCollection are applied
+    to produce the final result.
+    """
+    
+    def __init__(self):
+        self.deleted_items = util.IdentitySet()
+        self.added_items = util.OrderedIdentitySet()
+
+    def append(self, value):
+        if value in self.deleted_items:
+            self.deleted_items.remove(value)
+        self.added_items.add(value)
+    
+    def remove(self, value):
+        if value in self.added_items:
+            self.added_items.remove(value)
+        self.deleted_items.add(value)
+    
 def _managed_attributes(class_):
     """return all InstrumentedAttributes associated with the given class_ and its superclasses."""
     
index 56cf58d9b52adc0be6edb28b6a2a0109e91f3849..0c49bcfc39efbb3faf2d606831f25ee219d51c7d 100644 (file)
@@ -44,12 +44,12 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
             state.dict[self.key] = c = CollectionHistory(self, state)
             return c
 
-    def append(self, state, value, initiator):
+    def append(self, state, value, initiator, passive=False):
         if initiator is not self:
             self.get_history(state)._added_items.append(value)
             self.fire_append_event(state, value, initiator)
     
-    def remove(self, state, value, initiator):
+    def remove(self, state, value, initiator, passive=False):
         if initiator is not self:
             self.get_history(state)._deleted_items.append(value)
             self.fire_remove_event(state, value, initiator)
index 4e41f0a2951e28aa96b71e165d44c08611d2490c..b321dc50a17e7f816d700c4fd559318eb4276a53 100644 (file)
@@ -150,82 +150,9 @@ class AttributesTest(PersistTest):
             self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
             self.assert_(len(attributes.get_history(u, 'addresses').unchanged_items()) == 1)
 
-    def test_backref(self):
-        class Student(object):pass
-        class Course(object):pass
-        
-        attributes.register_class(Student)
-        attributes.register_class(Course)
-        attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True)
-        attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True)
-        
-        s = Student()
-        c = Course()
-        s.courses.append(c)
-        self.assert_(c.students == [s])
-        s.courses.remove(c)
-        self.assert_(c.students == [])
-        
-        (s1, s2, s3) = (Student(), Student(), Student())
-             
-        c.students = [s1, s2, s3]
-        self.assert_(s2.courses == [c])
-        self.assert_(s1.courses == [c])
-        print "--------------------------------"
-        print s1
-        print s1.courses
-        print c
-        print c.students
-        s1.courses.remove(c)
-        self.assert_(c.students == [s2,s3])        
-        class Post(object):pass
-        class Blog(object):pass
-
-        attributes.register_class(Post)
-        attributes.register_class(Blog)
-        attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
-        attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
-        b = Blog()
-        (p1, p2, p3) = (Post(), Post(), Post())
-        b.posts.append(p1)
-        b.posts.append(p2)
-        b.posts.append(p3)
-        self.assert_(b.posts == [p1, p2, p3])
-        self.assert_(p2.blog is b)
         
-        p3.blog = None
-        self.assert_(b.posts == [p1, p2])
-        p4 = Post()
-        p4.blog = b
-        self.assert_(b.posts == [p1, p2, p4])
-        
-        p4.blog = b
-        p4.blog = b
-        self.assert_(b.posts == [p1, p2, p4])
-
-        # assert no failure removing None
-        p5 = Post()
-        p5.blog = None
-        del p5.blog
-
-        class Port(object):pass
-        class Jack(object):pass
-        attributes.register_class(Port)
-        attributes.register_class(Jack)
-        attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True)
-        attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True)
-        p = Port()
-        j = Jack()
-        p.jack = j
-        self.assert_(j.port is p)
-        self.assert_(p.jack is not None)
-        
-        j.port = None
-        self.assert_(p.jack is None)
-
     def test_lazytrackparent(self):
         """test that the "hasparent" flag works properly when lazy loaders and backrefs are used"""
-        
 
         class Post(object):pass
         class Blog(object):pass
@@ -449,6 +376,173 @@ class AttributesTest(PersistTest):
             assert True
         except exceptions.ArgumentError, e:
             assert False
-            
+
+
+class BackrefTest(PersistTest):
+        
+    def test_manytomany(self):
+        class Student(object):pass
+        class Course(object):pass
+
+        attributes.register_class(Student)
+        attributes.register_class(Course)
+        attributes.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'), useobject=True)
+        attributes.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'), useobject=True)
+
+        s = Student()
+        c = Course()
+        s.courses.append(c)
+        self.assert_(c.students == [s])
+        s.courses.remove(c)
+        self.assert_(c.students == [])
+
+        (s1, s2, s3) = (Student(), Student(), Student())
+
+        c.students = [s1, s2, s3]
+        self.assert_(s2.courses == [c])
+        self.assert_(s1.courses == [c])
+        print "--------------------------------"
+        print s1
+        print s1.courses
+        print c
+        print c.students
+        s1.courses.remove(c)
+        self.assert_(c.students == [s2,s3])        
+    
+    def test_onetomany(self):
+        class Post(object):pass
+        class Blog(object):pass
+        
+        attributes.register_class(Post)
+        attributes.register_class(Blog)
+        attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
+        attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
+        b = Blog()
+        (p1, p2, p3) = (Post(), Post(), Post())
+        b.posts.append(p1)
+        b.posts.append(p2)
+        b.posts.append(p3)
+        self.assert_(b.posts == [p1, p2, p3])
+        self.assert_(p2.blog is b)
+
+        p3.blog = None
+        self.assert_(b.posts == [p1, p2])
+        p4 = Post()
+        p4.blog = b
+        self.assert_(b.posts == [p1, p2, p4])
+
+        p4.blog = b
+        p4.blog = b
+        self.assert_(b.posts == [p1, p2, p4])
+
+        # assert no failure removing None
+        p5 = Post()
+        p5.blog = None
+        del p5.blog
+
+    def test_onetoone(self):
+        class Port(object):pass
+        class Jack(object):pass
+        attributes.register_class(Port)
+        attributes.register_class(Jack)
+        attributes.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'), useobject=True)
+        attributes.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'), useobject=True)
+        p = Port()
+        j = Jack()
+        p.jack = j
+        self.assert_(j.port is p)
+        self.assert_(p.jack is not None)
+
+        j.port = None
+        self.assert_(p.jack is None)
+
+class DeferredBackrefTest(PersistTest):
+    def setUp(self):
+        global Post, Blog, called, lazy_load
+        
+        class Post(object):
+            def __init__(self, name):
+                self.name = name
+            def __eq__(self, other):
+                return other.name == self.name
+
+        class Blog(object):
+            def __init__(self, name):
+                self.name = name
+            def __eq__(self, other):
+                return other.name == self.name
+
+        called = [0]
+
+        lazy_load = []
+        def lazy_posts(instance):
+            def load():
+                called[0] += 1
+                return lazy_load
+            return load
+
+        attributes.register_class(Post)
+        attributes.register_class(Blog)
+        attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
+        attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), callable_=lazy_posts, trackparent=True, useobject=True)
+
+    def test_lazy_add(self):
+        global lazy_load
+
+        p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3")
+        lazy_load = [p1, p2, p3]
+
+        b = Blog("blog 1")
+        p = Post("post 4")
+        p.blog = b
+        p = Post("post 5")
+        p.blog = b
+        # setting blog doesnt call 'posts' callable
+        assert called[0] == 0
+
+        # calling backref calls the callable, populates extra posts
+        assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")]
+        assert called[0] == 1
+
+    def test_lazy_remove(self):
+        global lazy_load
+        called[0] = 0
+        lazy_load = []
+
+        b = Blog("blog 1")
+        p = Post("post 1")
+        p.blog = b
+        assert called[0] == 0
+
+        lazy_load = [p]
+
+        p.blog = None
+        p2 = Post("post 2")
+        p2.blog = b
+        assert called[0] == 0
+        assert b.posts == [p2]
+        assert called[0] == 1
+
+    def test_normal_load(self):
+        global lazy_load
+        lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")]
+        called[0] = 0
+
+        b = Blog("blog 1")
+
+        # assign without using backref system
+        p2.__dict__['blog'] = b
+
+        assert b.posts == [Post("post 1"), Post("post 2"), Post("post 3")]
+        assert called[0] == 1
+        p2.blog = None
+        p4 = Post("post 4")
+        p4.blog = b
+        assert b.posts == [Post("post 1"), Post("post 3"), Post("post 4")]
+        assert called[0] == 1
+
+        called[0] = 0
+        lazy_load = (p1, p2, p3) = [Post("post 1"), Post("post 2"), Post("post 3")]
+        
 if __name__ == "__main__":
     testbase.main()
index 97eda3006327243f6bfb68770c51c95f17487d1e..487eb77168dbd5d35bca6b1553c5e0ab1a5ad817 100644 (file)
@@ -272,7 +272,41 @@ class LazyTest(FixtureTest):
         u1 = sess.query(User).get(7)
         
         assert a.user is u1
+    
+    def test_backrefs_dont_lazyload(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user')
+        })
+        mapper(Address, addresses)
+        sess = create_session()
+        ad = sess.query(Address).filter_by(id=1).one()
+        assert ad.user.id == 7
+        def go():
+            ad.user = None
+            assert ad.user is None
+        self.assert_sql_count(testbase.db, go, 0)
+
+        u1 = sess.query(User).filter_by(id=7).one()
+        def go():
+            assert ad not in u1.addresses
+        self.assert_sql_count(testbase.db, go, 1)
+
+        sess.expire(u1, ['addresses'])
+        def go():
+            assert ad in u1.addresses
+        self.assert_sql_count(testbase.db, go, 1)
 
+        sess.expire(u1, ['addresses'])
+        ad2 = Address()
+        def go():
+            ad2.user = u1
+            assert ad2.user is u1
+        self.assert_sql_count(testbase.db, go, 0)
+        
+        def go():
+            assert ad2 in u1.addresses
+        self.assert_sql_count(testbase.db, go, 1)
+            
 class M2OGetTest(FixtureTest):
     keep_mappers = False
     keep_data = True
index 36f0561567aa8fc7cced98afe46d241d493a4663..df1b6bba1d71312f63f98ab3e042e450923e1aef 100644 (file)
@@ -1322,15 +1322,17 @@ class RequirementsTest(AssertMixin):
         h1.h1s.append(H1())
 
         s.flush()
-
+        self.assertEquals(t1.count().scalar(), 4)
+        
         h6 = H6()
         h6.h1a = h1
         h6.h1b = h1
 
         h6 = H6()
         h6.h1a = h1
-        h6.h1b = H1()
-
+        h6.h1b = x = H1()
+        assert x in s
+        
         h6.h1b.h2s.append(H2())
 
         s.flush()
index 158813cd7fb9f8a12e2a11bcdee6c82487865159..11d7313775d287e542dc2857495e94a9d3dcf307 100644 (file)
@@ -33,8 +33,7 @@ class HistoryTest(ORMTest):
         u = User(_sa_session=s)
         a = Address(_sa_session=s)
         a.user = u
-        #print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items())
-        #print repr(u.addresses.added_items())
+
         self.assert_(u.addresses == [a])
         s.commit()