]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Squeezed a few more unnecessary "lazy loads" out of
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Jul 2009 01:46:41 +0000 (01:46 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Jul 2009 01:46:41 +0000 (01:46 +0000)
relation().  When a collection is mutated, many-to-one
backrefs on the other side will not fire off to load
the "old" value, unless "single_parent=True" is set.
A direct assignment of a many-to-one still loads
the "old" value in order to update backref collections
on that value, which may be present in the session
already, thus maintaining the 0.5 behavioral contract.
[ticket:1483]

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/test_backref_mutations.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index 4ab2422fc15a916c81e2adb848d9eb4dc46cce51..a0246dd0b3118766d59d42ffe49b52a3341fdaeb 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -17,6 +17,16 @@ CHANGES
       many-to-many relations from concrete inheritance setups.
       Outside of that use case, YMMV.  [ticket:1477]
     
+    - Squeezed a few more unnecessary "lazy loads" out of 
+      relation().  When a collection is mutated, many-to-one
+      backrefs on the other side will not fire off to load
+      the "old" value, unless "single_parent=True" is set.  
+      A direct assignment of a many-to-one still loads 
+      the "old" value in order to update backref collections 
+      on that value, which may be present in the session 
+      already, thus maintaining the 0.5 behavioral contract.
+      [ticket:1483]
+      
     - Fixed bug whereby a load/refresh of joined table
       inheritance attributes which were based on 
       column_property() or similar would fail to evaluate.
index 2c26f34f2a8f655876d7bc05d881cec64b7d46da..46e9b00de2bbc389e87a1f7a0b552082a478aa49 100644 (file)
@@ -262,7 +262,6 @@ class AttributeImpl(object):
         active_history
           indicates that get_history() should always return the "old" value,
           even if it means executing a lazy callable upon attribute change.
-          This flag is set to True if any extensions are present.
 
         parent_token
           Usually references the MapperProperty, used as a key for
@@ -286,6 +285,10 @@ class AttributeImpl(object):
         else:
             self.is_equal = compare_function
         self.extensions = util.to_list(extension or [])
+        for e in self.extensions:
+            if e.active_history:
+                active_history = True
+                break
         self.active_history = active_history
         self.dont_expire_missing = dont_expire_missing
         
@@ -383,12 +386,12 @@ class AttributeImpl(object):
             return self.initialize(state, dict_)
 
     def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
-        self.set(state, dict_, value, initiator)
+        self.set(state, dict_, value, initiator, passive=passive)
 
     def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
-        self.set(state, dict_, None, initiator)
+        self.set(state, dict_, None, initiator, passive=passive)
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         raise NotImplementedError()
 
     def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
@@ -421,7 +424,7 @@ class ScalarAttributeImpl(AttributeImpl):
     def delete(self, state, dict_):
 
         # TODO: catch key errors, convert to attributeerror?
-        if self.active_history or self.extensions:
+        if self.active_history:
             old = self.get(state, dict_)
         else:
             old = dict_.get(self.key, NO_VALUE)
@@ -436,11 +439,11 @@ class ScalarAttributeImpl(AttributeImpl):
         return History.from_attribute(
             self, state, dict_.get(self.key, NO_VALUE))
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         if initiator is self:
             return
 
-        if self.active_history or self.extensions:
+        if self.active_history:
             old = self.get(state, dict_)
         else:
             old = dict_.get(self.key, NO_VALUE)
@@ -511,7 +514,7 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
         ScalarAttributeImpl.delete(self, state, dict_)
         state.mutable_dict.pop(self.key)
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         if initiator is self:
             return
 
@@ -559,7 +562,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
             else:
                 return History.from_attribute(self, state, current)
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         """Set a value on the given InstanceState.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -569,9 +572,21 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         """
         if initiator is self:
             return
-
-        # may want to add options to allow the get() here to be passive
-        old = self.get(state, dict_)
+        
+        if self.active_history:
+            old = self.get(state, dict_)
+        else:
+            # this would be the "laziest" approach,
+            # however it breaks currently expected backref
+            # behavior
+            #old = dict_.get(self.key, None)
+            # instead, use the "passive" setting, which
+            # is only going to be PASSIVE_NOCALLABLES if it
+            # came from a backref
+            old = self.get(state, dict_, passive=passive)
+            if old is PASSIVE_NORESULT:
+                old = None
+             
         value = self.fire_replace_event(state, dict_, value, old, initiator)
         dict_[self.key] = value
 
@@ -707,7 +722,7 @@ class CollectionAttributeImpl(AttributeImpl):
         else:
             collection.remove_with_event(value, initiator)
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         """Set a value on the given object.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -808,6 +823,9 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
     are two objects which contain scalar references to each other.
 
     """
+    
+    active_history = False
+    
     def __init__(self, key):
         self.key = key
 
index 70243291dc3e279bdae383bbe02b7ae1b157f213..0bc7bab24ee0d0dc137450288adfc79746df0c73 100644 (file)
@@ -100,7 +100,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         dict_[self.key] = True
         return state.committed_state[self.key]
 
-    def set(self, state, dict_, value, initiator):
+    def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF):
         if initiator is self:
             return
 
index 9a9ebfcab2b0ef85e6707373f2d016b1a252ce2d..5dffa6774a5efe898bd66fa29990d9915201081f 100644 (file)
@@ -783,6 +783,11 @@ class AttributeExtension(object):
 
     """
 
+    active_history = True
+    """indicates that the set() method would like to receive the 'old' value,
+    even if it means firing lazy callables.
+    """
+    
     def append(self, state, value, initiator):
         """Receive a collection append event.
 
index ebb576a71601602f35a00a2c1bc67b40baef661e..df21b24acc1a303dfcafd9447301097949d9cacd 100644 (file)
@@ -33,7 +33,7 @@ def _register_attribute(strategy, mapper, useobject,
 
     prop = strategy.parent_property
     attribute_ext = util.to_list(prop.extension) or []
-
+        
     if useobject and prop.single_parent:
         attribute_ext.append(_SingleParentValidator(prop))
 
@@ -370,13 +370,16 @@ class LazyLoader(AbstractRelationLoader):
     def init_class_attribute(self, mapper):
         self.is_class_level = True
         
-        
+        # MANYTOONE currently only needs the "old" value for delete-orphan
+        # cascades.  the required _SingleParentValidator will enable active_history
+        # in that case.  otherwise we don't need the "old" value during backref operations.
         _register_attribute(self, 
                 mapper,
                 useobject=True,
                 callable_=self._class_level_loader,
                 uselist = self.parent_property.uselist,
                 typecallable = self.parent_property.collection_class,
+                active_history = self.parent_property.direction is not interfaces.MANYTOONE, 
                 )
 
     def lazy_clause(self, state, reverse_direction=False, alias_secondary=False, adapt_source=None):
index 6a3d0ca5912ce11c4e1171265aa63b91c0c0072c..d650f65a5456cdd62d88c13461f256cc548ea816 100644 (file)
@@ -33,7 +33,9 @@ class UOWEventHandler(interfaces.AttributeExtension):
     """An event handler added to all relation attributes which handles
     session cascade operations.
     """
-
+    
+    active_history = False
+    
     def __init__(self, key):
         self.key = key
 
diff --git a/test/orm/test_backref_mutations.py b/test/orm/test_backref_mutations.py
new file mode 100644 (file)
index 0000000..1ecf027
--- /dev/null
@@ -0,0 +1,474 @@
+"""
+a series of tests which assert the behavior of moving objects between collections
+and scalar attributes resulting in the expected state w.r.t. backrefs, add/remove
+events, etc.
+
+there's a particular focus on collections that have "uselist=False", since in these
+cases the re-assignment of an attribute means the previous owner needs an
+UPDATE in the database.
+
+"""
+
+from sqlalchemy.test.testing import assert_raises, assert_raises_message
+from sqlalchemy import Integer, String, ForeignKey, Sequence, exc as sa_exc
+from sqlalchemy.test.schema import Table
+from sqlalchemy.test.schema import Column
+from sqlalchemy.orm import mapper, relation, create_session, class_mapper, backref, sessionmaker
+from sqlalchemy.orm import attributes, exc as orm_exc
+from sqlalchemy.test import testing
+from sqlalchemy.test.testing import eq_
+from test.orm import _base, _fixtures
+
+class O2MCollectionTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Address, addresses)
+        mapper(User, users, properties = dict(
+            addresses = relation(Address, backref="user"),
+        ))
+
+    @testing.resolve_artifact_names
+    def test_collection_move_hitslazy(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        a2 = Address(email_address="address2")
+        a3 = Address(email_address="address3")
+        u1= User(name='jack', addresses=[a1, a2, a3])
+        u2= User(name='ed')
+        sess.add_all([u1, a1, a2, a3])
+        sess.commit()
+        
+        #u1.addresses
+        
+        def go():
+            u2.addresses.append(a1)
+            u2.addresses.append(a2)
+            u2.addresses.append(a3)
+        self.assert_sql_count(testing.db, go, 0)
+        
+    @testing.resolve_artifact_names
+    def test_collection_move_preloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', addresses=[a1])
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.addresses collection
+        u1.addresses
+
+        u2.addresses.append(a1)
+
+        # backref fires
+        assert a1.user is u2
+
+        # doesn't extend to the previous collection tho,
+        # which was already loaded.
+        # flushing at this point means its anyone's guess.
+        assert a1 in u1.addresses
+        assert a1 in u2.addresses
+
+    @testing.resolve_artifact_names
+    def test_collection_move_notloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', addresses=[a1])
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        u2.addresses.append(a1)
+
+        # backref fires
+        assert a1.user is u2
+        
+        # u1.addresses wasn't loaded,
+        # so when it loads its correct
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+    @testing.resolve_artifact_names
+    def test_collection_move_commitfirst(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', addresses=[a1])
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.addresses collection
+        u1.addresses
+
+        u2.addresses.append(a1)
+
+        # backref fires
+        assert a1.user is u2
+        
+        # everything expires, no changes in 
+        # u1.addresses, so all is fine
+        sess.commit()
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_preloaded(self):
+        sess = sessionmaker()()
+
+        u1 = User(name='jack')
+        u2 = User(name='ed')
+        a1 = Address(email_address='a1')
+        a1.user = u1
+        sess.add_all([u1, u2, a1])
+        sess.commit()
+
+        # u1.addresses is loaded
+        u1.addresses
+
+        # direct set - the fetching of the 
+        # "old" u1 here allows the backref
+        # to remove it from the addresses collection
+        a1.user = u2
+
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+        
+    @testing.resolve_artifact_names
+    def test_scalar_move_notloaded(self):
+        sess = sessionmaker()()
+
+        u1 = User(name='jack')
+        u2 = User(name='ed')
+        a1 = Address(email_address='a1')
+        a1.user = u1
+        sess.add_all([u1, u2, a1])
+        sess.commit()
+
+        # direct set - the fetching of the 
+        # "old" u1 here allows the backref
+        # to remove it from the addresses collection
+        a1.user = u2
+
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_commitfirst(self):
+        sess = sessionmaker()()
+
+        u1 = User(name='jack')
+        u2 = User(name='ed')
+        a1 = Address(email_address='a1')
+        a1.user = u1
+        sess.add_all([u1, u2, a1])
+        sess.commit()
+
+        # u1.addresses is loaded
+        u1.addresses
+
+        # direct set - the fetching of the 
+        # "old" u1 here allows the backref
+        # to remove it from the addresses collection
+        a1.user = u2
+        
+        sess.commit()
+        assert a1 not in u1.addresses
+        assert a1 in u2.addresses
+
+class O2OScalarBackrefMoveTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Address, addresses)
+        mapper(User, users, properties = {
+            'address':relation(Address, backref=backref("user"), uselist=False)
+        })
+
+    @testing.resolve_artifact_names
+    def test_collection_move_preloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.address
+        u1.address
+
+        # reassign
+        u2.address = a1
+        assert u2.address is a1
+
+        # backref fires
+        assert a1.user is u2
+
+        # doesn't extend to the previous attribute tho.
+        # flushing at this point means its anyone's guess.
+        assert u1.address is a1
+        assert u2.address is a1
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_preloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        a2 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        sess.add_all([u1, a1, a2])
+        sess.commit() # everything is expired
+
+        # load a1.user
+        a1.user
+        
+        # reassign
+        a2.user = u1
+
+        # backref fires
+        assert u1.address is a2
+        
+        # stays on both sides
+        assert a1.user is u1
+        assert a2.user is u1
+
+    @testing.resolve_artifact_names
+    def test_collection_move_notloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # reassign
+        u2.address = a1
+        assert u2.address is a1
+
+        # backref fires
+        assert a1.user is u2
+        
+        # u1.address loads now after a flush
+        assert u1.address is None
+        assert u2.address is a1
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_notloaded(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        a2 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        sess.add_all([u1, a1, a2])
+        sess.commit() # everything is expired
+
+        # reassign
+        a2.user = u1
+
+        # backref fires
+        assert u1.address is a2
+
+        # stays on both sides
+        assert a1.user is u1
+        assert a2.user is u1
+
+    @testing.resolve_artifact_names
+    def test_collection_move_commitfirst(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.address
+        u1.address
+
+        # reassign
+        u2.address = a1
+        assert u2.address is a1
+
+        # backref fires
+        assert a1.user is u2
+
+        # the commit cancels out u1.addresses
+        # being loaded, on next access its fine.
+        sess.commit()
+        assert u1.address is None
+        assert u2.address is a1
+
+    @testing.resolve_artifact_names
+    def test_scalar_move_commitfirst(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        a2 = Address(email_address="address2")
+        u1 = User(name='jack', address=a1)
+
+        sess.add_all([u1, a1, a2])
+        sess.commit() # everything is expired
+
+        # load
+        assert a1.user is u1
+        
+        # reassign
+        a2.user = u1
+
+        # backref fires
+        assert u1.address is a2
+
+        # didnt work this way tho
+        assert a1.user is u1
+        
+        # moves appropriately after commit
+        sess.commit()
+        assert u1.address is a2
+        assert a1.user is None
+        assert a2.user is u1
+
+class O2OScalarMoveTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Address, addresses)
+        mapper(User, users, properties = {
+            'address':relation(Address, uselist=False)
+        })
+
+    @testing.resolve_artifact_names
+    def test_collection_move_commitfirst(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+
+        u2 = User(name='ed')
+        sess.add_all([u1, u2])
+        sess.commit() # everything is expired
+
+        # load u1.address
+        u1.address
+
+        # reassign
+        u2.address = a1
+        assert u2.address is a1
+
+        # the commit cancels out u1.addresses
+        # being loaded, on next access its fine.
+        sess.commit()
+        assert u1.address is None
+        assert u2.address is a1
+
+class O2OScalarOrphanTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Address, addresses)
+        mapper(User, users, properties = {
+            'address':relation(Address, uselist=False, 
+                backref=backref('user', single_parent=True, cascade="all, delete-orphan"))
+        })
+
+    @testing.resolve_artifact_names
+    def test_m2o_event(self):
+        sess = sessionmaker()()
+        a1 = Address(email_address="address1")
+        u1 = User(name='jack', address=a1)
+        
+        sess.add(u1)
+        sess.commit()
+        sess.expunge(u1)
+        
+        u2= User(name='ed')
+        # the _SingleParent extension sets the backref get to "active" !
+        # u1 gets loaded and deleted
+        u2.address = a1
+        sess.commit()
+        assert sess.query(User).count() == 1
+        
+    
+class M2MScalarMoveTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        mapper(Item, items, properties={
+            'keyword':relation(Keyword, secondary=item_keywords, uselist=False, backref=backref("item", uselist=False))
+        })
+        mapper(Keyword, keywords)
+    
+    @testing.resolve_artifact_names
+    def test_collection_move_preloaded(self):
+        sess = sessionmaker()()
+        
+        k1 = Keyword(name='k1')
+        i1 = Item(description='i1', keyword=k1)
+        i2 = Item(description='i2')
+
+        sess.add_all([i1, i2, k1])
+        sess.commit() # everything is expired
+        
+        # load i1.keyword
+        assert i1.keyword is k1
+        
+        i2.keyword = k1
+
+        assert k1.item is i2
+        
+        # nothing happens.
+        assert i1.keyword is k1
+        assert i2.keyword is k1
+
+    @testing.resolve_artifact_names
+    def test_collection_move_notloaded(self):
+        sess = sessionmaker()()
+
+        k1 = Keyword(name='k1')
+        i1 = Item(description='i1', keyword=k1)
+        i2 = Item(description='i2')
+
+        sess.add_all([i1, i2, k1])
+        sess.commit() # everything is expired
+
+        i2.keyword = k1
+
+        assert k1.item is i2
+
+        assert i1.keyword is None
+        assert i2.keyword is k1
+
+    @testing.resolve_artifact_names
+    def test_collection_move_commit(self):
+        sess = sessionmaker()()
+
+        k1 = Keyword(name='k1')
+        i1 = Item(description='i1', keyword=k1)
+        i2 = Item(description='i2')
+
+        sess.add_all([i1, i2, k1])
+        sess.commit() # everything is expired
+
+        # load i1.keyword
+        assert i1.keyword is k1
+
+        i2.keyword = k1
+
+        assert k1.item is i2
+
+        sess.commit()
+        assert i1.keyword is None
+        assert i2.keyword is k1