]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure propagate_to_loaders honored at the sub-loader level
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 May 2021 01:21:35 +0000 (21:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 May 2021 02:11:53 +0000 (22:11 -0400)
Fixed additional regression caused by "eager loaders run on unexpire"
feature :ticket:`1763` where the feature would run for a
``contains_eager()`` eagerload option in the case that the
``contains_eager()`` were chained to an additional eager loader option,
which would then produce an incorrect query as the original query-bound
join criteria were no longer present.

The contains_eager() option correctly included
propagate_to_loaders=False however this would not be considered
if the contains_eager() were chained and therefore bundled inside
of an enclosing loader.  We don't want to turn off propagation
completely in that case because we still want the other
loaders inside to be handled individually, so add a check
as each option is moved into the query context.

Fixes: #6449
Change-Id: Icd1d6611095c20ae44ff5d2df734c24770fc8812

doc/build/changelog/unreleased_14/6449.rst [new file with mode: 0644]
lib/sqlalchemy/orm/strategy_options.py
test/orm/test_expire.py

diff --git a/doc/build/changelog/unreleased_14/6449.rst b/doc/build/changelog/unreleased_14/6449.rst
new file mode 100644 (file)
index 0000000..56dd781
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 6449
+
+    Fixed additional regression caused by "eager loaders run on unexpire"
+    feature :ticket:`1763` where the feature would run for a
+    ``contains_eager()`` eagerload option in the case that the
+    ``contains_eager()`` were chained to an additional eager loader option,
+    which would then produce an incorrect query as the original query-bound
+    join criteria were no longer present.
index 8602f37b60c4ee057664d8846e9b8d397202d7f9..f61cf835d276b46fc4b5b94ccc986dc8e40c6731 100644 (file)
@@ -185,9 +185,12 @@ class Load(Generative, LoaderOption):
         self._process(compile_state, not bool(compile_state.current_path))
 
     def _process(self, compile_state, raiseerr):
+        is_refresh = compile_state.compile_options._for_refresh_state
         current_path = compile_state.current_path
         if current_path:
             for (token, start_path), loader in self.context.items():
+                if is_refresh and not loader.propagate_to_loaders:
+                    continue
                 chopped_start_path = self._chop_path(start_path, current_path)
                 if chopped_start_path is not None:
                     compile_state.attributes[
@@ -705,9 +708,12 @@ class _UnboundLoad(Load):
 
     def _process(self, compile_state, raiseerr):
         dedupes = compile_state.attributes["_unbound_load_dedupes"]
+        is_refresh = compile_state.compile_options._for_refresh_state
         for val in self._to_bind:
             if val not in dedupes:
                 dedupes.add(val)
+                if is_refresh and not val.propagate_to_loaders:
+                    continue
                 val._bind_loader(
                     [
                         ent.entity_zero
index 63ed7a3a71a608073c860bbf75e25bc75c44058c..4ef585f271913886b423296cbd71eb92dcceca8d 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import attributes
+from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm import defer
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import exc as orm_exc
@@ -25,6 +26,7 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.assertsql import CountStatements
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -695,6 +697,115 @@ class ExpireTest(_fixtures.FixtureTest):
         eq_(a1.email_address, "foo")
         assert a1 in sess.dirty
 
+    @testing.combinations(
+        ("contains,joined",),
+        ("contains,contains",),
+    )
+    def test_unexpire_eager_dont_include_contains_eager(self, case):
+        """test #6449
+
+        testing that contains_eager is downgraded to lazyload during
+        a refresh, including if additional eager loaders are off the
+        contains_eager
+
+        """
+        orders, Order, users, Address, addresses, User = (
+            self.tables.orders,
+            self.classes.Order,
+            self.tables.users,
+            self.classes.Address,
+            self.tables.addresses,
+            self.classes.User,
+        )
+
+        mapper(
+            User,
+            users,
+            properties={"orders": relationship(Order, order_by=orders.c.id)},
+        )
+        mapper(Address, addresses, properties={"user": relationship(User)})
+        mapper(Order, orders)
+
+        sess = fixture_session(autoflush=False)
+
+        with self.sql_execution_asserter(testing.db) as asserter:
+
+            if case == "contains,joined":
+                a1 = (
+                    sess.query(Address)
+                    .join(Address.user)
+                    .options(
+                        contains_eager(Address.user).joinedload(User.orders)
+                    )
+                    .filter(Address.id == 1)
+                    .one()
+                )
+            elif case == "contains,contains":
+                # legacy query.first() can't be used here because it sets
+                # limit 1 without the correct query wrapping.   1.3 has
+                # the same problem though it renders differently
+                a1 = (
+                    sess.query(Address)
+                    .join(Address.user)
+                    .join(User.orders)
+                    .order_by(Order.id)
+                    .options(
+                        contains_eager(Address.user).contains_eager(
+                            User.orders
+                        )
+                    )
+                    .filter(Address.id == 1)
+                    .one()
+                )
+
+            eq_(
+                a1,
+                Address(
+                    id=1,
+                    user=User(
+                        id=7, orders=[Order(id=1), Order(id=3), Order(id=5)]
+                    ),
+                ),
+            )
+
+        # ensure load with either contains_eager().joinedload() or
+        # contains_eager().contains_eager() worked as expected
+        asserter.assert_(CountStatements(1))
+
+        sess.expire(a1)
+
+        # assert behavior on unexpire
+        with self.sql_execution_asserter(testing.db) as asserter:
+            a1.user
+            assert "user" in a1.__dict__
+
+            if case == "contains,joined":
+                # joinedload took place
+                assert "orders" in a1.user.__dict__
+            elif case == "contains,contains":
+                # contains eager is downgraded to a lazy load
+                assert "orders" not in a1.user.__dict__
+
+            eq_(
+                a1,
+                Address(
+                    id=1,
+                    user=User(
+                        id=7, orders=[Order(id=1), Order(id=3), Order(id=5)]
+                    ),
+                ),
+            )
+
+        if case == "contains,joined":
+            # the joinedloader for Address->User works,
+            # so we get refresh(Address).lazyload(Address.user).
+            # joinedload(User.order)
+            asserter.assert_(CountStatements(2))
+        elif case == "contains,contains":
+            # both contains_eagers become normal loads so we get
+            # refresh(Address).lazyload(Address.user).lazyload(User.order]
+            asserter.assert_(CountStatements(3))
+
     def test_relationship_changes_preserved(self):
         users, Address, addresses, User = (
             self.tables.users,