From 6ea6673376609ce6a5e26f9f20425cffee96bcd8 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 8 May 2010 16:09:48 -0400 Subject: [PATCH] - session.merge() will not expire attributes on the returned instance if that instance is "pending". [ticket:1789] --- CHANGES | 5 +++- lib/sqlalchemy/orm/dynamic.py | 4 +-- lib/sqlalchemy/orm/mapper.py | 47 ++++++++++++++++++++------------ lib/sqlalchemy/orm/properties.py | 17 ++++++++---- lib/sqlalchemy/orm/session.py | 10 +++---- lib/sqlalchemy/orm/state.py | 12 ++++++++ lib/sqlalchemy/orm/strategies.py | 4 +-- lib/sqlalchemy/orm/util.py | 5 +--- test/orm/test_merge.py | 13 +++++++++ 9 files changed, 80 insertions(+), 37 deletions(-) diff --git a/CHANGES b/CHANGES index 2508a3cd05..e2d3303b1f 100644 --- a/CHANGES +++ b/CHANGES @@ -8,7 +8,10 @@ CHANGES - orm - Fixed regression introduced in 0.6.0 involving improper history accounting on mutable attributes. [ticket:1782] - + + - session.merge() will not expire attributes on the returned + instance if that instance is "pending". [ticket:1789] + - sql - Fixed bug that prevented implicit RETURNING from functioning properly with composite primary key that contained zeroes. diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index d7960406b3..92632ac890 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -19,7 +19,7 @@ from sqlalchemy.orm import ( attributes, object_session, util as mapperutil, strategies, object_mapper ) from sqlalchemy.orm.query import Query -from sqlalchemy.orm.util import _state_has_identity, has_identity +from sqlalchemy.orm.util import has_identity from sqlalchemy.orm import attributes, collections class DynaLoader(strategies.AbstractRelationshipLoader): @@ -116,7 +116,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): collection_history = self._modified_event(state, dict_) new_values = list(iterable) - if _state_has_identity(state): + if state.has_identity: old_collection = list(self.get(state, dict_)) else: old_collection = [] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index ccbb273d56..aec7794f31 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -27,7 +27,7 @@ from sqlalchemy.orm.interfaces import ( MapperProperty, EXT_CONTINUE, PropComparator ) from sqlalchemy.orm.util import ( - ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _state_has_identity, + ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _state_mapper, class_mapper, instance_str, state_str, ) @@ -603,20 +603,25 @@ class Mapper(object): # column is coming in after _readonly_props was initialized; check # for 'readonly' if hasattr(self, '_readonly_props') and \ - (not hasattr(col, 'table') or col.table not in self._cols_by_table): + (not hasattr(col, 'table') or + col.table not in self._cols_by_table): self._readonly_props.add(prop) else: - # if column is coming in after _cols_by_table was initialized, ensure the col is in the - # right set - if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]: + # if column is coming in after _cols_by_table was + # initialized, ensure the col is in the right set + if hasattr(self, '_cols_by_table') and \ + col.table in self._cols_by_table and \ + col not in self._cols_by_table[col.table]: self._cols_by_table[col.table].add(col) # if this ColumnProperty represents the "polymorphic discriminator" # column, mark it. We'll need this when rendering columns # in SELECT statements. if not hasattr(prop, '_is_polymorphic_discriminator'): - prop._is_polymorphic_discriminator = (col is self.polymorphic_on or prop.columns[0] is self.polymorphic_on) + prop._is_polymorphic_discriminator = \ + (col is self.polymorphic_on or + prop.columns[0] is self.polymorphic_on) self.columns[key] = col for col in prop.columns: @@ -801,7 +806,7 @@ class Mapper(object): for mapper in self.iterate_to_root(): for (key, cls) in mapper.delete_orphans: if attributes.manager_of_class(cls).has_parent( - state, key, optimistic=_state_has_identity(state)): + state, key, optimistic=state.has_identity): return False o = o or bool(mapper.delete_orphans) return o @@ -1326,7 +1331,7 @@ class Mapper(object): connection_callable(self, state.obj()) or \ connection - has_identity = _state_has_identity(state) + has_identity = state.has_identity mapper = _state_mapper(state) instance_key = state.key or mapper._identity_key_from_state(state) @@ -1525,7 +1530,8 @@ class Mapper(object): c = connection.execute(statement.values(value_params), params) mapper._postfetch(uowtransaction, table, - state, state_dict, c, c.last_updated_params(), value_params) + state, state_dict, c, + c.last_updated_params(), value_params) rows += c.rowcount @@ -1562,12 +1568,14 @@ class Mapper(object): if primary_key is not None: # set primary key attributes for i, col in enumerate(mapper._pks_by_table[table]): - if mapper._get_state_attr_by_column(state, state_dict, col) is None and \ - len(primary_key) > i: - mapper._set_state_attr_by_column(state, state_dict, col, primary_key[i]) + if mapper._get_state_attr_by_column(state, state_dict, col) \ + is None and len(primary_key) > i: + mapper._set_state_attr_by_column(state, state_dict, col, + primary_key[i]) mapper._postfetch(uowtransaction, table, - state, state_dict, c, c.last_inserted_params(), value_params) + state, state_dict, c, c.last_inserted_params(), + value_params) if not postupdate: for state, state_dict, mapper, connection, has_identity, \ @@ -1577,7 +1585,7 @@ class Mapper(object): readonly = state.unmodified.intersection( p.key for p in mapper._readonly_props ) - + if readonly: _expire_state(state, state.dict, readonly) @@ -1675,7 +1683,7 @@ class Mapper(object): tups.append((state, state.dict, _state_mapper(state), - _state_has_identity(state), + state.has_identity, conn)) table_to_mapper = self._sorted_tables @@ -2070,8 +2078,8 @@ def _load_scalar_attributes(state, attribute_names): raise orm_exc.DetachedInstanceError("Instance %s is not bound to a Session; " "attribute refresh operation cannot proceed" % (state_str(state))) - has_key = _state_has_identity(state) - + has_key = state.has_identity + result = False if mapper.inherits and not mapper.concrete: statement = mapper._optimized_get_statement(state, attribute_names) @@ -2086,6 +2094,7 @@ def _load_scalar_attributes(state, attribute_names): identity_key = state.key else: identity_key = mapper._identity_key_from_state(state) + result = session.query(mapper)._get( identity_key, refresh_state=state, @@ -2094,4 +2103,6 @@ def _load_scalar_attributes(state, attribute_names): # if instance is pending, a refresh operation # may not complete (even if PK attributes are assigned) if has_key and result is None: - raise orm_exc.ObjectDeletedError("Instance '%s' has been deleted." % state_str(state)) + raise orm_exc.ObjectDeletedError( + "Instance '%s' has been deleted." % + state_str(state)) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 41a8877bfa..50a8a1084a 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -97,18 +97,23 @@ class ColumnProperty(StrategizedProperty): self.columns[0], self.key)) def copy(self): - return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) + return ColumnProperty( + deferred=self.deferred, + group=self.group, + *self.columns) def _getattr(self, state, dict_, column): return state.get_impl(self.key).get(state, dict_) def _getcommitted(self, state, dict_, column, passive=False): - return state.get_impl(self.key).get_committed_value(state, dict_, passive=passive) + return state.get_impl(self.key).\ + get_committed_value(state, dict_, passive=passive) def _setattr(self, state, dict_, value, column): state.get_impl(self.key).set(state, dict_, value, None) - def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive): + def merge(self, session, source_state, source_dict, dest_state, + dest_dict, load, _recursive): if self.key in source_dict: value = source_dict[self.key] @@ -118,7 +123,7 @@ class ColumnProperty(StrategizedProperty): impl = dest_state.get_impl(self.key) impl.set(dest_state, dest_dict, value, None) else: - if self.key not in dest_dict: + if dest_state.has_identity and self.key not in dest_dict: dest_state.expire_attributes(dest_dict, [self.key]) def get_col_value(self, column, value): @@ -130,7 +135,9 @@ class ColumnProperty(StrategizedProperty): if self.adapter: return self.adapter(self.prop.columns[0]) else: - return self.prop.columns[0]._annotate({"parententity": self.mapper, "parentmapper":self.mapper}) + return self.prop.columns[0]._annotate({ + "parententity": self.mapper, + "parentmapper":self.mapper}) def operate(self, op, *other, **kwargs): return op(self.__clause_element__(), *other, **kwargs) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 42b1b3cb54..713cd8c3d2 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -17,7 +17,7 @@ from sqlalchemy.orm import ( from sqlalchemy.orm.util import object_mapper as _object_mapper from sqlalchemy.orm.util import class_mapper as _class_mapper from sqlalchemy.orm.util import ( - _class_to_mapper, _state_has_identity, _state_mapper, + _class_to_mapper, _state_mapper, ) from sqlalchemy.orm.mapper import Mapper, _none_set from sqlalchemy.orm.unitofwork import UOWTransaction @@ -1018,9 +1018,9 @@ class Session(object): if state.key is None: state.key = instance_key elif state.key != instance_key: - # primary key switch. - # use discard() in case another state has already replaced this - # one in the identity map (see test/orm/test_naturalpks.py ReversePKsTest) + # primary key switch. use discard() in case another + # state has already replaced this one in the identity + # map (see test/orm/test_naturalpks.py ReversePKsTest) self.identity_map.discard(state) state.key = instance_key @@ -1396,7 +1396,7 @@ class Session(object): for state in proc: is_orphan = _state_mapper(state)._is_orphan(state) - if is_orphan and not _state_has_identity(state): + if is_orphan and not state.has_identity: path = ", nor ".join( ["any parent '%s' instance " "via that classes' '%s' attribute" % diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 2b43d2cacd..9ee31bff80 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -43,6 +43,10 @@ class InstanceState(object): @util.memoized_property def callables(self): return {} + + @property + def has_identity(self): + return bool(self.key) def detach(self): if self.session_id: @@ -219,6 +223,14 @@ class InstanceState(object): If all attributes are expired, the "expired" flag is set to True. """ + # we would like to assert that 'self.key is not None' here, + # but there are many cases where the mapper will expire + # a newly persisted instance within the flush, before the + # key is assigned, and even cases where the attribute refresh + # occurs fully, within the flush(), before this key is assigned. + # the key is assigned late within the flush() to assist in + # "key switch" bookkeeping scenarios. + if attribute_names is None: attribute_names = self.manager.keys() self.expired = True diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 96aac7d3a6..5b5dd312d8 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -239,7 +239,7 @@ class DeferredColumnLoader(LoaderStrategy): path, adapter, **kwargs) def _class_level_loader(self, state): - if not mapperutil._state_has_identity(state): + if not state.has_identity: return None return LoadDeferredColumns(state, self.key) @@ -465,7 +465,7 @@ class LazyLoader(AbstractRelationshipLoader): return criterion def _class_level_loader(self, state): - if not mapperutil._state_has_identity(state): + if not state.has_identity: return None return LoadLazyAttribute(state, self.key) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 03f8c0c375..651b0256b1 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -612,10 +612,7 @@ def _class_to_mapper(class_or_mapper, compile=True): def has_identity(object): state = attributes.instance_state(object) - return _state_has_identity(state) - -def _state_has_identity(state): - return bool(state.key) + return state.has_identity def _is_mapped_class(cls): global mapperlib diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index e80b92699a..d63d7e086e 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -927,6 +927,19 @@ class MergeTest(_fixtures.FixtureTest): assert sess.autoflush sess.commit() + @testing.resolve_artifact_names + def test_dont_expire_pending(self): + """test that pending instances aren't expired during a merge.""" + + mapper(User, users) + u = User(id=7) + sess = create_session(autoflush=True, autocommit=False) + u = sess.merge(u) + assert not bool(attributes.instance_state(u).expired_attributes) + def go(): + eq_(u.name, None) + self.assert_sql_count(testing.db, go, 0) + @testing.resolve_artifact_names def test_option_state(self): """test that the merged takes on the MapperOption characteristics -- 2.47.2