]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
remove should_nest behavior for contains_eager()
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Sep 2022 19:17:57 +0000 (15:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Sep 2022 21:17:02 +0000 (17:17 -0400)
Fixed regression for 1.4 in :func:`_orm.contains_eager` where the "wrap in
subquery" logic of :func:`_orm.joinedload` would be inadvertently triggered
for use of the :func:`_orm.contains_eager` function with similar statements
(e.g. those that use ``distinct()``, ``limit()`` or ``offset()``). This is
not appropriate for :func:`_orm.contains_eager` which has always had the
contract that the user-defined SQL statement is unmodified with the
exception of adding the appropriate columns.

Also includes an adjustment to the assertion in Label._make_proxy()
which was there to prevent a fixed label name from being anonymized;
if the label is already anonymous, the change should proceed.
This logic was being hit before the contains_eager behavior was
adjusted. With the adjustment, this code is not used.

Fixes: #8569
Change-Id: I161e65041c0162fd2b83cbef40f57a50fcfaf0fd

lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/elements.py
test/orm/test_core_compilation.py
test/orm/test_eager_relations.py
test/orm/test_query.py
test/sql/test_selectable.py

index ff0cdd68079996968fce761d3cb3fd524b9ae2a9..4f24103df25eab2182e3cf491066c9939d12821b 100644 (file)
@@ -529,6 +529,7 @@ class ORMFromStatementCompileState(ORMCompileState):
 
     _has_orm_entities = False
     multi_row_eager_loaders = False
+    eager_adding_joins = False
     compound_eager_adapter = None
 
     extra_criteria_entities = _EMPTY_DICT
@@ -794,6 +795,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
     _has_orm_entities = False
     multi_row_eager_loaders = False
+    eager_adding_joins = False
     compound_eager_adapter = None
 
     correlate = None
@@ -1106,7 +1108,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         if self.order_by is False:
             self.order_by = None
 
-        if self.multi_row_eager_loaders and self._should_nest_selectable:
+        if (
+            self.multi_row_eager_loaders
+            and self.eager_adding_joins
+            and self._should_nest_selectable
+        ):
             self.statement = self._compound_eager_statement()
         else:
             self.statement = self._simple_statement()
index 41e598c38e0f8cdc70d0ad6c77a5779ffb3f66ec..19c6493db43c8b1321d41fc7dd4de29c71082052 100644 (file)
@@ -2108,6 +2108,9 @@ class JoinedLoader(AbstractRelationshipLoader):
         )
 
         if user_defined_adapter is not False:
+
+            # setup an adapter but dont create any JOIN, assume it's already
+            # in the query
             (
                 clauses,
                 adapter,
@@ -2119,6 +2122,11 @@ class JoinedLoader(AbstractRelationshipLoader):
                 adapter,
                 user_defined_adapter,
             )
+
+            # don't do "wrap" for multi-row, we want to wrap
+            # limited/distinct SELECT,
+            # because we want to put the JOIN on the outside.
+
         else:
             # if not via query option, check for
             # a cycle
@@ -2129,6 +2137,7 @@ class JoinedLoader(AbstractRelationshipLoader):
                 elif path.contains_mapper(self.mapper):
                     return
 
+            # add the JOIN and create an adapter
             (
                 clauses,
                 adapter,
@@ -2145,6 +2154,10 @@ class JoinedLoader(AbstractRelationshipLoader):
                 chained_from_outerjoin,
             )
 
+            # for multi-row, we want to wrap limited/distinct SELECT,
+            # because we want to put the JOIN on the outside.
+            compile_state.eager_adding_joins = True
+
         with_poly_entity = path.get(
             compile_state.attributes, "path_with_polymorphic", None
         )
index cfbf24f3c4c2cb5f3eb162c6470aeea1c26d76f2..8167dc7e45fde5173bc677135a230aacdc5ad89f 100644 (file)
@@ -4461,7 +4461,7 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
         # when a label name conflicts with other columns and select()
         # is attempting to disambiguate an explicit label, which is not what
         # the user would want.   See issue #6090.
-        if key != self.name:
+        if key != self.name and not isinstance(self.name, _anonymous_label):
             raise exc.InvalidRequestError(
                 "Label name %s is being renamed to an anonymous label due "
                 "to disambiguation "
index 5807b619f35ac3c7ef69fa0ca8cc44e9fb26123e..efa4c773a53dda0c49eba0f4f6ece59b8deae96d 100644 (file)
@@ -2,6 +2,7 @@ from sqlalchemy import bindparam
 from sqlalchemy import Column
 from sqlalchemy import delete
 from sqlalchemy import exc
+from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import insert
 from sqlalchemy import inspect
@@ -10,6 +11,7 @@ from sqlalchemy import literal_column
 from sqlalchemy import null
 from sqlalchemy import or_
 from sqlalchemy import select
+from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import union
@@ -1068,20 +1070,35 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
         self.mapper_registry.map_imperatively(
             User,
             users,
+            properties={
+                "addresses": relationship(Address, back_populates="user")
+            },
         )
 
         self.mapper_registry.map_imperatively(
             Address,
             addresses,
             properties={
-                "user": relationship(
-                    User,
-                )
+                "user": relationship(User, back_populates="addresses")
             },
         )
 
         return User, Address
 
+    @testing.fixture
+    def hard_labeled_self_ref_fixture(self, decl_base):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey("a.id"))
+            data = Column(String)
+            data_lower = column_property(func.lower(data).label("hardcoded"))
+
+            as_ = relationship("A")
+
+        return A
+
     def test_no_joinedload_embedded(self, plain_fixture):
         User, Address = plain_fixture
 
@@ -1190,22 +1207,84 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
             "ON users_1.id = addresses.user_id",
         )
 
-    def test_contains_eager_outermost(self, plain_fixture):
+    def test_joinedload_outermost_w_wrapping_elements(self, plain_fixture):
         User, Address = plain_fixture
 
         stmt = (
-            select(Address)
-            .join(Address.user)
-            .options(contains_eager(Address.user))
+            select(User)
+            .options(joinedload(User.addresses))
+            .limit(10)
+            .distinct()
         )
 
-        # render joined eager loads with stringify
         self.assert_compile(
             stmt,
-            "SELECT users.id, users.name, addresses.id AS id_1, "
-            "addresses.user_id, "
-            "addresses.email_address "
-            "FROM addresses JOIN users ON users.id = addresses.user_id",
+            "SELECT anon_1.id, anon_1.name, addresses_1.id AS id_1, "
+            "addresses_1.user_id, addresses_1.email_address FROM "
+            "(SELECT DISTINCT users.id AS id, users.name AS name FROM users "
+            "LIMIT :param_1) "
+            "AS anon_1 LEFT OUTER JOIN addresses AS addresses_1 "
+            "ON anon_1.id = addresses_1.user_id",
+        )
+
+    def test_contains_eager_outermost_w_wrapping_elements(self, plain_fixture):
+        """test #8569"""
+
+        User, Address = plain_fixture
+
+        stmt = (
+            select(User)
+            .join(User.addresses)
+            .options(contains_eager(User.addresses))
+            .limit(10)
+            .distinct()
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT DISTINCT addresses.id, addresses.user_id, "
+            "addresses.email_address, users.id AS id_1, users.name "
+            "FROM users JOIN addresses ON users.id = addresses.user_id "
+            "LIMIT :param_1",
+        )
+
+    def test_joinedload_hard_labeled_selfref(
+        self, hard_labeled_self_ref_fixture
+    ):
+        """test #8569"""
+
+        A = hard_labeled_self_ref_fixture
+
+        stmt = select(A).options(joinedload(A.as_)).distinct()
+        self.assert_compile(
+            stmt,
+            "SELECT anon_1.hardcoded, anon_1.id, anon_1.a_id, anon_1.data, "
+            "lower(a_1.data) AS lower_1, a_1.id AS id_1, a_1.a_id AS a_id_1, "
+            "a_1.data AS data_1 FROM (SELECT DISTINCT lower(a.data) AS "
+            "hardcoded, a.id AS id, a.a_id AS a_id, a.data AS data FROM a) "
+            "AS anon_1 LEFT OUTER JOIN a AS a_1 ON anon_1.id = a_1.a_id",
+        )
+
+    def test_contains_eager_hard_labeled_selfref(
+        self, hard_labeled_self_ref_fixture
+    ):
+        """test #8569"""
+
+        A = hard_labeled_self_ref_fixture
+
+        a1 = aliased(A)
+        stmt = (
+            select(A)
+            .join(A.as_.of_type(a1))
+            .options(contains_eager(A.as_.of_type(a1)))
+            .distinct()
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT DISTINCT lower(a.data) AS hardcoded, "
+            "lower(a_1.data) AS hardcoded, a_1.id, a_1.a_id, a_1.data, "
+            "a.id AS id_1, a.a_id AS a_id_1, a.data AS data_1 "
+            "FROM a JOIN a AS a_1 ON a.id = a_1.a_id",
         )
 
     def test_column_properties(self, column_property_fixture):
index a5bd4c75a28935e0a9094b3c5140b3b06c306552..ef6d4b684f7bc3100144c807e8bb0592e71d0a2b 100644 (file)
@@ -3152,10 +3152,14 @@ class SelectUniqueTest(_fixtures.FixtureTest):
 
         eq_(result.scalars().all(), self.static.address_user_result)
 
-    def test_unique_error(self):
+    @testing.combinations(joinedload, contains_eager)
+    def test_unique_error(self, opt):
         User = self.classes.User
 
-        stmt = select(User).options(joinedload(User.addresses))
+        stmt = select(User).options(opt(User.addresses))
+        if opt is contains_eager:
+            stmt = stmt.join(User.addresses)
+
         s = fixture_session()
         result = s.execute(stmt)
 
index 2e3d176c95f4f8906aed8bbed0e08a36fc29ff57..559f4ed9d51900abdd95e1202b3a346871679a4e 100644 (file)
@@ -5487,6 +5487,24 @@ class YieldTest(_fixtures.FixtureTest):
             q.all,
         )
 
+    def test_no_contains_eager_opt(self):
+        self._eagerload_mappings()
+
+        User = self.classes.User
+        sess = fixture_session()
+        q = (
+            sess.query(User)
+            .join(User.addresses)
+            .options(contains_eager(User.addresses))
+            .yield_per(1)
+        )
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            "Can't use yield_per with eager loaders that require "
+            "uniquing or row buffering",
+            q.all,
+        )
+
     def test_no_subqueryload_opt(self):
         self._eagerload_mappings()
 
index 3ecbfca274a6686f084eae410fc14ffb8607e494..64ff2e421e371a413110205ff05fdd91888e8778 100644 (file)
@@ -404,8 +404,10 @@ class SelectableTest(
 
     @testing.combinations((True,), (False,))
     def test_broken_select_same_named_explicit_cols(self, use_anon):
-        # this is issue #6090.  the query is "wrong" and we dont know how
+        """test for #6090. the query is "wrong" and we dont know how
         # to render this right now.
+
+        """
         stmt = select(
             table1.c.col1,
             table1.c.col2,
@@ -432,6 +434,24 @@ class SelectableTest(
             ):
                 select(stmt.subquery()).compile()
 
+    def test_same_anon_named_explicit_cols(self):
+        """test for #8569.  This adjusts the change in #6090 to not apply
+        to anonymous labels.
+
+        """
+        lc = literal_column("col2").label(None)
+
+        subq1 = select(lc).subquery()
+
+        stmt2 = select(subq1, lc).subquery()
+
+        self.assert_compile(
+            select(stmt2),
+            "SELECT anon_1.col2_1, anon_1.col2_1_1 FROM "
+            "(SELECT anon_2.col2_1 AS col2_1, col2 AS col2_1 FROM "
+            "(SELECT col2 AS col2_1) AS anon_2) AS anon_1",
+        )
+
     def test_correlate_none_arg_error(self):
         stmt = select(table1)
         with expect_raises_message(