From 9cf10db8aa4692dc615f1a03db5ffe342c321586 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 24 Apr 2012 18:06:27 -0400 Subject: [PATCH] - [feature] Calling rollback() within a session.begin_nested() will now only expire those objects that had net changes within the scope of that transaction, that is objects which were dirty or were modified on a flush. This allows the typical use case for begin_nested(), that of altering a small subset of objects, to leave in place the data from the larger enclosing set of objects that weren't modified in that sub-transaction. [ticket:2452] - inline the "register_newly_XYZ" functions to operate upon collections to reduce method calls --- CHANGES | 11 ++++ lib/sqlalchemy/orm/query.py | 34 +++++++---- lib/sqlalchemy/orm/session.py | 98 +++++++++++++++++------------- lib/sqlalchemy/orm/unitofwork.py | 15 ++--- test/aaa_profiling/test_zoomark.py | 4 +- test/orm/test_inspect.py | 2 +- test/orm/test_transaction.py | 51 +++++++++++++++- 7 files changed, 148 insertions(+), 67 deletions(-) diff --git a/CHANGES b/CHANGES index 65599b1390..d7b09c9ce5 100644 --- a/CHANGES +++ b/CHANGES @@ -60,6 +60,17 @@ CHANGES of current object state, history of attributes, etc. [ticket:2208] + - [feature] Calling rollback() within a + session.begin_nested() will now only expire + those objects that had net changes within the + scope of that transaction, that is objects which + were dirty or were modified on a flush. This + allows the typical use case for begin_nested(), + that of altering a small subset of objects, to + leave in place the data from the larger enclosing + set of objects that weren't modified in + that sub-transaction. [ticket:2452] + - [bug] The "passive" flag on Session.is_modified() no longer has any effect. is_modified() in all cases looks only at local in-memory diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index dda231e0c7..5cf9ea5cfe 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2424,7 +2424,7 @@ class Query(object): try: state(passive) except orm_exc.ObjectDeletedError: - session._remove_newly_deleted(state) + session._remove_newly_deleted([state]) return None return instance else: @@ -2650,18 +2650,20 @@ class Query(object): result = session.execute(delete_stmt, params=self._params) if synchronize_session == 'evaluate': - for obj in objs_to_expunge: - session._remove_newly_deleted(attributes.instance_state(obj)) + session._remove_newly_deleted([attributes.instance_state(obj) + for obj in objs_to_expunge]) elif synchronize_session == 'fetch': target_mapper = self._mapper_zero() for primary_key in matched_rows: + # TODO: inline this and call remove_newly_deleted + # once identity_key = target_mapper.identity_key_from_primary_key( list(primary_key)) if identity_key in session.identity_map: session._remove_newly_deleted( - attributes.instance_state( + [attributes.instance_state( session.identity_map[identity_key] - ) + )] ) session.dispatch.after_bulk_delete(session, self, context, result) @@ -2788,7 +2790,7 @@ class Query(object): if synchronize_session == 'evaluate': target_cls = self._mapper_zero().class_ - + states = set() for obj in matched_objects: state, dict_ = attributes.instance_state(obj),\ attributes.instance_dict(obj) @@ -2806,18 +2808,24 @@ class Query(object): state.expire_attributes(dict_, set(evaluated_keys). difference(to_evaluate)) + states.add(state) + session._register_altered(states) elif synchronize_session == 'fetch': target_mapper = self._mapper_zero() - for primary_key in matched_rows: - identity_key = target_mapper.identity_key_from_primary_key( + states = set([ + attributes.instance_state(session.identity_map[identity_key]) + for identity_key in [ + target_mapper.identity_key_from_primary_key( list(primary_key)) - if identity_key in session.identity_map: - session.expire( - session.identity_map[identity_key], - [_attr_as_key(k) for k in values] - ) + for primary_key in matched_rows + ] + ]) + attrib = [_attr_as_key(k) for k in values] + for state in states: + session._expire_state(state, attrib) + session._register_altered(states) session.dispatch.after_bulk_update(session, self, context, result) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 7c2cd8f0e1..eb15e033e7 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -211,6 +211,7 @@ class SessionTransaction(object): if not self._is_transaction_boundary: self._new = self._parent._new self._deleted = self._parent._deleted + self._dirty = self._parent._dirty return if not self.session._flushing: @@ -218,8 +219,9 @@ class SessionTransaction(object): self._new = weakref.WeakKeyDictionary() self._deleted = weakref.WeakKeyDictionary() + self._dirty = weakref.WeakKeyDictionary() - def _restore_snapshot(self): + def _restore_snapshot(self, dirty_only=False): assert self._is_transaction_boundary for s in set(self._new).union(self.session._new): @@ -236,7 +238,8 @@ class SessionTransaction(object): assert not self.session._deleted for s in self.session.identity_map.all_states(): - s.expire(s.dict, self.session.identity_map._modified) + if not dirty_only or s.modified or s in self._dirty: + s.expire(s.dict, self.session.identity_map._modified) def _remove_snapshot(self): assert self._is_transaction_boundary @@ -351,7 +354,7 @@ class SessionTransaction(object): "Session's state has been changed on " "a non-active transaction - this state " "will be discarded.") - self._restore_snapshot() + self._restore_snapshot(dirty_only=self.nested) self.close() if self._parent and _capture_exception: @@ -366,7 +369,7 @@ class SessionTransaction(object): t[1].rollback() if self.session._enable_transaction_accounting: - self._restore_snapshot() + self._restore_snapshot(dirty_only=self.nested) self.session.dispatch.after_rollback(self.session) @@ -1185,53 +1188,62 @@ class Session(object): elif self.transaction: self.transaction._deleted.pop(state, None) - def _register_newly_persistent(self, state): - mapper = _state_mapper(state) - - # prevent against last minute dereferences of the object - obj = state.obj() - if obj is not None: - - instance_key = mapper._identity_key_from_state(state) - - if _none_set.issubset(instance_key[1]) and \ - not mapper.allow_partial_pks or \ - _none_set.issuperset(instance_key[1]): - raise exc.FlushError( - "Instance %s has a NULL identity key. If this is an " - "auto-generated value, check that the database table " - "allows generation of new primary key values, and that " - "the mapped Column object is configured to expect these " - "generated values. Ensure also that this flush() is " - "not occurring at an inappropriate time, such as within " - "a load() event." % mapperutil.state_str(state) - ) + def _register_newly_persistent(self, states): + for state in states: + mapper = _state_mapper(state) + + # prevent against last minute dereferences of the object + obj = state.obj() + if obj is not None: + + instance_key = mapper._identity_key_from_state(state) + + if _none_set.issubset(instance_key[1]) and \ + not mapper.allow_partial_pks or \ + _none_set.issuperset(instance_key[1]): + raise exc.FlushError( + "Instance %s has a NULL identity key. If this is an " + "auto-generated value, check that the database table " + "allows generation of new primary key values, and that " + "the mapped Column object is configured to expect these " + "generated values. Ensure also that this flush() is " + "not occurring at an inappropriate time, such as within " + "a load() event." % mapperutil.state_str(state) + ) - if state.key is None: - state.key = instance_key - elif state.key != instance_key: - # primary key switch. use discard() in case another - # state has already replaced this one in the identity - # map (see test/orm/test_naturalpks.py ReversePKsTest) - self.identity_map.discard(state) - state.key = instance_key + if state.key is None: + state.key = instance_key + elif state.key != instance_key: + # primary key switch. use discard() in case another + # state has already replaced this one in the identity + # map (see test/orm/test_naturalpks.py ReversePKsTest) + self.identity_map.discard(state) + state.key = instance_key - self.identity_map.replace(state) - state.commit_all(state.dict, self.identity_map) + self.identity_map.replace(state) + state.commit_all(state.dict, self.identity_map) + self._register_altered(states) # remove from new last, might be the last strong ref - if state in self._new: - if self._enable_transaction_accounting and self.transaction: - self.transaction._new[state] = True + for state in set(states).intersection(self._new): self._new.pop(state) - def _remove_newly_deleted(self, state): + def _register_altered(self, states): if self._enable_transaction_accounting and self.transaction: - self.transaction._deleted[state] = True + for state in states: + if state in self._new: + self.transaction._new[state] = True + else: + self.transaction._dirty[state] = True - self.identity_map.discard(state) - self._deleted.pop(state, None) - state.deleted = True + def _remove_newly_deleted(self, states): + for state in states: + if self._enable_transaction_accounting and self.transaction: + self.transaction._deleted[state] = True + + self.identity_map.discard(state) + self._deleted.pop(state, None) + state.deleted = True def add(self, instance): """Place an object in the ``Session``. diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 3523e7d06e..bc3be8b413 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -339,13 +339,14 @@ class UOWTransaction(object): execute() method has succeeded and the transaction has been committed. """ - for state, (isdelete, listonly) in self.states.iteritems(): - if isdelete: - self.session._remove_newly_deleted(state) - else: - # if listonly: - # debug... would like to see how many do this - self.session._register_newly_persistent(state) + states = set(self.states) + isdel = set( + s for (s, (isdelete, listonly)) in self.states.iteritems() + if isdelete + ) + other = states.difference(isdel) + self.session._remove_newly_deleted(isdel) + self.session._register_newly_persistent(other) class IterateMappersMixin(object): def _mappers(self, uow): diff --git a/test/aaa_profiling/test_zoomark.py b/test/aaa_profiling/test_zoomark.py index d4c66336c0..3706d8e82d 100644 --- a/test/aaa_profiling/test_zoomark.py +++ b/test/aaa_profiling/test_zoomark.py @@ -377,8 +377,8 @@ class ZooMarkTest(fixtures.TestBase): def test_profile_2_insert(self): self.test_baseline_2_insert() - @profiling.function_call_count(3340, {'2.7':3333, - '2.7+cextension':3317, '2.6':3333}) + @profiling.function_call_count(3340, {'2.7':3109, + '2.7+cextension':3109, '2.6':3109}) def test_profile_3_properties(self): self.test_baseline_3_properties() diff --git a/test/orm/test_inspect.py b/test/orm/test_inspect.py index 9973c31c27..70002ee288 100644 --- a/test/orm/test_inspect.py +++ b/test/orm/test_inspect.py @@ -237,7 +237,7 @@ class TestORMInspection(_fixtures.FixtureTest): insp = inspect(u1) eq_( insp.identity_key, - (User, (11, )) + (User, (u1.id, )) ) def test_persistence_states(self): diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 8029cd2b21..56b2b79da1 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -111,7 +111,6 @@ class SessionTransactionTest(FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - users.delete().execute() s1 = create_session(bind=testing.db, autocommit=False) s2 = create_session(bind=testing.db, autocommit=False) @@ -454,6 +453,56 @@ class FixtureDataTest(_LocalFixture): assert u1.name == 'will' +class CleanSavepointTest(FixtureTest): + """test the behavior for [ticket:2452] - rollback on begin_nested() + only expires objects tracked as being modified in that transaction. + + """ + run_inserts = None + + def _run_test(self, update_fn): + User, users = self.classes.User, self.tables.users + + mapper(User, users) + + s = Session(bind=testing.db) + u1 = User(name='u1') + u2 = User(name='u2') + s.add_all([u1, u2]) + s.commit() + u1.name + u2.name + s.begin_nested() + update_fn(s, u2) + eq_(u2.name, 'u2modified') + s.rollback() + eq_(u1.__dict__['name'], 'u1') + assert 'name' not in u2.__dict__ + eq_(u2.name, 'u2') + + @testing.requires.savepoints + def test_rollback_ignores_clean_on_savepoint(self): + User, users = self.classes.User, self.tables.users + def update_fn(s, u2): + u2.name = 'u2modified' + self._run_test(update_fn) + + @testing.requires.savepoints + def test_rollback_ignores_clean_on_savepoint_agg_upd_eval(self): + User, users = self.classes.User, self.tables.users + def update_fn(s, u2): + s.query(User).filter_by(name='u2').update(dict(name='u2modified'), + synchronize_session='evaluate') + self._run_test(update_fn) + + @testing.requires.savepoints + def test_rollback_ignores_clean_on_savepoint_agg_upd_fetch(self): + User, users = self.classes.User, self.tables.users + def update_fn(s, u2): + s.query(User).filter_by(name='u2').update(dict(name='u2modified'), + synchronize_session='fetch') + self._run_test(update_fn) + class AutoExpireTest(_LocalFixture): def test_expunge_pending_on_rollback(self): -- 2.47.3