]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- basic idea of "session.merge()" actually implemented. needs more testing.
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Jan 2007 01:39:15 +0000 (01:39 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Jan 2007 01:39:15 +0000 (01:39 +0000)
CHANGES
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
test/orm/alltests.py
test/orm/merge.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index 29b226a2e4e30b8be9e3fa68ef32a4ae6f243503..7ab2c0714581379f41063f3767d1833493808fa6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -5,15 +5,7 @@
   completely illiterate, but its definitely sub-optimal to "ensure" which is
   non-ambiguous.
 - sql:
-  - postgres no longer uses client-side cursors, uses more efficient server side 
-  cursors via apparently undocumented psycopg2 behavior recently discovered on the 
-  mailing list.  disable it via create_engine('postgres://', client_side_cursors=True)
-  - mysql is inconsistent with what kinds of quotes it uses in foreign keys during a
-  SHOW CREATE TABLE, reflection updated to accomodate for all three styles [ticket:420]
   - added "fetchmany()" support to ResultProxy
-  - added "BIGSERIAL" support for postgres table with PGBigInteger/autoincrement
-  - fixes to postgres reflection to better handle when schema names are present;
-  thanks to jason (at) ncsmags.com [ticket:402]
   - fix to correlation of subqueries when the column list of the select statement
   is constructed with individual calls to append_column(); this fixes an ORM
   bug whereby nested select statements were not getting correlated with the 
   - the "op()" function is now treated as an "operation", rather than a "comparison".
   the difference is, an operation produces a BinaryExpression from which further operations
   can occur whereas comparison produces the more restrictive BooleanExpression
+- postgres
+  - postgres no longer uses client-side cursors, uses more efficient server side 
+    cursors via apparently undocumented psycopg2 behavior recently discovered on the 
+    mailing list.  disable it via create_engine('postgres://', client_side_cursors=True)
+  - added "BIGSERIAL" support for postgres table with PGBigInteger/autoincrement
+  - fixes to postgres reflection to better handle when schema names are present;
+    thanks to jason (at) ncsmags.com [ticket:402]
+- mysql
+  - mysql is inconsistent with what kinds of quotes it uses in foreign keys during a
+  SHOW CREATE TABLE, reflection updated to accomodate for all three styles [ticket:420]
 - firebird:
   - order of constraint creation puts primary key first before all other constraints;
   required for firebird, not a bad idea for others [ticket:408]
@@ -36,6 +38,7 @@
   initial compilation step does modify internal state significantly, and this step usually
   occurs not at module-level initialization time (unless you call compile()) but at first-request 
   time
+  - basic idea of "session.merge()" actually implemented.  needs more testing.
   - added "compile_mappers()" function as a shortcut to compiling all mappers
   - fix to MapperExtension create_instance so that entity_name properly associated
 with new instance
@@ -46,7 +49,7 @@ with new instance
   - fix to post_update to ensure rows are updated even for non insert/delete scenarios
   [ticket:413]
   - added an error message if you actually try to modify primary key values on an entity
-  and then flush it.  
+  and then flush it [ticket:412]
 
 0.3.3
 - string-based FROM clauses fixed, i.e. select(..., from_obj=["sometext"])
index 0c0ad2a724a10301a353835789f7e56610db1eeb..4e9fe55f4363a52d015941034c8f45ebc99fbd60 100644 (file)
@@ -55,7 +55,10 @@ class MapperProperty(object):
         This flag is used to indicate that the MapperProperty can define attribute instrumentation
         for the class at the class level (as opposed to the individual instance level.)"""
         return self.parent._is_primary_mapper()
-
+    def merge(self, session, source, dest):
+        """merges the attribute represented by this MapperProperty from source to destination object"""
+        raise NotImplementedError()
+        
 class StrategizedProperty(MapperProperty):
     """a MapperProperty which uses selectable strategies to affect loading behavior.
     There is a single default strategy selected, and alternate strategies can be selected
index cdab5464a2f6b139836ab91f750ce55cfbd7e65c..888ac0442c7cb48f746fff48ba88659e71e647b4 100644 (file)
@@ -39,6 +39,8 @@ class SynonymProperty(MapperProperty):
                     return s
                 return getattr(obj, self.name)
         setattr(self.parent.class_, self.key, SynonymProp())
+    def merge(self, session, source, dest):
+        pass
         
 class ColumnProperty(StrategizedProperty):
     """describes an object attribute that corresponds to a table column."""
@@ -60,6 +62,8 @@ class ColumnProperty(StrategizedProperty):
         setattr(object, self.key, value)
     def get_history(self, obj, passive=False):
         return sessionlib.attribute_manager.get_history(obj, self.key, passive=passive)
+    def merge(self, session, source, dest):
+        setattr(dest, self.key, getattr(source, self.key, None))
         
 ColumnProperty.logger = logging.class_logger(ColumnProperty)
         
@@ -118,6 +122,20 @@ class PropertyLoader(StrategizedProperty):
             
     def __str__(self):
         return self.__class__.__name__ + " " + str(self.parent) + "->" + self.key + "->" + str(self.mapper)
+
+    def merge(self, session, source, dest):
+        if not "merge" in self.cascade:
+            return
+        childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
+        if childlist is None:
+            return
+        if self.uselist:
+            # sets a blank list according to the correct list class
+            dest_list = getattr(self.parent.class_, self.key).initialize(dest)
+            for current in list(childlist):
+                dest_list.append(session.merge(current))
+        else:
+            setattr(dest, self.key, session.merge(current))
         
     def cascade_iterator(self, type, object, recursive, halt_on=None):
         if not type in self.cascade:
index 113341e1793f0cffaa93bf6434869cd844181f0b..82922068852534ca994960ea6ccfdec21dda5c48 100644 (file)
@@ -324,29 +324,27 @@ class Session(object):
             self.uow.register_deleted(c)
 
     def merge(self, object, entity_name=None):
-        """merge the object into a newly loaded or existing instance from this Session.
+        """copy the state of the given object onto the persistent object with the same identifier. 
         
-        note: this method is currently not completely implemented."""
-        instance = None
-        for obj in [object] + list(_object_mapper(object).cascade_iterator('merge', object)):
-            key = getattr(obj, '_instance_key', None)
-            if key is None:
-                mapper = _object_mapper(object)
-                ident = mapper.identity(object)
-                for k in ident:
-                    if k is None:
-                        raise exceptions.InvalidRequestError("Instance '%s' does not have a full set of identity values, and does not represent a saved entity in the database.  Use the add() method to add unsaved instances to this Session." % repr(obj))
-                key = mapper.identity_key(ident)
-            u = self.uow
-            if u.identity_map.has_key(key):
-                # TODO: copy the state of the given object into this one.  tricky !
-                inst = u.identity_map[key]
+        If there is no persistent instance currently associated with the session, it will be loaded. 
+        Return the persistent instance. If the given instance is unsaved, save a copy of and return it as 
+        a newly persistent instance. The given instance does not become associated with the session. 
+        This operation cascades to associated instances if the association is mapped with cascade="merge".
+        """
+        mapper = _object_mapper(object)
+        key = getattr(object, '_instance_key', None)
+        if key is None:
+            merged = mapper._create_instance(self)
+        else:
+            if key in self.identity_map:
+                merged = self.identity_map[key]
             else:
-                inst = self.get(object.__class__, key[1])
-            if obj is object:
-                instance = inst
-                
-        return instance
+                merged = self.get(mapper.class_, key[1])
+        for prop in mapper.props.values():
+            prop.merge(self, object, merged)
+        if key is None:
+            self.save(merged)
+        return merged
                     
     def _save_impl(self, object, **kwargs):
         if hasattr(object, '_instance_key'):
index 70d0bb6f4c26dc30a7d661e8cdcf093b8cfbd74f..3fbdbadfc8ecff1ebf17b0123eabfdc27ccb8f03 100644 (file)
@@ -16,6 +16,7 @@ def suite():
         'orm.cascade',
         'orm.relationships',
         'orm.association',
+        'orm.merge',
         
         'orm.cycles',
         'orm.poly_linked_list',
diff --git a/test/orm/merge.py b/test/orm/merge.py
new file mode 100644 (file)
index 0000000..1d2f24a
--- /dev/null
@@ -0,0 +1,106 @@
+from testbase import PersistTest, AssertMixin
+import testbase
+from sqlalchemy import *
+from tables import *
+import tables
+
+class MergeTest(AssertMixin):
+    """tests session.merge() functionality"""
+    def setUpAll(self):
+        tables.create()
+    def tearDownAll(self):
+        tables.drop()
+    def tearDown(self):
+        clear_mappers()
+        tables.delete()
+    def setUp(self):
+        pass
+        
+    def test_unsaved(self):
+        """test merge of a single transient entity."""
+        mapper(User, users)
+        sess = create_session()
+        
+        u = User()
+        u.user_id = 7
+        u.user_name = "fred"
+        u2 = sess.merge(u)
+        assert u2 in sess
+        assert u2.user_id == 7
+        assert u2.user_name == 'fred'
+        sess.flush()
+        sess.clear()
+        u2 = sess.query(User).get(7)
+        assert u2.user_name == 'fred'
+
+    def test_unsaved_cascade(self):
+        """test merge of a transient entity with two child transient entities."""
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), cascade="all")
+        })
+        sess = create_session()
+        u = User()
+        u.user_id = 7
+        u.user_name = "fred"
+        a1 = Address()
+        a1.email_address='foo@bar.com'
+        a2 = Address()
+        a2.email_address = 'hoho@la.com'
+        u.addresses.append(a1)
+        u.addresses.append(a2)
+        
+        u2 = sess.merge(u)
+        self.assert_result([u], User, {'user_id':7, 'user_name':'fred', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@la.com'}])})
+        self.assert_result([u2], User, {'user_id':7, 'user_name':'fred', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@la.com'}])})
+        sess.flush()
+        sess.clear()
+        u2 = sess.query(User).get(7)
+        self.assert_result([u2], User, {'user_id':7, 'user_name':'fred', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@la.com'}])})
+
+    def test_saved_cascade(self):
+        """test merge of a persistent entity with two child persistent entities."""
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), cascade="all")
+        })
+        sess = create_session()
+        
+        # set up data and save
+        u = User()
+        u.user_id = 7
+        u.user_name = "fred"
+        a1 = Address()
+        a1.email_address='foo@bar.com'
+        a2 = Address()
+        a2.email_address = 'hoho@la.com'
+        u.addresses.append(a1)
+        u.addresses.append(a2)
+        sess.save(u)
+        sess.flush()
+
+        # assert data was saved
+        sess2 = create_session()
+        u2 = sess2.query(User).get(7)
+        self.assert_result([u2], User, {'user_id':7, 'user_name':'fred', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@la.com'}])})
+        
+        # make local changes to data
+        u.user_name = 'fred2'
+        u.addresses[1].email_address = 'hoho@lalala.com'
+        
+        # new session, merge modified data into session
+        sess3 = create_session()
+        u3 = sess3.merge(u)
+        # insure local changes are pending
+        self.assert_result([u3], User, {'user_id':7, 'user_name':'fred2', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@lalala.com'}])})
+        
+        # save merged data
+        sess3.flush()
+        
+        # assert modified/merged data was saved
+        sess.clear()
+        u = sess.query(User).get(7)
+        self.assert_result([u], User, {'user_id':7, 'user_name':'fred2', 'addresses':(Address, [{'email_address':'foo@bar.com'}, {'email_address':'hoho@lalala.com'}])})
+        
+if __name__ == "__main__":    
+    testbase.main()
+
+                
\ No newline at end of file