]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Create new event for collection add w/o mutation
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 May 2021 13:26:03 +0000 (09:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 May 2021 13:26:03 +0000 (09:26 -0400)
Fixed issue when using :paramref:`_orm.relationship.cascade_backrefs`
parameter set to ``False``, which per :ref:`change_5150` is set to become
the standard behavior in SQLAlchemy 2.0, where adding the item to a
collection that uniquifies, such as ``set`` or ``dict`` would fail to fire
a cascade event if the object were already associated in that collection
via the backref. This fix represents a fundamental change in the collection
mechanics by introducing a new event state which can fire off for a
collection mutation even if there is no net change on the collection; the
action is now suited using a new event hook
:meth:`_orm.AttributeEvents.append_wo_mutation`.

Fixes: #6471
Change-Id: Ic50413f7e62440dad33ab84838098ea62ff4e815

doc/build/changelog/unreleased_14/6471.rst [new file with mode: 0644]
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/test_cascade.py
test/orm/test_collection.py

diff --git a/doc/build/changelog/unreleased_14/6471.rst b/doc/build/changelog/unreleased_14/6471.rst
new file mode 100644 (file)
index 0000000..26cf8db
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 6471
+
+    Fixed issue when using :paramref:`_orm.relationship.cascade_backrefs`
+    parameter set to ``False``, which per :ref:`change_5150` is set to become
+    the standard behavior in SQLAlchemy 2.0, where adding the item to a
+    collection that uniquifies, such as ``set`` or ``dict`` would fail to fire
+    a cascade event if the object were already associated in that collection
+    via the backref. This fix represents a fundamental change in the collection
+    mechanics by introducing a new event state which can fire off for a
+    collection mutation even if there is no net change on the collection; the
+    action is now suited using a new event hook
+    :meth:`_orm.AttributeEvents.append_wo_mutation`.
+
+
index 05b12dda27f08ac0c4c30934ded87707f460b810..105a9cfd2daf4c00dbfc395c75f9615ccf1fdb2f 100644 (file)
@@ -1389,6 +1389,12 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return value
 
+    def fire_append_wo_mutation_event(self, state, dict_, value, initiator):
+        for fn in self.dispatch.append_wo_mutation:
+            value = fn(state, value, initiator or self._append_token)
+
+        return value
+
     def fire_pre_remove_event(self, state, dict_, initiator):
         """A special event used for pop() operations.
 
index 63278fb7e0a09a47074e5e983a4159a1569a9319..3874cd5f9a24b93dbea114e9000566e0634f58db 100644 (file)
@@ -708,6 +708,32 @@ class CollectionAdapter(object):
 
     __nonzero__ = __bool__
 
+    def fire_append_wo_mutation_event(self, item, initiator=None):
+        """Notify that a entity is entering the collection but is already
+        present.
+
+
+        Initiator is a token owned by the InstrumentedAttribute that
+        initiated the membership mutation, and should be left as None
+        unless you are passing along an initiator value from a chained
+        operation.
+
+        .. versionadded:: 1.4.15
+
+        """
+        if initiator is not False:
+            if self.invalidated:
+                self._warn_invalidated()
+
+            if self.empty:
+                self._reset_empty()
+
+            return self.attr.fire_append_wo_mutation_event(
+                self.owner_state, self.owner_state.dict, item, initiator
+            )
+        else:
+            return item
+
     def fire_append_event(self, item, initiator=None):
         """Notify that a entity has entered the collection.
 
@@ -1083,6 +1109,18 @@ def _instrument_membership_mutator(method, before, argument, after):
     return wrapper
 
 
+def __set_wo_mutation(collection, item, _sa_initiator=None):
+    """Run set wo mutation events.
+
+    The collection is not mutated.
+
+    """
+    if _sa_initiator is not False:
+        executor = collection._sa_adapter
+        if executor:
+            executor.fire_append_wo_mutation_event(item, _sa_initiator)
+
+
 def __set(collection, item, _sa_initiator=None):
     """Run set events.
 
@@ -1351,7 +1389,11 @@ def _dict_decorators():
                 self.__setitem__(key, default)
                 return default
             else:
-                return self.__getitem__(key)
+                value = self.__getitem__(key)
+                if value is default:
+                    __set_wo_mutation(self, value, None)
+
+                return value
 
         _tidy(setdefault)
         return setdefault
@@ -1363,13 +1405,19 @@ def _dict_decorators():
                     for key in list(__other):
                         if key not in self or self[key] is not __other[key]:
                             self[key] = __other[key]
+                        else:
+                            __set_wo_mutation(self, __other[key], None)
                 else:
                     for key, value in __other:
                         if key not in self or self[key] is not value:
                             self[key] = value
+                        else:
+                            __set_wo_mutation(self, value, None)
             for key in kw:
                 if key not in self or self[key] is not kw[key]:
                     self[key] = kw[key]
+                else:
+                    __set_wo_mutation(self, kw[key], None)
 
         _tidy(update)
         return update
@@ -1410,6 +1458,8 @@ def _set_decorators():
         def add(self, value, _sa_initiator=None):
             if value not in self:
                 value = __set(self, value, _sa_initiator)
+            else:
+                __set_wo_mutation(self, value, _sa_initiator)
             # testlib.pragma exempt:__hash__
             fn(self, value)
 
index 0824ae7dec96200486e43449cfa2c30bdef321c4..926c2dea7b1475c8c4888f838a2990c8b27dea61 100644 (file)
@@ -2295,6 +2295,36 @@ class AttributeEvents(event.Events):
 
         """
 
+    def append_wo_mutation(self, target, value, initiator):
+        """Receive a collection append event where the collection was not
+        actually mutated.
+
+        This event differs from :meth:`_orm.AttributeEvents.append` in that
+        it is fired off for de-duplicating collections such as sets and
+        dictionaries, when the object already exists in the target collection.
+        The event does not have a return value and the identity of the
+        given object cannot be changed.
+
+        The event is used for cascading objects into a :class:`_orm.Session`
+        when the collection has already been mutated via a backref event.
+
+        :param target: the object instance receiving the event.
+          If the listener is registered with ``raw=True``, this will
+          be the :class:`.InstanceState` object.
+        :param value: the value that would be appended if the object did not
+          already exist in the collection.
+        :param initiator: An instance of :class:`.attributes.Event`
+          representing the initiation of the event.  May be modified
+          from its original value by backref handlers in order to control
+          chained event propagation, as well as be inspected for information
+          about the source of the event.
+
+        :return: No return value is defined for this event.
+
+        .. versionadded:: 1.4.15
+
+        """
+
     def bulk_replace(self, target, values, initiator):
         """Receive a collection 'bulk replace' event.
 
index 77bbb47510148b6aa20c6f7109a195bd8dc24574..ae99da059c7d5a61d8eb6662de112e5770718104 100644 (file)
@@ -144,6 +144,7 @@ def track_cascade_events(descriptor, prop):
                     sess.expunge(oldvalue)
         return newvalue
 
+    event.listen(descriptor, "append_wo_mutation", append, raw=True)
     event.listen(descriptor, "append", append, raw=True, retval=True)
     event.listen(descriptor, "remove", remove, raw=True, retval=True)
     event.listen(descriptor, "set", set_, raw=True, retval=True)
index 7a1ac35576004b647990030cc5d3d5c1c0978f37..a7156be4a5726e3afca431f7977b5864e80c3242 100644 (file)
@@ -20,6 +20,7 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import util as orm_util
 from sqlalchemy.orm.attributes import instance_state
+from sqlalchemy.orm.collections import attribute_mapped_collection
 from sqlalchemy.orm.decl_api import declarative_base
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -4439,3 +4440,88 @@ class ViewonlyFlagWarningTest(fixtures.MappedTest):
             cascade="all, delete, delete-orphan",
             viewonly=True,
         )
+
+
+class CollectionCascadesDespiteBackrefTest(fixtures.TestBase):
+    @testing.fixture
+    def cascade_fixture(self, registry):
+        def go(collection_class):
+            @registry.mapped
+            class A(object):
+                __tablename__ = "a"
+
+                id = Column(Integer, primary_key=True)
+                bs = relationship(
+                    "B", backref="a", collection_class=collection_class
+                )
+
+            @registry.mapped
+            class B(object):
+                __tablename__ = "b_"
+                id = Column(Integer, primary_key=True)
+                a_id = Column(ForeignKey("a.id"))
+                key = Column(String)
+
+            return A, B
+
+        yield go
+
+    @testing.combinations(
+        (set, "add"),
+        (list, "append"),
+        (attribute_mapped_collection("key"), "__setitem__"),
+        (attribute_mapped_collection("key"), "setdefault"),
+        (attribute_mapped_collection("key"), "update_dict"),
+        (attribute_mapped_collection("key"), "update_kw"),
+        argnames="collection_class,methname",
+    )
+    @testing.combinations((True,), (False,), argnames="future")
+    def test_cascades_on_collection(
+        self, cascade_fixture, collection_class, methname, future
+    ):
+        A, B = cascade_fixture(collection_class)
+
+        s = Session(future=future)
+
+        a1 = A()
+        s.add(a1)
+
+        b1 = B(key="b1")
+        b2 = B(key="b2")
+        b3 = B(key="b3")
+
+        b1.a = a1
+        b3.a = a1
+
+        if future:
+            assert b1 not in s
+            assert b3 not in s
+        else:
+            assert b1 in s
+            assert b3 in s
+
+        if methname == "__setitem__":
+            meth = getattr(a1.bs, methname)
+            meth(b1.key, b1)
+            meth(b2.key, b2)
+        elif methname == "setdefault":
+            meth = getattr(a1.bs, methname)
+            meth(b1.key, b1)
+            meth(b2.key, b2)
+        elif methname == "update_dict" and isinstance(a1.bs, dict):
+            a1.bs.update({b1.key: b1, b2.key: b2})
+        elif methname == "update_kw" and isinstance(a1.bs, dict):
+            a1.bs.update(b1=b1, b2=b2)
+        else:
+            meth = getattr(a1.bs, methname)
+            meth(b1)
+            meth(b2)
+
+        assert b1 in s
+        assert b2 in s
+
+        if future:
+            assert b3 not in s  # the event never triggers from reverse
+        else:
+            # old behavior
+            assert b3 in s
index 2a0aafbbcc63ca9106ce53052a1a716c45ea0cb3..6188af7690a5610def12629dab2613c5dae8d03a 100644 (file)
@@ -32,6 +32,7 @@ class Canary(object):
         self.data = set()
         self.added = set()
         self.removed = set()
+        self.appended_wo_mutation = set()
         self.dupe_check = True
 
     @contextlib.contextmanager
@@ -44,6 +45,7 @@ class Canary(object):
 
     def listen(self, attr):
         event.listen(attr, "append", self.append)
+        event.listen(attr, "append_wo_mutation", self.append_wo_mutation)
         event.listen(attr, "remove", self.remove)
         event.listen(attr, "set", self.set)
 
@@ -54,6 +56,11 @@ class Canary(object):
         self.data.add(value)
         return value
 
+    def append_wo_mutation(self, obj, value, initiator):
+        if self.dupe_check:
+            assert value in self.added
+            self.appended_wo_mutation.add(value)
+
     def remove(self, obj, value, initiator):
         if self.dupe_check:
             assert value not in self.removed
@@ -652,6 +659,48 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         self._test_list_bulk(ListIsh)
         self.assert_(getattr(ListIsh, "_sa_instrumented") == id(ListIsh))
 
+    def _test_set_wo_mutation(self, typecallable, creator=None):
+        if creator is None:
+            creator = self.entity_maker
+
+        class Foo(object):
+            pass
+
+        canary = Canary()
+        instrumentation.register_class(Foo)
+        d = attributes.register_attribute(
+            Foo,
+            "attr",
+            uselist=True,
+            typecallable=typecallable,
+            useobject=True,
+        )
+        canary.listen(d)
+
+        obj = Foo()
+
+        e = creator()
+
+        obj.attr.add(e)
+
+        assert e in canary.added
+        assert e not in canary.appended_wo_mutation
+
+        obj.attr.add(e)
+        assert e in canary.added
+        assert e in canary.appended_wo_mutation
+
+        e = creator()
+
+        obj.attr.update({e})
+
+        assert e in canary.added
+        assert e not in canary.appended_wo_mutation
+
+        obj.attr.update({e})
+        assert e in canary.added
+        assert e in canary.appended_wo_mutation
+
     def _test_set(self, typecallable, creator=None):
         if creator is None:
             creator = self.entity_maker
@@ -976,6 +1025,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         self._test_adapter(set)
         self._test_set(set)
         self._test_set_bulk(set)
+        self._test_set_wo_mutation(set)
 
     def test_set_subclass(self):
         class MySet(set):
@@ -1060,6 +1110,67 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         self._test_set_bulk(SetIsh)
         self.assert_(getattr(SetIsh, "_sa_instrumented") == id(SetIsh))
 
+    def _test_dict_wo_mutation(self, typecallable, creator=None):
+        if creator is None:
+            creator = self.dictable_entity
+
+        class Foo(object):
+            pass
+
+        canary = Canary()
+        instrumentation.register_class(Foo)
+        d = attributes.register_attribute(
+            Foo,
+            "attr",
+            uselist=True,
+            typecallable=typecallable,
+            useobject=True,
+        )
+        canary.listen(d)
+
+        obj = Foo()
+
+        e = creator()
+
+        obj.attr[e.a] = e
+        assert e in canary.added
+        assert e not in canary.appended_wo_mutation
+
+        with canary.defer_dupe_check():
+            # __setitem__ sets every time
+            obj.attr[e.a] = e
+            assert e in canary.added
+            assert e not in canary.appended_wo_mutation
+
+        if hasattr(obj.attr, "update"):
+            e = creator()
+            obj.attr.update({e.a: e})
+            assert e in canary.added
+            assert e not in canary.appended_wo_mutation
+
+            obj.attr.update({e.a: e})
+            assert e in canary.added
+            assert e in canary.appended_wo_mutation
+
+            e = creator()
+            obj.attr.update(**{e.a: e})
+            assert e in canary.added
+            assert e not in canary.appended_wo_mutation
+
+            obj.attr.update(**{e.a: e})
+            assert e in canary.added
+            assert e in canary.appended_wo_mutation
+
+        if hasattr(obj.attr, "setdefault"):
+            e = creator()
+            obj.attr.setdefault(e.a, e)
+            assert e in canary.added
+            assert e not in canary.appended_wo_mutation
+
+            obj.attr.setdefault(e.a, e)
+            assert e in canary.added
+            assert e in canary.appended_wo_mutation
+
     def _test_dict(self, typecallable, creator=None):
         if creator is None:
             creator = self.dictable_entity
@@ -1284,6 +1395,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         )
         self._test_dict(MyDict)
         self._test_dict_bulk(MyDict)
+        self._test_dict_wo_mutation(MyDict)
         self.assert_(getattr(MyDict, "_sa_instrumented") == id(MyDict))
 
     def test_dict_subclass2(self):
@@ -1296,6 +1408,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         )
         self._test_dict(MyEasyDict)
         self._test_dict_bulk(MyEasyDict)
+        self._test_dict_wo_mutation(MyEasyDict)
         self.assert_(getattr(MyEasyDict, "_sa_instrumented") == id(MyEasyDict))
 
     def test_dict_subclass3(self, ordered_dict_mro):
@@ -1309,6 +1422,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         )
         self._test_dict(MyOrdered)
         self._test_dict_bulk(MyOrdered)
+        self._test_dict_wo_mutation(MyOrdered)
         self.assert_(getattr(MyOrdered, "_sa_instrumented") == id(MyOrdered))
 
     def test_dict_duck(self):
@@ -1359,6 +1473,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         )
         self._test_dict(DictLike)
         self._test_dict_bulk(DictLike)
+        self._test_dict_wo_mutation(DictLike)
         self.assert_(getattr(DictLike, "_sa_instrumented") == id(DictLike))
 
     def test_dict_emulates(self):
@@ -1411,6 +1526,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
         )
         self._test_dict(DictIsh)
         self._test_dict_bulk(DictIsh)
+        self._test_dict_wo_mutation(DictIsh)
         self.assert_(getattr(DictIsh, "_sa_instrumented") == id(DictIsh))
 
     def _test_object(self, typecallable, creator=None):