]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merge() may actually work now, though we've heard that before...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Apr 2008 17:13:09 +0000 (17:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Apr 2008 17:13:09 +0000 (17:13 +0000)
- merge() uses the priamry key attributes on the object if _instance_key not present.  so merging works for instances that dont have an instnace_key, will still issue UPDATE for existing rows.
- improved collection behavior for merge() - will remove elements from a destination collection that are not in the source.
- fixed naive set-mutation issue in Select._get_display_froms
- simplified fixtures.Base a bit

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/expression.py
test/orm/merge.py
test/testlib/fixtures.py

diff --git a/CHANGES b/CHANGES
index 1fc3cda0ca1d165aa32ef582f45d36d90b7245d0..3450bb9060ed3f59f3e16c41146880a00c2408b3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -5,6 +5,25 @@ CHANGES
 0.4.5
 =====
 - orm
+    - a small change in behavior to session.merge() - existing
+      objects are checked for based on primary key attributes,
+      not necessarily _instance_key.  So the widely requested
+      capability, that:
+      
+            x = MyObject(id=1)
+            x = sess.merge(x)
+            
+      will in fact load MyObject with id #1 from the database
+      if present, is now available.  merge() still 
+      copies the state of the given object to the persistent
+      one, so an example like the above would typically have
+      copied "None" from all attributes of "x" onto the persistent 
+      copy.  These can be reverted using session.expire(x).
+    
+    - also fixed behavior in merge() whereby collection elements
+      present on the destination but not the merged collection 
+      were not being removed from the destination.
+        
     - Added a more aggressive check for "uncompiled mappers",
       helps particularly with declarative layer [ticket:995]
 
@@ -158,7 +177,13 @@ CHANGES
 
     - random() is now a generic sql function and will compile to
       the database's random implementation, if any.
-
+    
+    - fixed an issue in select() regarding its generation of 
+      FROM clauses, in rare circumstances two clauses could
+      be produced when one was intended to cancel out the
+      other.  Some ORM queries with lots of eager loads
+      might have seen this symptom.
+      
 - declarative extension
     - The "synonym" function is now directly usable with
       "declarative".  Pass in the decorated property using the
index a511c9bbb6ee34b34c37080bd04e6199ca7c859e..f57298d7cd1b6b381bb0853dfb27afcb3691d6e6 100644 (file)
@@ -569,7 +569,7 @@ class CollectionAttributeImpl(AttributeImpl):
             self.fire_remove_event(state, value, initiator)
         else:
             collection.remove_with_event(value, initiator)
-
+    
     def set(self, state, value, initiator):
         """Set a value on the given object.
 
index 970e49ea4d572bd75ac4e7649b3449e1f99f8618..d050f40d796d83cabce80c66109c76ec19cef184 100644 (file)
@@ -416,7 +416,6 @@ class PropertyLoader(StrategizedProperty):
             return
             
         if not "merge" in self.cascade:
-            # TODO: lazy callable should merge to the new instance
             dest._state.expire_attributes([self.key])
             return
 
@@ -425,15 +424,18 @@ class PropertyLoader(StrategizedProperty):
             return
         
         if self.uselist:
-            dest_list = attributes.init_collection(dest, self.key)
+            dest_list = []
             for current in instances:
                 _recursive[(current, self)] = True
                 obj = session.merge(current, entity_name=self.mapper.entity_name, dont_load=dont_load, _recursive=_recursive)
                 if obj is not None:
-                    if dont_load:
-                        dest_list.append_without_event(obj)
-                    else:
-                        dest_list.append_with_event(obj)
+                    dest_list.append(obj)
+            if dont_load:
+                coll = attributes.init_collection(dest, self.key)
+                for c in dest_list:
+                    coll.append_without_event(c) 
+            else:
+                getattr(dest.__class__, self.key).impl._set_iterable(dest._state, dest_list)
         else:
             current = instances[0]
             if current is not None:
index 391bc925b89cbbbd6a444b12739f719b4982430f..b7a4aa911eaf20bc9bda759a07119be28de3039a 100644 (file)
@@ -955,8 +955,10 @@ class Session(object):
         if key is None:
             if dont_load:
                 raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects.  flush() all changes on mapped instances before merging with dont_load=True.")
-            merged = attributes.new_instance(mapper.class_)
-        else:
+            key = mapper.identity_key_from_instance(instance)
+
+        merged = None
+        if key:
             if key in self.identity_map:
                 merged = self.identity_map[key]
             elif dont_load:
@@ -969,15 +971,19 @@ class Session(object):
                 self._update_impl(merged, entity_name=mapper.entity_name)
             else:
                 merged = self.get(mapper.class_, key[1])
-                if merged is None:
-                    raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(instance))
+        
+        if merged is None:
+            merged = attributes.new_instance(mapper.class_)
+            self.save(merged, entity_name=mapper.entity_name)
+            
         _recursive[instance] = merged
+        
         for prop in mapper.iterate_properties:
             prop.merge(self, instance, merged, dont_load, _recursive)
-        if key is None:
-            self.save(merged, entity_name=mapper.entity_name)
-        elif dont_load:
-            merged._state.commit_all()
+            
+        if dont_load:
+            merged._state.commit_all()  # remove any history
+
         return merged
 
     def identity_key(cls, *args, **kwargs):
index 2cd10720ab91a83ba01a962aa80c657bd1d52a78..758f75ebe7335cf3d2a2ea565b63fc65b9ae3377 100644 (file)
@@ -3096,9 +3096,9 @@ class Select(_SelectBaseMixin, FromClause):
 
         if self._froms:
             froms.update(self._froms)
-
-        for f in froms:
-            froms.difference_update(f._hide_froms)
+        
+        toremove = itertools.chain(*[f._hide_froms for f in froms])
+        froms.difference_update(toremove)
 
         if len(froms) > 1 or self.__correlate:
             if self.__correlate:
index b9a2472895d68bbbc99cd8e23f6815c4d497088e..fd61ccc28c4548bbfc75db94d352a259f5b3ac95 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import *
 from sqlalchemy import exceptions
 from sqlalchemy.orm import *
 from sqlalchemy.orm import mapperlib
+from sqlalchemy.util import OrderedSet
 from testlib import *
 from testlib import fixtures
 from testlib.tables import *
@@ -12,31 +13,139 @@ class MergeTest(TestBase, AssertsExecutionResults):
     """tests session.merge() functionality"""
     def setUpAll(self):
         tables.create()
+
     def tearDownAll(self):
         tables.drop()
+
     def tearDown(self):
         clear_mappers()
         tables.delete()
-    def setUp(self):
-        pass
 
-    def test_unsaved(self):
-        """test merge of a single transient entity."""
+    def test_transient_to_pending(self):
+        class User(fixtures.Base):
+            pass
         mapper(User, users)
         sess = create_session()
 
-        u = User()
-        u.user_id = 7
-        u.user_name = "fred"
+        u = User(user_id=7, user_name='fred')
         u2 = sess.merge(u)
         assert u2 in sess
-        assert u2.user_id == 7
-        assert u2.user_name == 'fred'
+        self.assertEquals(u2, User(user_id=7, user_name='fred'))
         sess.flush()
         sess.clear()
-        u2 = sess.query(User).get(7)
-        assert u2.user_name == 'fred'
+        self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred'))
+    
+    def test_transient_to_pending_collection(self):
+        class User(fixtures.Base):
+            pass
+        class Address(fixtures.Base):
+            pass
+        mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+        mapper(Address, addresses)
 
+        u = User(user_id=7, user_name='fred', addresses=OrderedSet([
+            Address(address_id=1, email_address='fred1'),
+            Address(address_id=2, email_address='fred2'),
+        ]))
+        sess = create_session()
+        sess.merge(u)
+        sess.flush()
+        sess.clear()
+
+        self.assertEquals(sess.query(User).one(), 
+            User(user_id=7, user_name='fred', addresses=OrderedSet([
+                Address(address_id=1, email_address='fred1'),
+                Address(address_id=2, email_address='fred2'),
+            ]))
+        )
+        
+    def test_transient_to_persistent(self):
+        class User(fixtures.Base):
+            pass
+        mapper(User, users)
+        sess = create_session()
+        u = User(user_id=7, user_name='fred')
+        sess.save(u)
+        sess.flush()
+        sess.clear()
+        
+        u2 = User(user_id=7, user_name='fred jones')
+        u2 = sess.merge(u2)
+        sess.flush()
+        sess.clear()
+        self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred jones'))
+        
+    def test_transient_to_persistent_collection(self):
+        class User(fixtures.Base):
+            pass
+        class Address(fixtures.Base):
+            pass
+        mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+        mapper(Address, addresses)
+        
+        u = User(user_id=7, user_name='fred', addresses=OrderedSet([
+            Address(address_id=1, email_address='fred1'),
+            Address(address_id=2, email_address='fred2'),
+        ]))
+        sess = create_session()
+        sess.save(u)
+        sess.flush()
+        sess.clear()
+        
+        u = User(user_id=7, user_name='fred', addresses=OrderedSet([
+            Address(address_id=3, email_address='fred3'),
+            Address(address_id=4, email_address='fred4'),
+        ]))
+        
+        u = sess.merge(u)
+        self.assertEquals(u, 
+            User(user_id=7, user_name='fred', addresses=OrderedSet([
+                Address(address_id=3, email_address='fred3'),
+                Address(address_id=4, email_address='fred4'),
+            ]))
+        )
+        sess.flush()
+        sess.clear()
+        self.assertEquals(sess.query(User).one(), 
+            User(user_id=7, user_name='fred', addresses=OrderedSet([
+                Address(address_id=3, email_address='fred3'),
+                Address(address_id=4, email_address='fred4'),
+            ]))
+        )
+        
+    def test_detached_to_persistent_collection(self):
+        class User(fixtures.Base):
+            pass
+        class Address(fixtures.Base):
+            pass
+        mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+        mapper(Address, addresses)
+        
+        a = Address(address_id=1, email_address='fred1')
+        u = User(user_id=7, user_name='fred', addresses=OrderedSet([
+            a,
+            Address(address_id=2, email_address='fred2'),
+        ]))
+        sess = create_session()
+        sess.save(u)
+        sess.flush()
+        sess.clear()
+        
+        u.user_name='fred jones'
+        u.addresses.add(Address(address_id=3, email_address='fred3'))
+        u.addresses.remove(a)
+        
+        u = sess.merge(u)
+        sess.flush()
+        sess.clear()
+        
+        self.assertEquals(sess.query(User).first(), 
+            User(user_id=7, user_name='fred jones', addresses=OrderedSet([
+                Address(address_id=2, email_address='fred2'),
+                Address(address_id=3, email_address='fred3'),
+            ]))
+        )
+        
     def test_unsaved_cascade(self):
         """test merge of a transient entity with two child transient entities, with a bidirectional relation."""
         
@@ -63,18 +172,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
         u2 = sess.query(User).get(7)
         self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
 
-    def test_transient_dontload(self):
-        mapper(User, users)
-        
-        sess = create_session()
-        u = User()
-        try:
-            u2 = sess.merge(u, dont_load=True)
-            assert False
-        except exceptions.InvalidRequestError, err:
-            assert str(err) == "merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects.  flush() all changes on mapped instances before merging with dont_load=True."
-    
-    def test_saved_cascade(self):
+    def test_attribute_cascade(self):
         """test merge of a persistent entity with two child persistent entities."""
 
         class User(fixtures.Base):
@@ -132,7 +230,6 @@ class MergeTest(TestBase, AssertsExecutionResults):
 
         # test with "dontload" merge
         sess5 = create_session()
-        print "------------------"
         u = sess5.merge(u, dont_load=True)
         assert len(u.addresses)
         for a in u.addresses:
@@ -158,7 +255,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
         assert u2.user_name == 'fred2'
         assert u2.addresses[1].email_address == 'afafds'
 
-    def test_saved_cascade_2(self):
+    def test_one_to_many_cascade(self):
 
         mapper(Order, orders, properties={
             'items':relation(mapper(Item, orderitems))
@@ -197,8 +294,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
         sess2.merge(o)
         assert o2.customer.user_name == 'also fred'
 
-    def test_saved_cascade_3(self):
-        """test merge of a persistent entity with one_to_one relationship"""
+    def test_one_to_one_cascade(self):
 
         mapper(User, users, properties={
             'address':relation(mapper(Address, addresses),uselist = False)
@@ -221,6 +317,14 @@ class MergeTest(TestBase, AssertsExecutionResults):
 
         u3 = sess.merge(u2)
     
+    def test_transient_dontload(self):
+        mapper(User, users)
+
+        sess = create_session()
+        u = User()
+        self.assertRaisesMessage(exceptions.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+
+
     def test_dontload_with_backrefs(self):
         """test that dontload populates relations in both directions without requiring a load"""
         
@@ -254,8 +358,8 @@ class MergeTest(TestBase, AssertsExecutionResults):
         self.assertEquals(u.addresses[1].user, User(user_id=7, user_name='fred'))
         
         
-    def test_noload_with_eager(self):
-        """this test illustrates that with noload=True, we can't just
+    def test_dontload_with_eager(self):
+        """this test illustrates that with dont_load=True, we can't just
         copy the committed_state of the merged instance over; since it references collection objects
         which themselves are to be merged.  This committed_state would instead need to be piecemeal
         'converted' to represent the correct objects.
@@ -286,8 +390,8 @@ class MergeTest(TestBase, AssertsExecutionResults):
             sess3.flush()
         self.assert_sql_count(testing.db, go, 0)
 
-    def test_noload_disallows_dirty(self):
-        """noload doesnt support 'dirty' objects right now (see test_noload_with_eager()).
+    def test_dont_load_disallows_dirty(self):
+        """dont_load doesnt support 'dirty' objects right now (see test_dont_load_with_eager()).
         Therefore lets assert it."""
 
         mapper(User, users)
@@ -315,8 +419,8 @@ class MergeTest(TestBase, AssertsExecutionResults):
             sess3.flush()
         self.assert_sql_count(testing.db, go, 0)
 
-    def test_noload_sets_entityname(self):
-        """test that a noload-merged entity has entity_name set, has_mapper() passes, and lazyloads work"""
+    def test_dont_load_sets_entityname(self):
+        """test that a dont_load-merged entity has entity_name set, has_mapper() passes, and lazyloads work"""
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses),uselist = True)
         })
@@ -346,7 +450,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
             assert len(u2.addresses) == 1
         self.assert_sql_count(testing.db, go, 1)
 
-    def test_noload_sets_backrefs(self):
+    def test_dont_load_sets_backrefs(self):
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses),backref='user')
         })
@@ -370,10 +474,10 @@ class MergeTest(TestBase, AssertsExecutionResults):
             assert u2.addresses[0].user is u2
         self.assert_sql_count(testing.db, go, 0)
 
-    def test_noload_preserves_parents(self):
-        """test that merge with noload does not trigger a 'delete-orphan' operation.
+    def test_dont_load_preserves_parents(self):
+        """test that merge with dont_load does not trigger a 'delete-orphan' operation.
 
-        merge with noload sets attributes without using events.  this means the
+        merge with dont_load sets attributes without using events.  this means the
         'hasparent' flag is not propagated to the newly merged instance.  in fact this
         works out OK, because the '_state.parents' collection on the newly
         merged instance is empty; since the mapper doesn't see an active 'False' setting
index bbd27a39f5e16ee3e5a61746d0b61f89f4402cda..a1aa717e9364e1ea90c5ab857a8a4c2d74440c00 100644 (file)
@@ -54,19 +54,11 @@ class Base(object):
                     except AttributeError:
                         #print "b class does not have attribute named '%s'" % attr
                         return False
-                    #print "other:", battr
-                    if not hasattr(value, '__len__'):
-                        value = list(iter(value))
-                        battr = list(iter(battr))
-                    if len(value) != len(battr):
-                        #print "Length of collection '%s' does not match that of b" % attr
-                        return False
-                    for (us, them) in zip(value, battr):
-                        if us != them:
-                            #print "1. Attribute named '%s' does not match b" % attr
-                            return False
-                    else:
+                    
+                    if list(value) == list(battr):
                         continue
+                    else:
+                        return False
                 else:
                     if value is not None:
                         if value != getattr(b, attr, None):