]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
remove should_nest behavior for contains_eager()
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Sep 2022 19:17:57 +0000 (15:17 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Sep 2022 21:18:37 +0000 (17:18 -0400)
Fixed regression for 1.4 in :func:`_orm.contains_eager` where the "wrap in
subquery" logic of :func:`_orm.joinedload` would be inadvertently triggered
for use of the :func:`_orm.contains_eager` function with similar statements
(e.g. those that use ``distinct()``, ``limit()`` or ``offset()``). This is
not appropriate for :func:`_orm.contains_eager` which has always had the
contract that the user-defined SQL statement is unmodified with the
exception of adding the appropriate columns.

Also includes an adjustment to the assertion in Label._make_proxy()
which was there to prevent a fixed label name from being anonymized;
if the label is already anonymous, the change should proceed.
This logic was being hit before the contains_eager behavior was
adjusted. With the adjustment, this code is not used.

Fixes: #8569
Change-Id: I161e65041c0162fd2b83cbef40f57a50fcfaf0fd
(cherry picked from commit 57b400f07951f0ae8651ca38338ec5be1d222c7e)

lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/elements.py
test/orm/test_core_compilation.py
test/orm/test_eager_relations.py
test/orm/test_query.py
test/sql/test_selectable.py

index d5a742cc57852666a4ab86133ef28156756da6fa..379b65ac7e956f8e097fcb7a38d96ef1c4adeadb 100644 (file)
@@ -397,6 +397,7 @@ class ORMFromStatementCompileState(ORMCompileState):
 
     _has_orm_entities = False
     multi_row_eager_loaders = False
+    eager_adding_joins = False
     compound_eager_adapter = None
 
     extra_criteria_entities = _EMPTY_DICT
@@ -592,6 +593,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
     _has_orm_entities = False
     multi_row_eager_loaders = False
+    eager_adding_joins = False
     compound_eager_adapter = None
 
     correlate = None
@@ -900,7 +902,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         if self.order_by is False:
             self.order_by = None
 
-        if self.multi_row_eager_loaders and self._should_nest_selectable:
+        if (
+            self.multi_row_eager_loaders
+            and self.eager_adding_joins
+            and self._should_nest_selectable
+        ):
             self.statement = self._compound_eager_statement()
         else:
             self.statement = self._simple_statement()
index 944c114a64060655c7047fdfbecab8736be24ce9..a014b2f411578f5ca6a155604ba1561a6e1bb4ee 100644 (file)
@@ -1965,6 +1965,9 @@ class JoinedLoader(AbstractRelationshipLoader):
         )
 
         if user_defined_adapter is not False:
+
+            # setup an adapter but dont create any JOIN, assume it's already
+            # in the query
             (
                 clauses,
                 adapter,
@@ -1976,6 +1979,11 @@ class JoinedLoader(AbstractRelationshipLoader):
                 adapter,
                 user_defined_adapter,
             )
+
+            # don't do "wrap" for multi-row, we want to wrap
+            # limited/distinct SELECT,
+            # because we want to put the JOIN on the outside.
+
         else:
             # if not via query option, check for
             # a cycle
@@ -1986,6 +1994,7 @@ class JoinedLoader(AbstractRelationshipLoader):
                 elif path.contains_mapper(self.mapper):
                     return
 
+            # add the JOIN and create an adapter
             (
                 clauses,
                 adapter,
@@ -2002,6 +2011,10 @@ class JoinedLoader(AbstractRelationshipLoader):
                 chained_from_outerjoin,
             )
 
+            # for multi-row, we want to wrap limited/distinct SELECT,
+            # because we want to put the JOIN on the outside.
+            compile_state.eager_adding_joins = True
+
         with_poly_entity = path.get(
             compile_state.attributes, "path_with_polymorphic", None
         )
index 268c0d6ac4db8f6d4e73019db73ddf26477c271a..ace43b3a1d4ca7451de8526c6b5cc9669264d05e 100644 (file)
@@ -4636,7 +4636,7 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
         # when a label name conflicts with other columns and select()
         # is attempting to disambiguate an explicit label, which is not what
         # the user would want.   See issue #6090.
-        if key != self.name:
+        if key != self.name and not isinstance(self.name, _anonymous_label):
             raise exc.InvalidRequestError(
                 "Label name %s is being renamed to an anonymous label due "
                 "to disambiguation "
index 1457f873c5f6123f528151b0cfaee2130d795812..c0c530b4c07fd8ec3611d8f648a49927fc5c05a7 100644 (file)
@@ -2,6 +2,7 @@ from sqlalchemy import bindparam
 from sqlalchemy import Column
 from sqlalchemy import delete
 from sqlalchemy import exc
+from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import insert
 from sqlalchemy import inspect
@@ -10,6 +11,7 @@ from sqlalchemy import literal_column
 from sqlalchemy import null
 from sqlalchemy import or_
 from sqlalchemy import select
+from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import union
@@ -1023,20 +1025,35 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
         self.mapper_registry.map_imperatively(
             User,
             users,
+            properties={
+                "addresses": relationship(Address, back_populates="user")
+            },
         )
 
         self.mapper_registry.map_imperatively(
             Address,
             addresses,
             properties={
-                "user": relationship(
-                    User,
-                )
+                "user": relationship(User, back_populates="addresses")
             },
         )
 
         return User, Address
 
+    @testing.fixture
+    def hard_labeled_self_ref_fixture(self, decl_base):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey("a.id"))
+            data = Column(String)
+            data_lower = column_property(func.lower(data).label("hardcoded"))
+
+            as_ = relationship("A")
+
+        return A
+
     def test_no_joinedload_embedded(self, plain_fixture):
         User, Address = plain_fixture
 
@@ -1145,22 +1162,84 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
             "ON users_1.id = addresses.user_id",
         )
 
-    def test_contains_eager_outermost(self, plain_fixture):
+    def test_joinedload_outermost_w_wrapping_elements(self, plain_fixture):
         User, Address = plain_fixture
 
         stmt = (
-            select(Address)
-            .join(Address.user)
-            .options(contains_eager(Address.user))
+            select(User)
+            .options(joinedload(User.addresses))
+            .limit(10)
+            .distinct()
         )
 
-        # render joined eager loads with stringify
         self.assert_compile(
             stmt,
-            "SELECT users.id, users.name, addresses.id AS id_1, "
-            "addresses.user_id, "
-            "addresses.email_address "
-            "FROM addresses JOIN users ON users.id = addresses.user_id",
+            "SELECT anon_1.id, anon_1.name, addresses_1.id AS id_1, "
+            "addresses_1.user_id, addresses_1.email_address FROM "
+            "(SELECT DISTINCT users.id AS id, users.name AS name FROM users "
+            "LIMIT :param_1) "
+            "AS anon_1 LEFT OUTER JOIN addresses AS addresses_1 "
+            "ON anon_1.id = addresses_1.user_id",
+        )
+
+    def test_contains_eager_outermost_w_wrapping_elements(self, plain_fixture):
+        """test #8569"""
+
+        User, Address = plain_fixture
+
+        stmt = (
+            select(User)
+            .join(User.addresses)
+            .options(contains_eager(User.addresses))
+            .limit(10)
+            .distinct()
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT DISTINCT addresses.id, addresses.user_id, "
+            "addresses.email_address, users.id AS id_1, users.name "
+            "FROM users JOIN addresses ON users.id = addresses.user_id "
+            "LIMIT :param_1",
+        )
+
+    def test_joinedload_hard_labeled_selfref(
+        self, hard_labeled_self_ref_fixture
+    ):
+        """test #8569"""
+
+        A = hard_labeled_self_ref_fixture
+
+        stmt = select(A).options(joinedload(A.as_)).distinct()
+        self.assert_compile(
+            stmt,
+            "SELECT anon_1.hardcoded, anon_1.id, anon_1.a_id, anon_1.data, "
+            "lower(a_1.data) AS lower_1, a_1.id AS id_1, a_1.a_id AS a_id_1, "
+            "a_1.data AS data_1 FROM (SELECT DISTINCT lower(a.data) AS "
+            "hardcoded, a.id AS id, a.a_id AS a_id, a.data AS data FROM a) "
+            "AS anon_1 LEFT OUTER JOIN a AS a_1 ON anon_1.id = a_1.a_id",
+        )
+
+    def test_contains_eager_hard_labeled_selfref(
+        self, hard_labeled_self_ref_fixture
+    ):
+        """test #8569"""
+
+        A = hard_labeled_self_ref_fixture
+
+        a1 = aliased(A)
+        stmt = (
+            select(A)
+            .join(A.as_.of_type(a1))
+            .options(contains_eager(A.as_.of_type(a1)))
+            .distinct()
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT DISTINCT lower(a.data) AS hardcoded, "
+            "lower(a_1.data) AS hardcoded, a_1.id, a_1.a_id, a_1.data, "
+            "a.id AS id_1, a.a_id AS a_id_1, a.data AS data_1 "
+            "FROM a JOIN a AS a_1 ON a.id = a_1.a_id",
         )
 
     def test_column_properties(self, column_property_fixture):
index b2a5ed33f3908bec5ce3d15eb59f7385d27e123a..fb7550e0ea3ea316cc4955b421d8d9a97fb673e4 100644 (file)
@@ -3135,10 +3135,14 @@ class SelectUniqueTest(_fixtures.FixtureTest):
 
         eq_(result.scalars().all(), self.static.address_user_result)
 
-    def test_unique_error(self):
+    @testing.combinations(joinedload, contains_eager)
+    def test_unique_error(self, opt):
         User = self.classes.User
 
-        stmt = select(User).options(joinedload(User.addresses))
+        stmt = select(User).options(opt(User.addresses))
+        if opt is contains_eager:
+            stmt = stmt.join(User.addresses)
+
         s = fixture_session()
         result = s.execute(stmt)
 
index ddaa3c60dab963fd5f0bc567365bad50d1229a8b..9779462a2463cd91a4397cef02bd80f1880afb00 100644 (file)
@@ -5463,6 +5463,24 @@ class YieldTest(_fixtures.FixtureTest):
             q.all,
         )
 
+    def test_no_contains_eager_opt(self):
+        self._eagerload_mappings()
+
+        User = self.classes.User
+        sess = fixture_session()
+        q = (
+            sess.query(User)
+            .join(User.addresses)
+            .options(contains_eager(User.addresses))
+            .yield_per(1)
+        )
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            "Can't use yield_per with eager loaders that require "
+            "uniquing or row buffering",
+            q.all,
+        )
+
     def test_no_subqueryload_opt(self):
         self._eagerload_mappings()
 
index e0113a7f1017e7de06878ce28d7aeaa13222fe52..a3f7b7c468206b5e179f5306c593de8d08d3fd9f 100644 (file)
@@ -384,8 +384,10 @@ class SelectableTest(
 
     @testing.combinations((True,), (False,))
     def test_broken_select_same_named_explicit_cols(self, use_anon):
-        # this is issue #6090.  the query is "wrong" and we dont know how
+        """test for #6090. the query is "wrong" and we dont know how
         # to render this right now.
+
+        """
         stmt = select(
             table1.c.col1,
             table1.c.col2,
@@ -412,6 +414,24 @@ class SelectableTest(
             ):
                 select(stmt.subquery()).compile()
 
+    def test_same_anon_named_explicit_cols(self):
+        """test for #8569.  This adjusts the change in #6090 to not apply
+        to anonymous labels.
+
+        """
+        lc = literal_column("col2").label(None)
+
+        subq1 = select(lc).subquery()
+
+        stmt2 = select(subq1, lc).subquery()
+
+        self.assert_compile(
+            select(stmt2),
+            "SELECT anon_1.col2_1, anon_1.col2_1_1 FROM "
+            "(SELECT anon_2.col2_1 AS col2_1, col2 AS col2_1 FROM "
+            "(SELECT col2 AS col2_1) AS anon_2) AS anon_1",
+        )
+
     def test_select_label_grouped_still_corresponds(self):
         label = select(table1.c.col1).label("foo")
         label2 = label.self_group()