]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure discarded collection removed from empty collections
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 22 Aug 2019 00:19:43 +0000 (20:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 22 Aug 2019 18:37:48 +0000 (14:37 -0400)
A bulk replace operation was not attending to the previous
list still present in the "_empty_collections" dictionary
which was added as part of #4519.

Fixes: #4519
Change-Id: I3f99f8647c0fb8140b3dfb03686a5d3b90da633f

lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_attributes.py
test/orm/test_eager_relations.py

index d47740e3dbac9a8e0b3585564543c2f292a2032d..2f54fcd329fadc3e1d6d77b2ecd77f5461e5d325 100644 (file)
@@ -671,6 +671,11 @@ class AttributeImpl(object):
     def _default_value(self, state, dict_):
         """Produce an empty value for an uninitialized scalar attribute."""
 
+        assert self.key not in dict_, (
+            "_default_value should only be invoked for an "
+            "uninitialized or expired attribute"
+        )
+
         value = None
         for fn in self.dispatch.init_scalar:
             ret = fn(state, value, dict_)
@@ -1201,6 +1206,11 @@ class CollectionAttributeImpl(AttributeImpl):
     def _default_value(self, state, dict_):
         """Produce an empty collection for an un-initialized attribute"""
 
+        assert self.key not in dict_, (
+            "_default_value should only be invoked for an "
+            "uninitialized or expired attribute"
+        )
+
         if self.key in state._empty_collections:
             return state._empty_collections[self.key]
 
@@ -1321,8 +1331,18 @@ class CollectionAttributeImpl(AttributeImpl):
             new_values, old_collection, new_collection, initiator=evt
         )
 
-        del old._sa_adapter
-        self.dispatch.dispose_collection(state, old, old_collection)
+        self._dispose_previous_collection(state, old, old_collection, True)
+
+    def _dispose_previous_collection(
+        self, state, collection, adapter, fire_event
+    ):
+        del collection._sa_adapter
+
+        # discarding old collection make sure it is not referenced in empty
+        # collections.
+        state._empty_collections.pop(self.key, None)
+        if fire_event:
+            self.dispatch.dispose_collection(state, collection, adapter)
 
     def _invalidate_collection(self, collection):
         adapter = getattr(collection, "_sa_adapter")
@@ -1360,7 +1380,9 @@ class CollectionAttributeImpl(AttributeImpl):
     ):
         """Retrieve the CollectionAdapter associated with the given state.
 
-        Creates a new CollectionAdapter if one does not exist.
+        if user_data is None, retrieves it from the state using normal
+        "get()" rules, which will fire lazy callables or return the "empty"
+        collection value.
 
         """
         if user_data is None:
@@ -1368,7 +1390,7 @@ class CollectionAttributeImpl(AttributeImpl):
             if user_data is PASSIVE_NO_RESULT:
                 return user_data
 
-        return getattr(user_data, "_sa_adapter")
+        return user_data._sa_adapter
 
 
 def backref_listeners(attribute, key, uselist):
@@ -1905,12 +1927,22 @@ def init_collection(obj, key):
 
 
 def init_state_collection(state, dict_, key):
-    """Initialize a collection attribute and return the collection adapter."""
+    """Initialize a collection attribute and return the collection adapter.
+
+    Discards any existing collection which may be there.
 
+    """
     attr = state.manager[key].impl
+
+    old = dict_.pop(key, None)  # discard old collection
+    if old is not None:
+        old_collection = old._sa_adapter
+        attr._dispose_previous_collection(state, old, old_collection, False)
+
     user_data = attr._default_value(state, dict_)
     adapter = attr.get_collection(state, dict_, user_data)
     adapter._reset_empty()
+
     return adapter
 
 
index 8e8242c6693fc3b81fc44f5efe0f56b44fd1cf7d..fc86076b1fe2f45a8f6061a50f810ac137c75be0 100644 (file)
@@ -1981,6 +1981,9 @@ class JoinedLoader(AbstractRelationshipLoader):
 
     def _create_collection_loader(self, context, key, _instance, populators):
         def load_collection_from_joined_new_row(state, dict_, row):
+            # note this must unconditionally clear out any existing collection.
+            # an existing collection would be present only in the case of
+            # populate_existing().
             collection = attributes.init_state_collection(state, dict_, key)
             result_list = util.UniqueAppender(
                 collection, "append_without_event"
index 4498a6ab022827aa0a84e3327930bf157d9162c9..cd9d63c67811bc0a492869a5da5af65a7e71c6e6 100644 (file)
@@ -14,7 +14,9 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_not_
 from sqlalchemy.testing import is_true
+from sqlalchemy.testing import not_in_
 from sqlalchemy.testing.mock import call
 from sqlalchemy.testing.mock import Mock
 from sqlalchemy.testing.util import all_partial_orderings
@@ -3659,6 +3661,65 @@ class EventPropagateTest(fixtures.TestBase):
             canary[:] = []
 
 
+class CollectionInitTest(fixtures.TestBase):
+    def setUp(self):
+        class A(object):
+            pass
+
+        class B(object):
+            pass
+
+        self.A = A
+        self.B = B
+        instrumentation.register_class(A)
+        instrumentation.register_class(B)
+        attributes.register_attribute(A, "bs", uselist=True, useobject=True)
+
+    def test_bulk_replace_resets_empty(self):
+        A = self.A
+        a1 = A()
+        state = attributes.instance_state(a1)
+
+        existing = a1.bs
+
+        is_(state._empty_collections["bs"], existing)
+        is_not_(existing._sa_adapter, None)
+
+        a1.bs = []  # replaces previous "empty" collection
+        not_in_("bs", state._empty_collections)  # empty is replaced
+        is_(existing._sa_adapter, None)
+
+    def test_assert_false_on_default_value(self):
+        A = self.A
+        a1 = A()
+        state = attributes.instance_state(a1)
+
+        attributes.init_state_collection(state, state.dict, "bs")
+
+        assert_raises(
+            AssertionError, A.bs.impl._default_value, state, state.dict
+        )
+
+    def test_loader_inits_collection_already_exists(self):
+        A, B = self.A, self.B
+        a1 = A()
+        b1, b2 = B(), B()
+        a1.bs = [b1, b2]
+        eq_(a1.__dict__["bs"], [b1, b2])
+
+        old = a1.__dict__["bs"]
+        is_not_(old._sa_adapter, None)
+        state = attributes.instance_state(a1)
+
+        # this occurs during a load with populate_existing
+        adapter = attributes.init_state_collection(state, state.dict, "bs")
+
+        new = a1.__dict__["bs"]
+        eq_(new, [])
+        is_(new._sa_adapter, adapter)
+        is_(old._sa_adapter, None)
+
+
 class TestUnlink(fixtures.TestBase):
     def setUp(self):
         class A(object):
index cd20342c1ace5a8ff2aebf68369e792f378c9166..70ca993b79c2f9046f9699ad4f359b72909cddd7 100644 (file)
@@ -3479,6 +3479,34 @@ class LoadOnExistingTest(_fixtures.FixtureTest):
         self.assert_sql_count(testing.db, go, 1)
         assert "addresses" not in u1.__dict__
 
+    def test_populate_existing_propagate(self):
+        # both SelectInLoader and SubqueryLoader receive the loaded collection
+        # at once and use attributes.set_committed_value().  However
+        # joinedloader receives the collection per-row, so has an initial
+        # step where it invokes init_state_collection().  This has to clear
+        # out an existing collection to function correctly with
+        # populate_existing.
+        User, Address, sess = self._eager_config_fixture()
+        u1 = sess.query(User).get(8)
+        u1.addresses[2].email_address = "foofoo"
+        del u1.addresses[1]
+        u1 = sess.query(User).populate_existing().filter_by(id=8).one()
+        # collection is reverted
+        eq_(len(u1.addresses), 3)
+
+        # attributes on related items reverted
+        eq_(u1.addresses[2].email_address, "ed@lala.com")
+
+    def test_no_crash_on_existing(self):
+        User, Address, sess = self._eager_config_fixture()
+        u1 = User(id=12, name="u", addresses=[])
+        sess.add(u1)
+        sess.commit()
+
+        sess.query(User).filter(User.id == 12).options(
+            joinedload(User.addresses)
+        ).first()
+
     def test_loads_second_level_collection_to_scalar(self):
         User, Address, Dingaling, sess = self._collection_to_scalar_fixture()