]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- significantly rework the approach to collection events and history within DynamicAt...
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Dec 2012 22:53:57 +0000 (17:53 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Dec 2012 22:53:57 +0000 (17:53 -0500)
- Fixes to the "dynamic" loader on :func:`.relationship`, includes
that backrefs will work properly even when autoflush is disabled,
history events are more accurate in scenarios where multiple add/remove
of the same object occurs, as can often be the case in conjunction
with the association proxy.  [ticket:2637]

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dynamic.py
test/orm/test_dynamic.py

index 7a4ab6fc7f147d92298e9b9eae0617d27d776dc3..4eb7fc7122a715631a5483ff443a45a5086eeefe 100644 (file)
@@ -7,6 +7,16 @@
     :version: 0.8.0b2
     :released: December 14, 2012
 
+    .. change::
+        :tags: orm, bug
+        :tickets: 2637
+
+      Fixes to the "dynamic" loader on :func:`.relationship`, includes
+      that backrefs will work properly even when autoflush is disabled,
+      history events are more accurate in scenarios where multiple add/remove
+      of the same object occurs, as can often be the case in conjunction
+      with the association proxy.
+
     .. change::
         :tags: sqlite, bug
         :tickets: 2568
index 80206011cbde13a2f780c734631d97320257e3fe..aef94f27c618eef1c56c21b5938d019b36dc32e8 100644 (file)
@@ -795,9 +795,10 @@ def bulk_replace(values, existing_adapter, new_adapter):
         values = list(values)
 
     idset = util.IdentitySet
-    constants = idset(existing_adapter or ()).intersection(values or ())
+    existing_idset = idset(existing_adapter or ())
+    constants = existing_idset.intersection(values or ())
     additions = idset(values or ()).difference(constants)
-    removals = idset(existing_adapter or ()).difference(constants)
+    removals = existing_idset.difference(constants)
 
     for member in values or ():
         if member in additions:
index 64353cfafa5e0f7d9fc90a2d302290422192e47c..28bddd6130fd40ea5d9941e3636aa60e1fd0427b 100644 (file)
@@ -15,7 +15,7 @@ from .. import log, util, exc
 from ..sql import operators
 from . import (
     attributes, object_session, util as orm_util, strategies,
-    object_mapper, exc as orm_exc, collections
+    object_mapper, exc as orm_exc
     )
 from .query import Query
 
@@ -31,10 +31,12 @@ class DynaLoader(strategies.AbstractRelationshipLoader):
         strategies._register_attribute(self,
             mapper,
             useobject=True,
+            uselist=True,
             impl_class=DynamicAttributeImpl,
             target_mapper=self.parent_property.mapper,
             order_by=self.parent_property.order_by,
-            query_class=self.parent_property.query_class
+            query_class=self.parent_property.query_class,
+            backref=self.parent_property.back_populates,
         )
 
 log.class_logger(DynaLoader)
@@ -74,11 +76,14 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
                     passive).added_items
         else:
             history = self._get_collection_history(state, passive)
-            return history.added_items + history.unchanged_items
+            return history.added_plus_unchanged
 
-    def fire_append_event(self, state, dict_, value, initiator):
-        collection_history = self._modified_event(state, dict_)
-        collection_history.added_items.append(value)
+    def fire_append_event(self, state, dict_, value, initiator,
+                                                    collection_history=None):
+        if collection_history is None:
+            collection_history = self._modified_event(state, dict_)
+
+        collection_history.add_added(value)
 
         for fn in self.dispatch.append:
             value = fn(state, value, initiator or self)
@@ -86,9 +91,12 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         if self.trackparent and value is not None:
             self.sethasparent(attributes.instance_state(value), state, True)
 
-    def fire_remove_event(self, state, dict_, value, initiator):
-        collection_history = self._modified_event(state, dict_)
-        collection_history.deleted_items.append(value)
+    def fire_remove_event(self, state, dict_, value, initiator,
+                                                    collection_history=None):
+        if collection_history is None:
+            collection_history = self._modified_event(state, dict_)
+
+        collection_history.add_removed(value)
 
         if self.trackparent and value is not None:
             self.sethasparent(attributes.instance_state(value), state, False)
@@ -121,16 +129,30 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         self._set_iterable(state, dict_, value)
 
     def _set_iterable(self, state, dict_, iterable, adapter=None):
-        collection_history = self._modified_event(state, dict_)
         new_values = list(iterable)
         if state.has_identity:
-            old_collection = list(self.get(state, dict_))
+            old_collection = util.IdentitySet(self.get(state, dict_))
+
+        collection_history = self._modified_event(state, dict_)
+        if not state.has_identity:
+            old_collection = collection_history.added_items
         else:
-            old_collection = []
-        collections.bulk_replace(new_values, DynCollectionAdapter(self,
-                                 state, old_collection),
-                                 DynCollectionAdapter(self, state,
-                                 new_values))
+            old_collection = old_collection.union(
+                                    collection_history.added_items)
+
+        idset = util.IdentitySet
+        constants = old_collection.intersection(new_values)
+        additions = idset(new_values).difference(constants)
+        removals = old_collection.difference(constants)
+
+        for member in new_values:
+            if member in additions:
+                self.fire_append_event(state, dict_, member, None,
+                                        collection_history=collection_history)
+
+        for member in removals:
+            self.fire_remove_event(state, dict_, member, None,
+                                        collection_history=collection_history)
 
     def delete(self, *args, **kwargs):
         raise NotImplementedError()
@@ -141,8 +163,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
 
     def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
         c = self._get_collection_history(state, passive)
-        return attributes.History(c.added_items, c.unchanged_items,
-                                  c.deleted_items)
+        return c.as_history()
 
     def get_all_pending(self, state, dict_):
         c = self._get_collection_history(
@@ -150,7 +171,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         return [
                 (attributes.instance_state(x), x)
                 for x in
-                c.added_items + c.unchanged_items + c.deleted_items
+                c.all_items
             ]
 
     def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF):
@@ -159,9 +180,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         else:
             c = CollectionHistory(self, state)
 
-        # TODO: consider using a different flag here, possibly
-        # one local to dynamic
-        if passive & attributes.INIT_OK:
+        if state.has_identity:
             return CollectionHistory(self, state, apply_to=c)
         else:
             return c
@@ -177,29 +196,6 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
             self.fire_remove_event(state, dict_, value, initiator)
 
 
-class DynCollectionAdapter(object):
-    """the dynamic analogue to orm.collections.CollectionAdapter"""
-
-    def __init__(self, attr, owner_state, data):
-        self.attr = attr
-        self.state = owner_state
-        self.data = data
-
-    def __iter__(self):
-        return iter(self.data)
-
-    def append_with_event(self, item, initiator=None):
-        self.attr.append(self.state, self.state.dict, item, initiator)
-
-    def remove_with_event(self, item, initiator=None):
-        self.attr.remove(self.state, self.state.dict, item, initiator)
-
-    def append_without_event(self, item):
-        pass
-
-    def remove_without_event(self, item):
-        pass
-
 
 class AppenderMixin(object):
     query_class = None
@@ -220,7 +216,7 @@ class AppenderMixin(object):
         if self.attr.order_by:
             self._order_by = self.attr.order_by
 
-    def __session(self):
+    def session(self):
         sess = object_session(self.instance)
         if sess is not None and self.autoflush and sess.autoflush \
             and self.instance in sess:
@@ -229,13 +225,10 @@ class AppenderMixin(object):
             return None
         else:
             return sess
-
-    def session(self):
-        return self.__session()
     session = property(session, lambda s, x: None)
 
     def __iter__(self):
-        sess = self.__session()
+        sess = self.session
         if sess is None:
             return iter(self.attr._get_collection_history(
                 attributes.instance_state(self.instance),
@@ -244,17 +237,16 @@ class AppenderMixin(object):
             return iter(self._clone(sess))
 
     def __getitem__(self, index):
-        sess = self.__session()
+        sess = self.session
         if sess is None:
             return self.attr._get_collection_history(
                 attributes.instance_state(self.instance),
-                attributes.PASSIVE_NO_INITIALIZE).added_items.\
-                    __getitem__(index)
+                attributes.PASSIVE_NO_INITIALIZE).indexed(index)
         else:
             return self._clone(sess).__getitem__(index)
 
     def count(self):
-        sess = self.__session()
+        sess = self.session
         if sess is None:
             return len(self.attr._get_collection_history(
                 attributes.instance_state(self.instance),
@@ -318,14 +310,44 @@ class CollectionHistory(object):
 
     def __init__(self, attr, state, apply_to=None):
         if apply_to:
-            deleted = util.IdentitySet(apply_to.deleted_items)
-            added = apply_to.added_items
             coll = AppenderQuery(attr, state).autoflush(False)
-            self.unchanged_items = [o for o in util.IdentitySet(coll)
-                                    if o not in deleted]
+            self.unchanged_items = util.OrderedIdentitySet(coll)
             self.added_items = apply_to.added_items
             self.deleted_items = apply_to.deleted_items
         else:
-            self.deleted_items = []
-            self.added_items = []
-            self.unchanged_items = []
+            self.deleted_items = util.OrderedIdentitySet()
+            self.added_items = util.OrderedIdentitySet()
+            self.unchanged_items = util.OrderedIdentitySet()
+
+    @property
+    def added_plus_unchanged(self):
+        return list(self.added_items.union(self.unchanged_items))
+
+    @property
+    def all_items(self):
+        return list(self.added_items.union(
+                        self.unchanged_items).union(self.deleted_items))
+
+    def as_history(self):
+        added = self.added_items.difference(self.unchanged_items)
+        deleted = self.deleted_items.intersection(self.unchanged_items)
+        unchanged = self.unchanged_items.difference(deleted)
+
+        return attributes.History(
+                    list(added),
+                    list(unchanged),
+                    list(deleted),
+                )
+
+    def indexed(self, index):
+        return list(self.added_items)[index]
+
+    def add_added(self, value):
+        self.added_items.add(value)
+
+    def add_removed(self, value):
+        if value in self.added_items:
+            self.added_items.remove(value)
+        else:
+            self.deleted_items.add(value)
+
index a356a562bed38b06a587a576c9edb3355843c202..0eef8f5a53a6dcf82deb3cd2c103cfbad598a7df 100644 (file)
@@ -1,9 +1,9 @@
-from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_, is_
 from sqlalchemy.orm import backref, configure_mappers
 from sqlalchemy import testing
 from sqlalchemy import desc, select, func, exc
 from sqlalchemy.orm import mapper, relationship, create_session, Query, \
-                    attributes, exc as orm_exc
+                    attributes, exc as orm_exc, Session
 from sqlalchemy.orm.dynamic import AppenderMixin
 from sqlalchemy.testing import AssertsCompiledSQL, \
         assert_raises_message, assert_raises
@@ -591,6 +591,16 @@ class UOWTest(_DynamicFixture, _fixtures.FixtureTest):
     def test_backref_savead(self):
         self._backref_test(False, False)
 
+    def test_backref_events(self):
+        User, Address = self._user_address_fixture(addresses_args={
+                            "backref": "user",
+            })
+
+        u1 = User()
+        a1 = Address()
+        u1.addresses.append(a1)
+        is_(a1.user, u1)
+
     def test_no_deref(self):
         User, Address = self._user_address_fixture(addresses_args={
                             "backref": "user",
@@ -626,4 +636,162 @@ class UOWTest(_DynamicFixture, _fixtures.FixtureTest):
         eq_(query2(), [Address(email_address='joe@joesdomain.example')])
         eq_(query3(), [Address(email_address='joe@joesdomain.example')])
 
+class HistoryTest(_DynamicFixture, _fixtures.FixtureTest):
+    run_inserts = None
+
+    def _transient_fixture(self):
+        User, Address = self._user_address_fixture()
+
+        u1 = User()
+        a1 = Address()
+        return u1, a1
+
+    def _persistent_fixture(self, autoflush=True):
+        User, Address = self._user_address_fixture()
+
+        u1 = User(name='u1')
+        a1 = Address(email_address='a1')
+        s = Session(autoflush=autoflush)
+        s.add(u1)
+        s.flush()
+        return u1, a1, s
+
+    def _assert_history(self, obj, compare):
+        eq_(
+            attributes.get_history(obj, 'addresses'),
+            compare
+        )
+
+        eq_(
+            attributes.get_history(obj, 'addresses',
+                        attributes.LOAD_AGAINST_COMMITTED),
+            compare
+        )
+
+    def test_append_transient(self):
+        u1, a1 = self._transient_fixture()
+        u1.addresses.append(a1)
+
+        self._assert_history(u1,
+            ([a1], [], [])
+        )
+
+    def test_append_persistent(self):
+        u1, a1, s = self._persistent_fixture()
+        u1.addresses.append(a1)
+
+        self._assert_history(u1,
+            ([a1], [], [])
+        )
+
+    def test_remove_transient(self):
+        u1, a1 = self._transient_fixture()
+        u1.addresses.append(a1)
+        u1.addresses.remove(a1)
+
+        self._assert_history(u1,
+            ([], [], [])
+        )
+
+    def test_remove_persistent(self):
+        u1, a1, s = self._persistent_fixture()
+        u1.addresses.append(a1)
+        s.flush()
+        s.expire_all()
+
+        u1.addresses.remove(a1)
+
+        self._assert_history(u1,
+            ([], [], [a1])
+        )
+
+    def test_unchanged_persistent(self):
+        Address = self.classes.Address
+
+        u1, a1, s = self._persistent_fixture()
+        a2, a3 = Address(email_address='a2'), Address(email_address='a3')
+
+        u1.addresses.append(a1)
+        u1.addresses.append(a2)
+        s.flush()
+
+        u1.addresses.append(a3)
+        u1.addresses.remove(a2)
+
+        self._assert_history(u1,
+            ([a3], [a1], [a2])
+        )
+
+    def test_replace_transient(self):
+        Address = self.classes.Address
+
+        u1, a1 = self._transient_fixture()
+        a2, a3, a4, a5 = Address(email_address='a2'), \
+                            Address(email_address='a3'), \
+                            Address(email_address='a4'), \
+                            Address(email_address='a5')
+
+        u1.addresses = [a1, a2]
+        u1.addresses = [a2, a3, a4, a5]
+
+        self._assert_history(u1,
+            ([a2, a3, a4, a5], [], [])
+        )
+
+    def test_replace_persistent_noflush(self):
+        Address = self.classes.Address
+
+        u1, a1, s = self._persistent_fixture(autoflush=False)
+        a2, a3, a4, a5 = Address(email_address='a2'), \
+                            Address(email_address='a3'), \
+                            Address(email_address='a4'), \
+                            Address(email_address='a5')
+
+        u1.addresses = [a1, a2]
+        u1.addresses = [a2, a3, a4, a5]
+
+        self._assert_history(u1,
+            ([a2, a3, a4, a5], [], [])
+        )
+
+    def test_replace_persistent_autoflush(self):
+        Address = self.classes.Address
+
+        u1, a1, s = self._persistent_fixture(autoflush=True)
+        a2, a3, a4, a5 = Address(email_address='a2'), \
+                            Address(email_address='a3'), \
+                            Address(email_address='a4'), \
+                            Address(email_address='a5')
+
+        u1.addresses = [a1, a2]
+        u1.addresses = [a2, a3, a4, a5]
+
+        self._assert_history(u1,
+            ([a3, a4, a5], [a2], [a1])
+        )
+
+
+    def test_persistent_but_readded_noflush(self):
+        u1, a1, s = self._persistent_fixture(autoflush=False)
+        u1.addresses.append(a1)
+        s.flush()
+
+        u1.addresses.append(a1)
+
+        self._assert_history(u1, ([], [a1], []))
+
+    def test_persistent_but_readded_autoflush(self):
+        u1, a1, s = self._persistent_fixture(autoflush=True)
+        u1.addresses.append(a1)
+        s.flush()
+
+        u1.addresses.append(a1)
+
+        self._assert_history(u1, ([], [a1], []))
+
+    def test_missing_but_removed_noflush(self):
+        u1, a1, s = self._persistent_fixture(autoflush=False)
+
+        u1.addresses.remove(a1)
 
+        self._assert_history(u1, ([], [], []))