]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- making progress with session.merge() as well as combining its
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Apr 2007 02:16:57 +0000 (02:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Apr 2007 02:16:57 +0000 (02:16 +0000)
usage with entity_name [ticket:543]

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py

diff --git a/CHANGES b/CHANGES
index d285509b0bfbb8df038c8b78cca2b177ed35ca94..5b65adfe89be734e1a22fc5e1e55cb0e5fc65557 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -71,6 +71,8 @@
       methods on them during lazy loads)
     - fix to many-to-many relationships targeting polymorphic mappers
       [ticket:533]
+    - making progress with session.merge() as well as combining its
+      usage with entity_name [ticket:543]
 - sqlite:
     - removed silly behavior where sqlite would reflect UNIQUE indexes
       as part of the primary key (?!)
index 1b6203e063d239fba3e53a974115b4ebe4f72f88..51fb8fb2eb510186c2c7accb52c988a5e2bbf7b3 100644 (file)
@@ -62,9 +62,10 @@ class PropertyLoader(StrategizedProperty):
     of items that correspond to a related database table.
     """
 
-    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True):
+    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True):
         self.uselist = uselist
         self.argument = argument
+        self.entity_name = entity_name
         self.secondary = secondary
         self.primaryjoin = primaryjoin
         self.secondaryjoin = secondaryjoin
@@ -120,24 +121,24 @@ class PropertyLoader(StrategizedProperty):
         return str(self.parent.class_.__name__) + "." + self.key + " (" + str(self.mapper.class_.__name__)  + ")"
 
     def merge(self, session, source, dest, _recursive):
-        if not "merge" in self.cascade or source in _recursive:
+        if not "merge" in self.cascade or self.mapper in _recursive:
             return
-        _recursive.add(source)
-        try:
-            childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
-            if childlist is None:
-                return
-            if self.uselist:
-                # sets a blank list according to the correct list class
-                dest_list = getattr(self.parent.class_, self.key).initialize(dest)
-                for current in list(childlist):
-                    dest_list.append(session.merge(current, _recursive=_recursive))
-            else:
-                current = list(childlist)[0]
-                if current is not None:
-                    setattr(dest, self.key, session.merge(current, _recursive=_recursive))
-        finally:
-            _recursive.remove(source)
+        childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
+        if childlist is None:
+            return
+        if self.uselist:
+            # sets a blank list according to the correct list class
+            dest_list = getattr(self.parent.class_, self.key).initialize(dest)
+            for current in list(childlist):
+                obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive)
+                if obj is not None:
+                    dest_list.append(obj)
+        else:
+            current = list(childlist)[0]
+            if current is not None:
+                obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive)
+                if obj is not None:
+                    setattr(dest, self.key, obj)
 
     def cascade_iterator(self, type, object, recursive, halt_on=None):
         if not type in self.cascade:
@@ -188,7 +189,7 @@ class PropertyLoader(StrategizedProperty):
 
     def _determine_targets(self):
         if isinstance(self.argument, type):
-            self.mapper = mapper.class_mapper(self.argument, compile=False)._check_compile()
+            self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)._check_compile()
         elif isinstance(self.argument, mapper.Mapper):
             self.mapper = self.argument._check_compile()
         else:
@@ -199,7 +200,7 @@ class PropertyLoader(StrategizedProperty):
 
         if self.association is not None:
             if isinstance(self.association, type):
-                self.association = mapper.class_mapper(self.association, compile=False)._check_compile()
+                self.association = mapper.class_mapper(self.association, entity_name=self.entity_name, compile=False)._check_compile()
 
         self.target = self.mapper.mapped_table
         self.select_mapper = self.mapper.get_select_mapper()
index 35b35201d894b53a73eb25797227a4118d24214a..1880e6062c73737995fed1a1e78ff0ac5c86bb8f 100644 (file)
@@ -446,7 +446,7 @@ class Session(object):
         for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)):
             self.uow.register_deleted(c)
 
-    def merge(self, object,entity_name=None, _recursive=None):
+    def merge(self, object, entity_name=None, _recursive=None):
         """Copy the state of the given `object` onto the persistent
         object with the same identifier.
 
@@ -462,21 +462,31 @@ class Session(object):
 
         if _recursive is None:
             _recursive = util.Set()
-        mapper = _object_mapper(object, entity_name=entity_name)
-        key = getattr(object, '_instance_key', None)
-        if key is None:
-            merged = mapper._create_instance(self)
+        if entity_name is not None:
+            mapper = _class_mapper(object.__class__, entity_name=entity_name)
         else:
-            if key in self.identity_map:
-                merged = self.identity_map[key]
+            mapper = _object_mapper(object)
+        if mapper in _recursive or object in _recursive:
+            return None
+        _recursive.add(mapper)
+        _recursive.add(object)
+        try:
+            key = getattr(object, '_instance_key', None)
+            if key is None:
+                merged = mapper._create_instance(self)
             else:
-                merged = self.get(mapper.class_, key[1])
-        for prop in mapper.props.values():
-            prop.merge(self, object, merged, _recursive)
-        if key is None:
-            self.save(merged)
-        return merged
-
+                if key in self.identity_map:
+                    merged = self.identity_map[key]
+                else:
+                    merged = self.get(mapper.class_, key[1])
+            for prop in mapper.props.values():
+                prop.merge(self, object, merged, _recursive)
+            if key is None:
+                self.save(merged, entity_name=mapper.entity_name)
+            return merged
+        finally:
+            _recursive.remove(mapper)
+            
     def identity_key(self, *args, **kwargs):
         """Get an identity key.