From: Mike Bayer Date: Fri, 21 Dec 2012 22:53:57 +0000 (-0500) Subject: - significantly rework the approach to collection events and history within DynamicAt... X-Git-Tag: rel_0_8_0~38^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3522785ef493a6cad6403b4c702bbfe2f1b7dc89;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - significantly rework the approach to collection events and history within DynamicAttributeImpl - Fixes to the "dynamic" loader on :func:`.relationship`, includes that backrefs will work properly even when autoflush is disabled, history events are more accurate in scenarios where multiple add/remove of the same object occurs, as can often be the case in conjunction with the association proxy. [ticket:2637] --- diff --git a/doc/build/changelog/changelog_08.rst b/doc/build/changelog/changelog_08.rst index 7a4ab6fc7f..4eb7fc7122 100644 --- a/doc/build/changelog/changelog_08.rst +++ b/doc/build/changelog/changelog_08.rst @@ -7,6 +7,16 @@ :version: 0.8.0b2 :released: December 14, 2012 + .. change:: + :tags: orm, bug + :tickets: 2637 + + Fixes to the "dynamic" loader on :func:`.relationship`, includes + that backrefs will work properly even when autoflush is disabled, + history events are more accurate in scenarios where multiple add/remove + of the same object occurs, as can often be the case in conjunction + with the association proxy. + .. change:: :tags: sqlite, bug :tickets: 2568 diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 80206011cb..aef94f27c6 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -795,9 +795,10 @@ def bulk_replace(values, existing_adapter, new_adapter): values = list(values) idset = util.IdentitySet - constants = idset(existing_adapter or ()).intersection(values or ()) + existing_idset = idset(existing_adapter or ()) + constants = existing_idset.intersection(values or ()) additions = idset(values or ()).difference(constants) - removals = idset(existing_adapter or ()).difference(constants) + removals = existing_idset.difference(constants) for member in values or (): if member in additions: diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 64353cfafa..28bddd6130 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -15,7 +15,7 @@ from .. import log, util, exc from ..sql import operators from . import ( attributes, object_session, util as orm_util, strategies, - object_mapper, exc as orm_exc, collections + object_mapper, exc as orm_exc ) from .query import Query @@ -31,10 +31,12 @@ class DynaLoader(strategies.AbstractRelationshipLoader): strategies._register_attribute(self, mapper, useobject=True, + uselist=True, impl_class=DynamicAttributeImpl, target_mapper=self.parent_property.mapper, order_by=self.parent_property.order_by, - query_class=self.parent_property.query_class + query_class=self.parent_property.query_class, + backref=self.parent_property.back_populates, ) log.class_logger(DynaLoader) @@ -74,11 +76,14 @@ class DynamicAttributeImpl(attributes.AttributeImpl): passive).added_items else: history = self._get_collection_history(state, passive) - return history.added_items + history.unchanged_items + return history.added_plus_unchanged - def fire_append_event(self, state, dict_, value, initiator): - collection_history = self._modified_event(state, dict_) - collection_history.added_items.append(value) + def fire_append_event(self, state, dict_, value, initiator, + collection_history=None): + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_added(value) for fn in self.dispatch.append: value = fn(state, value, initiator or self) @@ -86,9 +91,12 @@ class DynamicAttributeImpl(attributes.AttributeImpl): if self.trackparent and value is not None: self.sethasparent(attributes.instance_state(value), state, True) - def fire_remove_event(self, state, dict_, value, initiator): - collection_history = self._modified_event(state, dict_) - collection_history.deleted_items.append(value) + def fire_remove_event(self, state, dict_, value, initiator, + collection_history=None): + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_removed(value) if self.trackparent and value is not None: self.sethasparent(attributes.instance_state(value), state, False) @@ -121,16 +129,30 @@ class DynamicAttributeImpl(attributes.AttributeImpl): self._set_iterable(state, dict_, value) def _set_iterable(self, state, dict_, iterable, adapter=None): - collection_history = self._modified_event(state, dict_) new_values = list(iterable) if state.has_identity: - old_collection = list(self.get(state, dict_)) + old_collection = util.IdentitySet(self.get(state, dict_)) + + collection_history = self._modified_event(state, dict_) + if not state.has_identity: + old_collection = collection_history.added_items else: - old_collection = [] - collections.bulk_replace(new_values, DynCollectionAdapter(self, - state, old_collection), - DynCollectionAdapter(self, state, - new_values)) + old_collection = old_collection.union( + collection_history.added_items) + + idset = util.IdentitySet + constants = old_collection.intersection(new_values) + additions = idset(new_values).difference(constants) + removals = old_collection.difference(constants) + + for member in new_values: + if member in additions: + self.fire_append_event(state, dict_, member, None, + collection_history=collection_history) + + for member in removals: + self.fire_remove_event(state, dict_, member, None, + collection_history=collection_history) def delete(self, *args, **kwargs): raise NotImplementedError() @@ -141,8 +163,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF): c = self._get_collection_history(state, passive) - return attributes.History(c.added_items, c.unchanged_items, - c.deleted_items) + return c.as_history() def get_all_pending(self, state, dict_): c = self._get_collection_history( @@ -150,7 +171,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): return [ (attributes.instance_state(x), x) for x in - c.added_items + c.unchanged_items + c.deleted_items + c.all_items ] def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF): @@ -159,9 +180,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: c = CollectionHistory(self, state) - # TODO: consider using a different flag here, possibly - # one local to dynamic - if passive & attributes.INIT_OK: + if state.has_identity: return CollectionHistory(self, state, apply_to=c) else: return c @@ -177,29 +196,6 @@ class DynamicAttributeImpl(attributes.AttributeImpl): self.fire_remove_event(state, dict_, value, initiator) -class DynCollectionAdapter(object): - """the dynamic analogue to orm.collections.CollectionAdapter""" - - def __init__(self, attr, owner_state, data): - self.attr = attr - self.state = owner_state - self.data = data - - def __iter__(self): - return iter(self.data) - - def append_with_event(self, item, initiator=None): - self.attr.append(self.state, self.state.dict, item, initiator) - - def remove_with_event(self, item, initiator=None): - self.attr.remove(self.state, self.state.dict, item, initiator) - - def append_without_event(self, item): - pass - - def remove_without_event(self, item): - pass - class AppenderMixin(object): query_class = None @@ -220,7 +216,7 @@ class AppenderMixin(object): if self.attr.order_by: self._order_by = self.attr.order_by - def __session(self): + def session(self): sess = object_session(self.instance) if sess is not None and self.autoflush and sess.autoflush \ and self.instance in sess: @@ -229,13 +225,10 @@ class AppenderMixin(object): return None else: return sess - - def session(self): - return self.__session() session = property(session, lambda s, x: None) def __iter__(self): - sess = self.__session() + sess = self.session if sess is None: return iter(self.attr._get_collection_history( attributes.instance_state(self.instance), @@ -244,17 +237,16 @@ class AppenderMixin(object): return iter(self._clone(sess)) def __getitem__(self, index): - sess = self.__session() + sess = self.session if sess is None: return self.attr._get_collection_history( attributes.instance_state(self.instance), - attributes.PASSIVE_NO_INITIALIZE).added_items.\ - __getitem__(index) + attributes.PASSIVE_NO_INITIALIZE).indexed(index) else: return self._clone(sess).__getitem__(index) def count(self): - sess = self.__session() + sess = self.session if sess is None: return len(self.attr._get_collection_history( attributes.instance_state(self.instance), @@ -318,14 +310,44 @@ class CollectionHistory(object): def __init__(self, attr, state, apply_to=None): if apply_to: - deleted = util.IdentitySet(apply_to.deleted_items) - added = apply_to.added_items coll = AppenderQuery(attr, state).autoflush(False) - self.unchanged_items = [o for o in util.IdentitySet(coll) - if o not in deleted] + self.unchanged_items = util.OrderedIdentitySet(coll) self.added_items = apply_to.added_items self.deleted_items = apply_to.deleted_items else: - self.deleted_items = [] - self.added_items = [] - self.unchanged_items = [] + self.deleted_items = util.OrderedIdentitySet() + self.added_items = util.OrderedIdentitySet() + self.unchanged_items = util.OrderedIdentitySet() + + @property + def added_plus_unchanged(self): + return list(self.added_items.union(self.unchanged_items)) + + @property + def all_items(self): + return list(self.added_items.union( + self.unchanged_items).union(self.deleted_items)) + + def as_history(self): + added = self.added_items.difference(self.unchanged_items) + deleted = self.deleted_items.intersection(self.unchanged_items) + unchanged = self.unchanged_items.difference(deleted) + + return attributes.History( + list(added), + list(unchanged), + list(deleted), + ) + + def indexed(self, index): + return list(self.added_items)[index] + + def add_added(self, value): + self.added_items.add(value) + + def add_removed(self, value): + if value in self.added_items: + self.added_items.remove(value) + else: + self.deleted_items.add(value) + diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index a356a562be..0eef8f5a53 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -1,9 +1,9 @@ -from sqlalchemy.testing import eq_ +from sqlalchemy.testing import eq_, is_ from sqlalchemy.orm import backref, configure_mappers from sqlalchemy import testing from sqlalchemy import desc, select, func, exc from sqlalchemy.orm import mapper, relationship, create_session, Query, \ - attributes, exc as orm_exc + attributes, exc as orm_exc, Session from sqlalchemy.orm.dynamic import AppenderMixin from sqlalchemy.testing import AssertsCompiledSQL, \ assert_raises_message, assert_raises @@ -591,6 +591,16 @@ class UOWTest(_DynamicFixture, _fixtures.FixtureTest): def test_backref_savead(self): self._backref_test(False, False) + def test_backref_events(self): + User, Address = self._user_address_fixture(addresses_args={ + "backref": "user", + }) + + u1 = User() + a1 = Address() + u1.addresses.append(a1) + is_(a1.user, u1) + def test_no_deref(self): User, Address = self._user_address_fixture(addresses_args={ "backref": "user", @@ -626,4 +636,162 @@ class UOWTest(_DynamicFixture, _fixtures.FixtureTest): eq_(query2(), [Address(email_address='joe@joesdomain.example')]) eq_(query3(), [Address(email_address='joe@joesdomain.example')]) +class HistoryTest(_DynamicFixture, _fixtures.FixtureTest): + run_inserts = None + + def _transient_fixture(self): + User, Address = self._user_address_fixture() + + u1 = User() + a1 = Address() + return u1, a1 + + def _persistent_fixture(self, autoflush=True): + User, Address = self._user_address_fixture() + + u1 = User(name='u1') + a1 = Address(email_address='a1') + s = Session(autoflush=autoflush) + s.add(u1) + s.flush() + return u1, a1, s + + def _assert_history(self, obj, compare): + eq_( + attributes.get_history(obj, 'addresses'), + compare + ) + + eq_( + attributes.get_history(obj, 'addresses', + attributes.LOAD_AGAINST_COMMITTED), + compare + ) + + def test_append_transient(self): + u1, a1 = self._transient_fixture() + u1.addresses.append(a1) + + self._assert_history(u1, + ([a1], [], []) + ) + + def test_append_persistent(self): + u1, a1, s = self._persistent_fixture() + u1.addresses.append(a1) + + self._assert_history(u1, + ([a1], [], []) + ) + + def test_remove_transient(self): + u1, a1 = self._transient_fixture() + u1.addresses.append(a1) + u1.addresses.remove(a1) + + self._assert_history(u1, + ([], [], []) + ) + + def test_remove_persistent(self): + u1, a1, s = self._persistent_fixture() + u1.addresses.append(a1) + s.flush() + s.expire_all() + + u1.addresses.remove(a1) + + self._assert_history(u1, + ([], [], [a1]) + ) + + def test_unchanged_persistent(self): + Address = self.classes.Address + + u1, a1, s = self._persistent_fixture() + a2, a3 = Address(email_address='a2'), Address(email_address='a3') + + u1.addresses.append(a1) + u1.addresses.append(a2) + s.flush() + + u1.addresses.append(a3) + u1.addresses.remove(a2) + + self._assert_history(u1, + ([a3], [a1], [a2]) + ) + + def test_replace_transient(self): + Address = self.classes.Address + + u1, a1 = self._transient_fixture() + a2, a3, a4, a5 = Address(email_address='a2'), \ + Address(email_address='a3'), \ + Address(email_address='a4'), \ + Address(email_address='a5') + + u1.addresses = [a1, a2] + u1.addresses = [a2, a3, a4, a5] + + self._assert_history(u1, + ([a2, a3, a4, a5], [], []) + ) + + def test_replace_persistent_noflush(self): + Address = self.classes.Address + + u1, a1, s = self._persistent_fixture(autoflush=False) + a2, a3, a4, a5 = Address(email_address='a2'), \ + Address(email_address='a3'), \ + Address(email_address='a4'), \ + Address(email_address='a5') + + u1.addresses = [a1, a2] + u1.addresses = [a2, a3, a4, a5] + + self._assert_history(u1, + ([a2, a3, a4, a5], [], []) + ) + + def test_replace_persistent_autoflush(self): + Address = self.classes.Address + + u1, a1, s = self._persistent_fixture(autoflush=True) + a2, a3, a4, a5 = Address(email_address='a2'), \ + Address(email_address='a3'), \ + Address(email_address='a4'), \ + Address(email_address='a5') + + u1.addresses = [a1, a2] + u1.addresses = [a2, a3, a4, a5] + + self._assert_history(u1, + ([a3, a4, a5], [a2], [a1]) + ) + + + def test_persistent_but_readded_noflush(self): + u1, a1, s = self._persistent_fixture(autoflush=False) + u1.addresses.append(a1) + s.flush() + + u1.addresses.append(a1) + + self._assert_history(u1, ([], [a1], [])) + + def test_persistent_but_readded_autoflush(self): + u1, a1, s = self._persistent_fixture(autoflush=True) + u1.addresses.append(a1) + s.flush() + + u1.addresses.append(a1) + + self._assert_history(u1, ([], [a1], [])) + + def test_missing_but_removed_noflush(self): + u1, a1, s = self._persistent_fixture(autoflush=False) + + u1.addresses.remove(a1) + self._assert_history(u1, ([], [], []))