]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure backrefs accommodate for op_bulk_replace
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 24 Jan 2018 16:09:47 +0000 (11:09 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 25 Jan 2018 01:55:13 +0000 (20:55 -0500)
Fixed 1.2 regression regarding new bulk_replace event
where a backref would fail to remove an object from the
previous owner when a bulk-assignment assigned the
object to a new owner.

As this revisits the Event tokens associated with
AttributeImpl objects, remove the verbosity of the
"inline lazy init" pattern; the Event token is a simple
slotted object and should have minimal memory overhead.

Change-Id: Id188b4026fc2f3500186548008f4db8cdf7afecc
Fixes: #4171
doc/build/changelog/unreleased_12/4171.rst [new file with mode: 0644]
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/dynamic.py
test/orm/test_backref_mutations.py

diff --git a/doc/build/changelog/unreleased_12/4171.rst b/doc/build/changelog/unreleased_12/4171.rst
new file mode 100644 (file)
index 0000000..69a31b9
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 4171
+
+    Fixed 1.2 regression regarding new bulk_replace event
+    where a backref would fail to remove an object from the
+    previous owner when a bulk-assignment assigned the
+    object to a new owner.
index c9fa91b184b5003322fe1ce25026cf951a0e2d3d..b175297acef39c99092ea6c4f1f935d3ff6929cc 100644 (file)
@@ -465,7 +465,7 @@ class AttributeImpl(object):
             self.dispatch._active_history = True
 
         self.expire_missing = expire_missing
-        self._modified_token = None
+        self._modified_token = Event(self, OP_MODIFIED)
 
     __slots__ = (
         'class_', 'key', 'callable_', 'dispatch', 'trackparent',
@@ -473,10 +473,6 @@ class AttributeImpl(object):
         '_modified_token', 'accepts_scalar_loader'
     )
 
-    def _init_modified_token(self):
-        self._modified_token = Event(self, OP_MODIFIED)
-        return self._modified_token
-
     def __str__(self):
         return "%s.%s" % (self.class_.__name__, self.key)
 
@@ -666,23 +662,14 @@ class ScalarAttributeImpl(AttributeImpl):
     uses_objects = False
     supports_population = True
     collection = False
+    dynamic = False
 
     __slots__ = '_replace_token', '_append_token', '_remove_token'
 
     def __init__(self, *arg, **kw):
         super(ScalarAttributeImpl, self).__init__(*arg, **kw)
-        self._replace_token = self._append_token = None
-        self._remove_token = None
-
-    def _init_append_token(self):
         self._replace_token = self._append_token = Event(self, OP_REPLACE)
-        return self._replace_token
-
-    _init_append_or_replace_token = _init_append_token
-
-    def _init_remove_token(self):
         self._remove_token = Event(self, OP_REMOVE)
-        return self._remove_token
 
     def delete(self, state, dict_):
 
@@ -725,15 +712,12 @@ class ScalarAttributeImpl(AttributeImpl):
     def fire_replace_event(self, state, dict_, value, previous, initiator):
         for fn in self.dispatch.set:
             value = fn(
-                state, value, previous,
-                initiator or self._replace_token or
-                self._init_append_or_replace_token())
+                state, value, previous, initiator or self._replace_token)
         return value
 
     def fire_remove_event(self, state, dict_, value, initiator):
         for fn in self.dispatch.remove:
-            fn(state, value,
-               initiator or self._remove_token or self._init_remove_token())
+            fn(state, value, initiator or self._remove_token)
 
     @property
     def type(self):
@@ -757,9 +741,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
     def delete(self, state, dict_):
         old = self.get(state, dict_)
-        self.fire_remove_event(
-            state, dict_, old,
-            self._remove_token or self._init_remove_token())
+        self.fire_remove_event(state, dict_, old, self._remove_token)
         del dict_[self.key]
 
     def get_history(self, state, dict_, passive=PASSIVE_OFF):
@@ -836,8 +818,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
             self.sethasparent(instance_state(value), state, False)
 
         for fn in self.dispatch.remove:
-            fn(state, value, initiator or
-               self._remove_token or self._init_remove_token())
+            fn(state, value, initiator or self._remove_token)
 
         state._modified_event(dict_, self, value)
 
@@ -849,8 +830,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
         for fn in self.dispatch.set:
             value = fn(
-                state, value, previous, initiator or
-                self._replace_token or self._init_append_or_replace_token())
+                state, value, previous, initiator or self._replace_token)
 
         state._modified_event(dict_, self, previous)
 
@@ -876,10 +856,11 @@ class CollectionAttributeImpl(AttributeImpl):
     uses_objects = True
     supports_population = True
     collection = True
+    dynamic = False
 
     __slots__ = (
         'copy', 'collection_factory', '_append_token', '_remove_token',
-        '_duck_typed_as'
+        '_bulk_replace_token', '_duck_typed_as'
     )
 
     def __init__(self, class_, key, callable_, dispatch,
@@ -898,8 +879,9 @@ class CollectionAttributeImpl(AttributeImpl):
             copy_function = self.__copy
         self.copy = copy_function
         self.collection_factory = typecallable
-        self._append_token = None
-        self._remove_token = None
+        self._append_token = Event(self, OP_APPEND)
+        self._remove_token = Event(self, OP_REMOVE)
+        self._bulk_replace_token = Event(self, OP_BULK_REPLACE)
         self._duck_typed_as = util.duck_type_collection(
             self.collection_factory())
 
@@ -913,14 +895,6 @@ class CollectionAttributeImpl(AttributeImpl):
             def unlink(target, collection, collection_adapter):
                 collection._sa_linker(None)
 
-    def _init_append_token(self):
-        self._append_token = Event(self, OP_APPEND)
-        return self._append_token
-
-    def _init_remove_token(self):
-        self._remove_token = Event(self, OP_REMOVE)
-        return self._remove_token
-
     def __copy(self, item):
         return [y for y in collections.collection_adapter(item)]
 
@@ -966,8 +940,7 @@ class CollectionAttributeImpl(AttributeImpl):
     def fire_append_event(self, state, dict_, value, initiator):
         for fn in self.dispatch.append:
             value = fn(
-                state, value,
-                initiator or self._append_token or self._init_append_token())
+                state, value, initiator or self._append_token)
 
         state._modified_event(dict_, self, NEVER_SET, True)
 
@@ -984,8 +957,7 @@ class CollectionAttributeImpl(AttributeImpl):
             self.sethasparent(instance_state(value), state, False)
 
         for fn in self.dispatch.remove:
-            fn(state, value,
-               initiator or self._remove_token or self._init_remove_token())
+            fn(state, value, initiator or self._remove_token)
 
         state._modified_event(dict_, self, NEVER_SET, True)
 
@@ -1081,7 +1053,7 @@ class CollectionAttributeImpl(AttributeImpl):
                     iterable = iter(iterable)
         new_values = list(iterable)
 
-        evt = Event(self, OP_BULK_REPLACE)
+        evt = self._bulk_replace_token
 
         self.dispatch.bulk_replace(state, new_values, evt)
 
@@ -1156,7 +1128,13 @@ class CollectionAttributeImpl(AttributeImpl):
 def backref_listeners(attribute, key, uselist):
     """Apply listeners to synchronize a two-way relationship."""
 
-    # use easily recognizable names for stack traces
+    # use easily recognizable names for stack traces.
+
+    # in the sections marked "tokens to test for a recursive loop",
+    # this is somewhat brittle and very performance-sensitive logic
+    # that is specific to how we might arrive at each event.  a marker
+    # that can target us directly to arguments being invoked against
+    # the impl might be simpler, but could interfere with other systems.
 
     parent_token = attribute.impl.parent_token
     parent_impl = attribute.impl
@@ -1186,24 +1164,35 @@ def backref_listeners(attribute, key, uselist):
                 instance_dict(oldchild)
             impl = old_state.manager[key].impl
 
-            if initiator.impl is not impl or \
-                    initiator.op is OP_APPEND:
+            # tokens to test for a recursive loop.
+            if not impl.collection and not impl.dynamic:
+                check_recursive_token = impl._replace_token
+            else:
+                check_recursive_token = impl._remove_token
+
+            if initiator is not check_recursive_token:
                 impl.pop(old_state,
                          old_dict,
                          state.obj(),
-                         parent_impl._append_token or
-                            parent_impl._init_append_token(),
+                         parent_impl._append_token,
                          passive=PASSIVE_NO_FETCH)
 
         if child is not None:
             child_state, child_dict = instance_state(child),\
                 instance_dict(child)
             child_impl = child_state.manager[key].impl
+
             if initiator.parent_token is not parent_token and \
                     initiator.parent_token is not child_impl.parent_token:
                 _acceptable_key_err(state, initiator, child_impl)
-            elif initiator.impl is not child_impl or \
-                    initiator.op is OP_REMOVE:
+
+            # tokens to test for a recursive loop.
+            check_append_token = child_impl._append_token
+            check_bulk_replace_token = child_impl._bulk_replace_token \
+                if child_impl.collection else None
+
+            if initiator is not check_append_token and \
+                    initiator is not check_bulk_replace_token:
                 child_impl.append(
                     child_state,
                     child_dict,
@@ -1223,8 +1212,14 @@ def backref_listeners(attribute, key, uselist):
         if initiator.parent_token is not parent_token and \
                 initiator.parent_token is not child_impl.parent_token:
             _acceptable_key_err(state, initiator, child_impl)
-        elif initiator.impl is not child_impl or \
-                initiator.op is OP_REMOVE:
+
+        # tokens to test for a recursive loop.
+        check_append_token = child_impl._append_token
+        check_bulk_replace_token = child_impl._bulk_replace_token \
+            if child_impl.collection else None
+
+        if initiator is not check_append_token and \
+                initiator is not check_bulk_replace_token:
             child_impl.append(
                 child_state,
                 child_dict,
@@ -1238,8 +1233,18 @@ def backref_listeners(attribute, key, uselist):
             child_state, child_dict = instance_state(child),\
                 instance_dict(child)
             child_impl = child_state.manager[key].impl
-            if initiator.impl is not child_impl or \
-                    initiator.op is OP_APPEND:
+
+            # tokens to test for a recursive loop.
+            if not child_impl.collection and not child_impl.dynamic:
+                check_remove_token = child_impl._remove_token
+                check_replace_token = child_impl._replace_token
+            else:
+                check_remove_token = child_impl._remove_token
+                check_replace_token = child_impl._bulk_replace_token \
+                    if child_impl.collection else None
+
+            if initiator is not check_remove_token and \
+                    initiator is not check_replace_token:
                 child_impl.pop(
                     child_state,
                     child_dict,
@@ -1648,8 +1653,7 @@ def flag_modified(instance, key):
     """
     state, dict_ = instance_state(instance), instance_dict(instance)
     impl = state.manager[key].impl
-    impl.dispatch.modified(
-        state, impl._modified_token or impl._init_modified_token())
+    impl.dispatch.modified(state, impl._modified_token)
     state._modified_event(dict_, impl, NO_VALUE, is_userland=True)
 
 
index ffb4405aaafca92355b3c61a6b4a4d4b8bf71f7a..73d9ef3bb7fd5eec288a2e5e8011365fe061c2d2 100644 (file)
@@ -47,6 +47,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
     default_accepts_scalar_loader = False
     supports_population = False
     collection = False
+    dynamic = True
 
     def __init__(self, class_, key, typecallable,
                  dispatch,
index 7d7ba736c4d4f4f92fd1ce06a85d4d9ca545f350..f32365fd6134f190d7b983d2f7086641f0c7ce9c 100644 (file)
@@ -17,7 +17,7 @@ from sqlalchemy.orm import mapper, relationship, create_session, \
     class_mapper, backref, sessionmaker, Session
 from sqlalchemy.orm import attributes, exc as orm_exc
 from sqlalchemy import testing
-from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_, is_
 from sqlalchemy.testing import fixtures
 from test.orm import _fixtures
 
@@ -235,6 +235,63 @@ class O2MCollectionTest(_fixtures.FixtureTest):
         assert a1 not in u1.addresses
         assert a1 in u2.addresses
 
+    def test_collection_assignment_mutates_previous_one(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        u1 = User(name='jack')
+        u2 = User(name='ed')
+        a1 = Address(email_address='a1')
+        u1.addresses.append(a1)
+
+        is_(a1.user, u1)
+
+        u2.addresses = [a1]
+
+        eq_(u1.addresses, [])
+
+        is_(a1.user, u2)
+
+    def test_collection_assignment_mutates_previous_two(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        u1 = User(name='jack')
+        a1 = Address(email_address='a1')
+
+        u1.addresses.append(a1)
+
+        is_(a1.user, u1)
+
+        u1.addresses = []
+        is_(a1.user, None)
+
+    def test_del_from_collection(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        u1 = User(name='jack')
+        a1 = Address(email_address='a1')
+
+        u1.addresses.append(a1)
+
+        is_(a1.user, u1)
+
+        del u1.addresses[0]
+
+        is_(a1.user, None)
+
+    def test_del_from_scalar(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        u1 = User(name='jack')
+        a1 = Address(email_address='a1')
+
+        u1.addresses.append(a1)
+
+        is_(a1.user, u1)
+
+        del a1.user
+
+        assert a1 not in u1.addresses
+
 
 class O2OScalarBackrefMoveTest(_fixtures.FixtureTest):
     run_inserts = None
@@ -592,6 +649,20 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest):
         session.commit()
         eq_(k1.items, [i1])
 
+    def test_bulk_replace(self):
+        Item, Keyword = (self.classes.Item, self.classes.Keyword)
+
+        k1 = Keyword(name='k1')
+        k2 = Keyword(name='k2')
+        k3 = Keyword(name='k3')
+        i1 = Item(description='i1', keywords=[k1, k2])
+        i2 = Item(description='i2', keywords=[k3])
+
+        i1.keywords = [k2, k3]
+        assert i1 in k3.items
+        assert i2 in k3.items
+        assert i1 not in k1.items
+
 
 class M2MScalarMoveTest(_fixtures.FixtureTest):
     run_inserts = None