]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Adjust derivation rules for table vs. subquery against a join
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Mar 2021 23:46:40 +0000 (19:46 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Mar 2021 15:02:18 +0000 (11:02 -0400)
Fixed bug where ORM queries using a correlated subquery in conjunction with
:func:`_orm.column_property` would fail to correlate correctly to an
enclosing subquery or to a CTE when :meth:`_sql.Select.correlate_except`
were used in the property to control correlation, in cases where the
subquery contained the same selectables as ones within the correlated
subquery that were intended to not be correlated.

This is achieved by adding a limiting factor to ClauseAdapter
which is to explicitly pass the selectables we will be adapting
"from", which is then used by AliasedClass to limit "from"
to the mappers represented by the AliasedClass.

This did cause one test where an alias for a contains_eager()
was missing to suddenly fail, and the test was corrected, however
there may be some very edge cases like that one where the tighter
criteria causes an existing use case that's relying on the more
liberal aliasing to require modifications.

Fixes: #6060
Change-Id: I8342042641886e1a220beafeb94fe45ea7aadb33

doc/build/changelog/unreleased_14/6060.rst [new file with mode: 0644]
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/util.py
test/aaa_profiling/test_memusage.py
test/orm/test_froms.py
test/sql/test_external_traversal.py

diff --git a/doc/build/changelog/unreleased_14/6060.rst b/doc/build/changelog/unreleased_14/6060.rst
new file mode 100644 (file)
index 0000000..a133a24
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 6060
+
+    Fixed bug where ORM queries using a correlated subquery in conjunction with
+    :func:`_orm.column_property` would fail to correlate correctly to an
+    enclosing subquery or to a CTE when :meth:`_sql.Select.correlate_except`
+    were used in the property to control correlation, in cases where the
+    subquery contained the same selectables as ones within the correlated
+    subquery that were intended to not be correlated.
index 37be077be7dc671997f0e784530d9f4e1faeaf45..8179149112d71107b28d6493a3e4b14e53fdbb8c 100644 (file)
@@ -684,7 +684,14 @@ class AliasedInsp(
             equivalents=mapper._equivalent_columns,
             adapt_on_names=adapt_on_names,
             anonymize_labels=True,
+            # make sure the adapter doesn't try to grab other tables that
+            # are not even the thing we are mapping, such as embedded
+            # selectables in subqueries or CTEs.  See issue #6060
+            adapt_from_selectables=[
+                m.selectable for m in self.with_polymorphic_mappers
+            ],
         )
+
         if inspected.is_aliased_class:
             self._adapter = inspected._adapter.wrap(self._adapter)
 
index 4300d8a298b65c2f9796e5602d3d6251056c602f..4dec30a80cc3167d5f83defe9dc30953b42c0566 100644 (file)
@@ -813,6 +813,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
         exclude_fn=None,
         adapt_on_names=False,
         anonymize_labels=False,
+        adapt_from_selectables=None,
     ):
         self.__traverse_options__ = {
             "stop_on": [selectable],
@@ -823,6 +824,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
         self.exclude_fn = exclude_fn
         self.equivalents = util.column_dict(equivalents or {})
         self.adapt_on_names = adapt_on_names
+        self.adapt_from_selectables = adapt_from_selectables
 
     def _corresponding_column(
         self, col, require_embedded, _seen=util.EMPTY_SET
@@ -850,6 +852,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
         if isinstance(col, FromClause) and not isinstance(
             col, functions.FunctionElement
         ):
+            if self.adapt_from_selectables:
+                for adp in self.adapt_from_selectables:
+                    if adp.is_derived_from(col):
+                        break
+                else:
+                    return None
+
             if self.selectable.is_derived_from(col):
                 return self.selectable
             elif isinstance(col, Alias) and isinstance(
@@ -875,6 +884,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
         if "adapt_column" in col._annotations:
             col = col._annotations["adapt_column"]
 
+        if self.adapt_from_selectables and col not in self.equivalents:
+            for adp in self.adapt_from_selectables:
+                if adp.c.corresponding_column(col, False) is not None:
+                    break
+            else:
+                return None
+
         if self.include_fn and not self.include_fn(col):
             return None
         elif self.exclude_fn and self.exclude_fn(col):
@@ -924,6 +940,7 @@ class ColumnAdapter(ClauseAdapter):
         adapt_on_names=False,
         allow_label_resolve=True,
         anonymize_labels=False,
+        adapt_from_selectables=None,
     ):
         ClauseAdapter.__init__(
             self,
@@ -933,6 +950,7 @@ class ColumnAdapter(ClauseAdapter):
             exclude_fn=exclude_fn,
             adapt_on_names=adapt_on_names,
             anonymize_labels=anonymize_labels,
+            adapt_from_selectables=adapt_from_selectables,
         )
 
         self.columns = util.WeakPopulateDict(self._locate_col)
index dd709965bac904a28a1b6aeffa0c75016b1bfea8..b1dd29a7eeac42094471d2f3c436d917364b9d08 100644 (file)
@@ -1197,7 +1197,7 @@ class CycleTest(_fixtures.FixtureTest):
         # unfortunately there's a lot of cycles with an aliased()
         # for now, however calling upon clause_element does not seem
         # to make it worse which is what this was looking to test
-        @assert_cycles(68)
+        @assert_cycles(69)
         def go():
             a1 = aliased(Foo)
             a1.user_name.__clause_element__()
index 1464cfc28401b270aed8f9595b5579ded523e850..5881c54c2abdf1382030c4539c44de09c4e191ac 100644 (file)
@@ -12,11 +12,13 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
+from sqlalchemy import Text
 from sqlalchemy import text
 from sqlalchemy import true
 from sqlalchemy import union
 from sqlalchemy import util
 from sqlalchemy.engine import default
+from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import backref
 from sqlalchemy.orm import clear_mappers
@@ -1113,9 +1115,10 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         q = sess.query(User)
 
         def go():
-            ulist_alias = aliased(User, alias=query.alias("ulist"))
+            ulist = query.alias("ulist")
+            ulist_alias = aliased(User, alias=ulist)
             result = (
-                q.options(contains_eager("addresses"))
+                q.options(contains_eager("addresses", alias=ulist))
                 .select_entity_from(ulist_alias)
                 .all()
             )
@@ -3894,3 +3897,135 @@ class LabelCollideTest(fixtures.MappedTest):
         # all three columns are loaded independently without
         # overlap, no additional SQL to load all attributes
         self.assert_sql_count(testing.db, go, 0)
+
+
+class CorrelateORMTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    @testing.fixture
+    def mapping(self):
+        Base = declarative_base()
+
+        def go(include_property, correlate_style, include_from):
+            class Address(Base):
+                __tablename__ = "addresses"
+
+                id = Column(Integer, primary_key=True)
+                user_id = Column(
+                    Integer, ForeignKey("users.id"), nullable=False
+                )
+                city = Column(Text)
+
+            class User(Base):
+                __tablename__ = "users"
+
+                id = Column(Integer, primary_key=True)
+                name = Column(Text)
+
+            stmt = select(func.count(Address.id)).where(
+                Address.user_id == User.id
+            )
+            if include_from:
+                stmt = stmt.select_from(Address)
+
+            if include_property:
+                if correlate_style == "correlate":
+                    User.total_addresses = column_property(
+                        stmt.correlate(User).scalar_subquery()
+                    )
+                elif correlate_style == "correlate_except":
+                    User.total_addresses = column_property(
+                        stmt.correlate_except(Address).scalar_subquery()
+                    )
+                elif correlate_style is None:
+                    User.total_addresses = column_property(
+                        stmt.scalar_subquery()
+                    )
+                total_addresses = None
+            else:
+
+                def total_addresses(cls):
+                    stmt = select(func.count(Address.id)).where(
+                        Address.user_id == cls.id
+                    )
+
+                    if correlate_style == "correlate":
+                        stmt = stmt.correlate(cls)
+                    elif correlate_style == "correlate_except":
+                        stmt = stmt.correlate_except(Address)
+
+                    stmt = stmt.scalar_subquery()
+
+                    return stmt
+
+            return User, Address, total_addresses
+
+        yield go
+        Base.registry.dispose()
+
+    def _combinations(fn):
+
+        return testing.combinations(
+            (True,), (False,), argnames="include_property"
+        )(
+            testing.combinations(
+                ("correlate",),
+                ("correlate_except",),
+                (None,),
+                argnames="correlate_style",
+            )(
+                testing.combinations(
+                    (True,), (False), argnames="include_from"
+                )(fn)
+            )
+        )
+
+    @_combinations
+    def test_correlate_to_cte_legacy(
+        self, mapping, include_property, correlate_style, include_from
+    ):
+        User, Address, total_addresses = mapping(
+            include_property, correlate_style, include_from
+        )
+        session = fixture_session()
+
+        filtered_users = (
+            session.query(User.id, User.name)
+            .join(Address)
+            .filter(Address.city == "somewhere")
+            .cte("filtered_users")
+        )
+
+        filtered_users_alias = aliased(User, filtered_users)
+
+        paginated_users = (
+            session.query(filtered_users_alias.id, filtered_users_alias.name)
+            .order_by(func.lower(filtered_users_alias.name).asc())
+            .limit(25)
+            .cte("paginated_users")
+        )
+
+        paginated_users_alias = aliased(User, paginated_users)
+
+        if total_addresses:
+            q = session.query(
+                paginated_users_alias, total_addresses(paginated_users_alias)
+            )
+        else:
+            q = session.query(paginated_users_alias)
+        self.assert_compile(
+            q,
+            "WITH filtered_users AS "
+            "(SELECT users.id AS id, users.name AS name "
+            "FROM users JOIN addresses ON users.id = addresses.user_id "
+            "WHERE addresses.city = :city_1), "
+            "paginated_users AS (SELECT filtered_users.id AS id, "
+            "filtered_users.name AS name FROM filtered_users "
+            "ORDER BY lower(filtered_users.name) ASC LIMIT :param_1) "
+            "SELECT "
+            "paginated_users.id AS paginated_users_id, "
+            "paginated_users.name AS paginated_users_name, "
+            "(SELECT count(addresses.id) AS count_1 FROM addresses "
+            "WHERE addresses.user_id = paginated_users.id) AS anon_1 "
+            "FROM paginated_users",
+        )
index b7e58dad9e82a642ca1bd635da6ff82110b49601..21b5b2d27b52e1a84d684707f82a84bf01a14c9a 100644 (file)
@@ -1372,13 +1372,17 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
         column_adapter = sql_util.ColumnAdapter(stmt2)
         is_(column_adapter.columns[expr], stmt2.selected_columns[3])
 
-    def test_correlate_except_on_clone(self):
+    @testing.combinations((True,), (False,), argnames="use_adapt_from")
+    def test_correlate_except_on_clone(self, use_adapt_from):
         # test [ticket:4537]'s issue
 
         t1alias = t1.alias("t1alias")
         j = t1.join(t1alias, t1.c.col1 == t1alias.c.col2)
 
-        vis = sql_util.ClauseAdapter(j)
+        if use_adapt_from:
+            vis = sql_util.ClauseAdapter(j, adapt_from_selectables=[t1])
+        else:
+            vis = sql_util.ClauseAdapter(j)
 
         # "control" subquery - uses correlate which has worked w/ adaption
         # for a long time
@@ -1456,6 +1460,65 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             "JOIN table2 ON table1.col1 = table2.col1",
         )
 
+    @testing.combinations((True,), (False,), argnames="use_adapt_from")
+    def test_correlate_except_with_mixed_tables(self, use_adapt_from):
+        # test [ticket:6060]'s issue
+
+        stmt = select(
+            t1.c.col1,
+            select(func.count(t2.c.col1))
+            .where(t2.c.col1 == t1.c.col1)
+            .correlate_except(t2)
+            .scalar_subquery(),
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT table1.col1, "
+            "(SELECT count(table2.col1) AS count_1 FROM table2 "
+            "WHERE table2.col1 = table1.col1) AS anon_1 "
+            "FROM table1",
+        )
+
+        subq = (
+            select(t1)
+            .join(t2, t1.c.col1 == t2.c.col1)
+            .where(t2.c.col2 == "x")
+            .subquery()
+        )
+
+        if use_adapt_from:
+            vis = sql_util.ClauseAdapter(subq, adapt_from_selectables=[t1])
+        else:
+            vis = sql_util.ClauseAdapter(subq)
+
+        if use_adapt_from:
+            self.assert_compile(
+                vis.traverse(stmt),
+                "SELECT anon_1.col1, "
+                "(SELECT count(table2.col1) AS count_1 FROM table2 WHERE "
+                "table2.col1 = anon_1.col1) AS anon_2 "
+                "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "
+                "table1.col3 AS col3 FROM table1 JOIN table2 ON table1.col1 = "
+                "table2.col1 WHERE table2.col2 = :col2_1) AS anon_1",
+            )
+        else:
+            # here's the buggy version.  table2 gets yanked out of the
+            # correlated subquery also.  AliasedClass now uses
+            # adapt_from_selectables in all cases
+            self.assert_compile(
+                vis.traverse(stmt),
+                "SELECT anon_1.col1, "
+                "(SELECT count(table2.col1) AS count_1 FROM table2, "
+                "(SELECT table1.col1 AS col1, table1.col2 AS col2, "
+                "table1.col3 AS col3 FROM table1 JOIN table2 ON "
+                "table1.col1 = table2.col1 WHERE table2.col2 = :col2_1) AS "
+                "anon_1 WHERE table2.col1 = anon_1.col1) AS anon_2 "
+                "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "
+                "table1.col3 AS col3 FROM table1 JOIN table2 "
+                "ON table1.col1 = table2.col1 "
+                "WHERE table2.col2 = :col2_1) AS anon_1",
+            )
+
     @testing.fails_on_everything_except()
     def test_joins_dont_adapt(self):
         # adapting to a join, i.e. ClauseAdapter(t1.join(t2)), doesn't
@@ -1483,24 +1546,36 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             "addresses.user_id",
         )
 
-    def test_table_to_alias_1(self):
+    @testing.combinations((True,), (False,), argnames="use_adapt_from")
+    def test_table_to_alias_1(self, use_adapt_from):
         t1alias = t1.alias("t1alias")
 
-        vis = sql_util.ClauseAdapter(t1alias)
+        if use_adapt_from:
+            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+        else:
+            vis = sql_util.ClauseAdapter(t1alias)
         ff = vis.traverse(func.count(t1.c.col1).label("foo"))
         assert list(_from_objects(ff)) == [t1alias]
 
-    def test_table_to_alias_2(self):
+    @testing.combinations((True,), (False,), argnames="use_adapt_from")
+    def test_table_to_alias_2(self, use_adapt_from):
         t1alias = t1.alias("t1alias")
-        vis = sql_util.ClauseAdapter(t1alias)
+        if use_adapt_from:
+            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+        else:
+            vis = sql_util.ClauseAdapter(t1alias)
         self.assert_compile(
             vis.traverse(select(literal_column("*")).select_from(t1)),
             "SELECT * FROM table1 AS t1alias",
         )
 
-    def test_table_to_alias_3(self):
+    @testing.combinations((True,), (False,), argnames="use_adapt_from")
+    def test_table_to_alias_3(self, use_adapt_from):
         t1alias = t1.alias("t1alias")
-        vis = sql_util.ClauseAdapter(t1alias)
+        if use_adapt_from:
+            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+        else:
+            vis = sql_util.ClauseAdapter(t1alias)
         self.assert_compile(
             vis.traverse(
                 select(literal_column("*")).where(t1.c.col1 == t2.c.col2)
@@ -1509,9 +1584,13 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE t1alias.col1 = table2.col2",
         )
 
-    def test_table_to_alias_4(self):
+    @testing.combinations((True,), (False,), argnames="use_adapt_from")
+    def test_table_to_alias_4(self, use_adapt_from):
         t1alias = t1.alias("t1alias")
-        vis = sql_util.ClauseAdapter(t1alias)
+        if use_adapt_from:
+            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+        else:
+            vis = sql_util.ClauseAdapter(t1alias)
         self.assert_compile(
             vis.traverse(
                 select(literal_column("*"))
@@ -1522,9 +1601,13 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE t1alias.col1 = table2.col2",
         )
 
-    def test_table_to_alias_5(self):
+    @testing.combinations((True,), (False,), argnames="use_adapt_from")
+    def test_table_to_alias_5(self, use_adapt_from):
         t1alias = t1.alias("t1alias")
-        vis = sql_util.ClauseAdapter(t1alias)
+        if use_adapt_from:
+            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+        else:
+            vis = sql_util.ClauseAdapter(t1alias)
         self.assert_compile(
             select(t1alias, t2).where(
                 t1alias.c.col1
@@ -1543,9 +1626,13 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
             "(SELECT * FROM table2 WHERE t1alias.col1 = table2.col2)",
         )
 
-    def test_table_to_alias_6(self):
+    @testing.combinations((True,), (False,), argnames="use_adapt_from")
+    def test_table_to_alias_6(self, use_adapt_from):
         t1alias = t1.alias("t1alias")
-        vis = sql_util.ClauseAdapter(t1alias)
+        if use_adapt_from:
+            vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1])
+        else:
+            vis = sql_util.ClauseAdapter(t1alias)
         self.assert_compile(
             select(t1alias, t2).where(
                 t1alias.c.col1