]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- session.merge() will not expire attributes on the returned
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 May 2010 20:09:48 +0000 (16:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 May 2010 20:09:48 +0000 (16:09 -0400)
instance if that instance is "pending".  [ticket:1789]

CHANGES
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
test/orm/test_merge.py

diff --git a/CHANGES b/CHANGES
index 2508a3cd052dc40955079fee495a48a238e12a23..e2d3303b1f89749a0342441cc59e2399f116a6c4 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -8,7 +8,10 @@ CHANGES
 - orm
   - Fixed regression introduced in 0.6.0 involving improper
     history accounting on mutable attributes.  [ticket:1782]
-    
+  
+  - session.merge() will not expire attributes on the returned
+    instance if that instance is "pending".  [ticket:1789]
+
 - sql
   - Fixed bug that prevented implicit RETURNING from functioning
     properly with composite primary key that contained zeroes.
index d7960406b344eb38e4e32606792f2b39b3e82aa0..92632ac8904a2a5020554f18bb792d2d52aef733 100644 (file)
@@ -19,7 +19,7 @@ from sqlalchemy.orm import (
     attributes, object_session, util as mapperutil, strategies, object_mapper
     )
 from sqlalchemy.orm.query import Query
-from sqlalchemy.orm.util import _state_has_identity, has_identity
+from sqlalchemy.orm.util import has_identity
 from sqlalchemy.orm import attributes, collections
 
 class DynaLoader(strategies.AbstractRelationshipLoader):
@@ -116,7 +116,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         collection_history = self._modified_event(state, dict_)
         new_values = list(iterable)
 
-        if _state_has_identity(state):
+        if state.has_identity:
             old_collection = list(self.get(state, dict_))
         else:
             old_collection = []
index ccbb273d56213203ff4cb246eee6f26c83868fa7..aec7794f31d8d04e2e01297b240a905843d57b63 100644 (file)
@@ -27,7 +27,7 @@ from sqlalchemy.orm.interfaces import (
     MapperProperty, EXT_CONTINUE, PropComparator
     )
 from sqlalchemy.orm.util import (
-     ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _state_has_identity,
+     ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, 
      _state_mapper, class_mapper, instance_str, state_str,
      )
 
@@ -603,20 +603,25 @@ class Mapper(object):
                 # column is coming in after _readonly_props was initialized; check
                 # for 'readonly'
                 if hasattr(self, '_readonly_props') and \
-                    (not hasattr(col, 'table') or col.table not in self._cols_by_table):
+                    (not hasattr(col, 'table') or 
+                    col.table not in self._cols_by_table):
                         self._readonly_props.add(prop)
 
             else:
-                # if column is coming in after _cols_by_table was initialized, ensure the col is in the
-                # right set
-                if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]:
+                # if column is coming in after _cols_by_table was 
+                # initialized, ensure the col is in the right set
+                if hasattr(self, '_cols_by_table') and \
+                                    col.table in self._cols_by_table and \
+                                    col not in self._cols_by_table[col.table]:
                     self._cols_by_table[col.table].add(col)
             
             # if this ColumnProperty represents the "polymorphic discriminator"
             # column, mark it.  We'll need this when rendering columns
             # in SELECT statements.
             if not hasattr(prop, '_is_polymorphic_discriminator'):
-                prop._is_polymorphic_discriminator = (col is self.polymorphic_on or prop.columns[0] is self.polymorphic_on)
+                prop._is_polymorphic_discriminator = \
+                                    (col is self.polymorphic_on or
+                                    prop.columns[0] is self.polymorphic_on)
                 
             self.columns[key] = col
             for col in prop.columns:
@@ -801,7 +806,7 @@ class Mapper(object):
         for mapper in self.iterate_to_root():
             for (key, cls) in mapper.delete_orphans:
                 if attributes.manager_of_class(cls).has_parent(
-                    state, key, optimistic=_state_has_identity(state)):
+                    state, key, optimistic=state.has_identity):
                     return False
             o = o or bool(mapper.delete_orphans)
         return o
@@ -1326,7 +1331,7 @@ class Mapper(object):
                 connection_callable(self, state.obj()) or \
                 connection
 
-            has_identity = _state_has_identity(state)
+            has_identity = state.has_identity
             mapper = _state_mapper(state)
             instance_key = state.key or mapper._identity_key_from_state(state)
 
@@ -1525,7 +1530,8 @@ class Mapper(object):
                         c = connection.execute(statement.values(value_params), params)
                         
                     mapper._postfetch(uowtransaction, table, 
-                                        state, state_dict, c, c.last_updated_params(), value_params)
+                                        state, state_dict, c, 
+                                        c.last_updated_params(), value_params)
 
                     rows += c.rowcount
 
@@ -1562,12 +1568,14 @@ class Mapper(object):
                     if primary_key is not None:
                         # set primary key attributes
                         for i, col in enumerate(mapper._pks_by_table[table]):
-                            if mapper._get_state_attr_by_column(state, state_dict, col) is None and \
-                                                                len(primary_key) > i:
-                                mapper._set_state_attr_by_column(state, state_dict, col, primary_key[i])
+                            if mapper._get_state_attr_by_column(state, state_dict, col) \
+                                        is None and len(primary_key) > i:
+                                mapper._set_state_attr_by_column(state, state_dict, col,
+                                                                    primary_key[i])
                                 
                     mapper._postfetch(uowtransaction, table, 
-                                        state, state_dict, c, c.last_inserted_params(), value_params)
+                                        state, state_dict, c, c.last_inserted_params(),
+                                        value_params)
 
         if not postupdate:
             for state, state_dict, mapper, connection, has_identity, \
@@ -1577,7 +1585,7 @@ class Mapper(object):
                 readonly = state.unmodified.intersection(
                     p.key for p in mapper._readonly_props
                 )
-
+                
                 if readonly:
                     _expire_state(state, state.dict, readonly)
 
@@ -1675,7 +1683,7 @@ class Mapper(object):
             tups.append((state, 
                     state.dict,
                     _state_mapper(state), 
-                    _state_has_identity(state),
+                    state.has_identity,
                     conn))
 
         table_to_mapper = self._sorted_tables
@@ -2070,8 +2078,8 @@ def _load_scalar_attributes(state, attribute_names):
         raise orm_exc.DetachedInstanceError("Instance %s is not bound to a Session; "
                     "attribute refresh operation cannot proceed" % (state_str(state)))
 
-    has_key = _state_has_identity(state)
-
+    has_key = state.has_identity
+    
     result = False
     if mapper.inherits and not mapper.concrete:
         statement = mapper._optimized_get_statement(state, attribute_names)
@@ -2086,6 +2094,7 @@ def _load_scalar_attributes(state, attribute_names):
             identity_key = state.key
         else:
             identity_key = mapper._identity_key_from_state(state)
+        
         result = session.query(mapper)._get(
                                             identity_key, 
                                             refresh_state=state, 
@@ -2094,4 +2103,6 @@ def _load_scalar_attributes(state, attribute_names):
     # if instance is pending, a refresh operation 
     # may not complete (even if PK attributes are assigned)
     if has_key and result is None:
-        raise orm_exc.ObjectDeletedError("Instance '%s' has been deleted." % state_str(state))
+        raise orm_exc.ObjectDeletedError(
+                            "Instance '%s' has been deleted." % 
+                            state_str(state))
index 41a8877bfaf1a2d9a036597c26c76a201e7eb7ec..50a8a1084a58ded20967cca0e53084642e13d0c7 100644 (file)
@@ -97,18 +97,23 @@ class ColumnProperty(StrategizedProperty):
                                        self.columns[0], self.key))
 
     def copy(self):
-        return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
+        return ColumnProperty(
+                        deferred=self.deferred, 
+                        group=self.group, 
+                        *self.columns)
 
     def _getattr(self, state, dict_, column):
         return state.get_impl(self.key).get(state, dict_)
 
     def _getcommitted(self, state, dict_, column, passive=False):
-        return state.get_impl(self.key).get_committed_value(state, dict_, passive=passive)
+        return state.get_impl(self.key).\
+                    get_committed_value(state, dict_, passive=passive)
 
     def _setattr(self, state, dict_, value, column):
         state.get_impl(self.key).set(state, dict_, value, None)
 
-    def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive):
+    def merge(self, session, source_state, source_dict, dest_state, 
+                                dest_dict, load, _recursive):
         if self.key in source_dict:
             value = source_dict[self.key]
         
@@ -118,7 +123,7 @@ class ColumnProperty(StrategizedProperty):
                 impl = dest_state.get_impl(self.key)
                 impl.set(dest_state, dest_dict, value, None)
         else:
-            if self.key not in dest_dict:
+            if dest_state.has_identity and self.key not in dest_dict:
                 dest_state.expire_attributes(dest_dict, [self.key])
                 
     def get_col_value(self, column, value):
@@ -130,7 +135,9 @@ class ColumnProperty(StrategizedProperty):
             if self.adapter:
                 return self.adapter(self.prop.columns[0])
             else:
-                return self.prop.columns[0]._annotate({"parententity": self.mapper, "parentmapper":self.mapper})
+                return self.prop.columns[0]._annotate({
+                                                "parententity": self.mapper,
+                                                "parentmapper":self.mapper})
                 
         def operate(self, op, *other, **kwargs):
             return op(self.__clause_element__(), *other, **kwargs)
index 42b1b3cb5447e7206f46831697aff142cc1c2452..713cd8c3d2d826cea177528e26c9c5e05ca73e3a 100644 (file)
@@ -17,7 +17,7 @@ from sqlalchemy.orm import (
 from sqlalchemy.orm.util import object_mapper as _object_mapper
 from sqlalchemy.orm.util import class_mapper as _class_mapper
 from sqlalchemy.orm.util import (
-    _class_to_mapper, _state_has_identity, _state_mapper,
+    _class_to_mapper, _state_mapper,
     )
 from sqlalchemy.orm.mapper import Mapper, _none_set
 from sqlalchemy.orm.unitofwork import UOWTransaction
@@ -1018,9 +1018,9 @@ class Session(object):
             if state.key is None:
                 state.key = instance_key
             elif state.key != instance_key:
-                # primary key switch.
-                # use discard() in case another state has already replaced this
-                # one in the identity map (see test/orm/test_naturalpks.py ReversePKsTest)
+                # primary key switch. use discard() in case another 
+                # state has already replaced this one in the identity 
+                # map (see test/orm/test_naturalpks.py ReversePKsTest)
                 self.identity_map.discard(state)
                 state.key = instance_key
             
@@ -1396,7 +1396,7 @@ class Session(object):
             
         for state in proc:
             is_orphan = _state_mapper(state)._is_orphan(state)
-            if is_orphan and not _state_has_identity(state):
+            if is_orphan and not state.has_identity:
                 path = ", nor ".join(
                     ["any parent '%s' instance "
                      "via that classes' '%s' attribute" %
index 2b43d2cacd9128070f3e656004ec57eed159726c..9ee31bff80a74be954739469eb43c15e7b76755f 100644 (file)
@@ -43,6 +43,10 @@ class InstanceState(object):
     @util.memoized_property
     def callables(self):
         return {}
+
+    @property
+    def has_identity(self):
+        return bool(self.key)
         
     def detach(self):
         if self.session_id:
@@ -219,6 +223,14 @@ class InstanceState(object):
         If all attributes are expired, the "expired" flag is set to True.
         
         """
+        # we would like to assert that 'self.key is not None' here, 
+        # but there are many cases where the mapper will expire
+        # a newly persisted instance within the flush, before the
+        # key is assigned, and even cases where the attribute refresh
+        # occurs fully, within the flush(), before this key is assigned.
+        # the key is assigned late within the flush() to assist in
+        # "key switch" bookkeeping scenarios.
+        
         if attribute_names is None:
             attribute_names = self.manager.keys()
             self.expired = True
index 96aac7d3a67659c4a12542b5d234cd4efab307e8..5b5dd312d82cfd1fce63e4b601e50f1d5940de21 100644 (file)
@@ -239,7 +239,7 @@ class DeferredColumnLoader(LoaderStrategy):
                                         path, adapter, **kwargs)
     
     def _class_level_loader(self, state):
-        if not mapperutil._state_has_identity(state):
+        if not state.has_identity:
             return None
             
         return LoadDeferredColumns(state, self.key)
@@ -465,7 +465,7 @@ class LazyLoader(AbstractRelationshipLoader):
         return criterion
         
     def _class_level_loader(self, state):
-        if not mapperutil._state_has_identity(state):
+        if not state.has_identity:
             return None
 
         return LoadLazyAttribute(state, self.key)
index 03f8c0c3758a5822409fbac09d6fcff5fc0b604b..651b0256b1ef781493efd5ec454dbdd38ccff0a4 100644 (file)
@@ -612,10 +612,7 @@ def _class_to_mapper(class_or_mapper, compile=True):
 
 def has_identity(object):
     state = attributes.instance_state(object)
-    return _state_has_identity(state)
-
-def _state_has_identity(state):
-    return bool(state.key)
+    return state.has_identity
 
 def _is_mapped_class(cls):
     global mapperlib
index e80b92699a01c504613c19fc4ed676ed53b77ae3..d63d7e086ec0b382cf4f5189944e8d73280f45fb 100644 (file)
@@ -927,6 +927,19 @@ class MergeTest(_fixtures.FixtureTest):
         assert sess.autoflush
         sess.commit()
 
+    @testing.resolve_artifact_names
+    def test_dont_expire_pending(self):
+        """test that pending instances aren't expired during a merge."""
+        
+        mapper(User, users)
+        u = User(id=7)
+        sess = create_session(autoflush=True, autocommit=False)
+        u = sess.merge(u)
+        assert not bool(attributes.instance_state(u).expired_attributes)
+        def go():
+            eq_(u.name, None)
+        self.assert_sql_count(testing.db, go, 0)
+    
     @testing.resolve_artifact_names
     def test_option_state(self):
         """test that the merged takes on the MapperOption characteristics