From 242867ec87c4d739011ee3cea9a53f33d9f05f2b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 16 Mar 2021 19:46:40 -0400 Subject: [PATCH] Adjust derivation rules for table vs. subquery against a join Fixed bug where ORM queries using a correlated subquery in conjunction with :func:`_orm.column_property` would fail to correlate correctly to an enclosing subquery or to a CTE when :meth:`_sql.Select.correlate_except` were used in the property to control correlation, in cases where the subquery contained the same selectables as ones within the correlated subquery that were intended to not be correlated. This is achieved by adding a limiting factor to ClauseAdapter which is to explicitly pass the selectables we will be adapting "from", which is then used by AliasedClass to limit "from" to the mappers represented by the AliasedClass. This did cause one test where an alias for a contains_eager() was missing to suddenly fail, and the test was corrected, however there may be some very edge cases like that one where the tighter criteria causes an existing use case that's relying on the more liberal aliasing to require modifications. Fixes: #6060 Change-Id: I8342042641886e1a220beafeb94fe45ea7aadb33 --- doc/build/changelog/unreleased_14/6060.rst | 10 ++ lib/sqlalchemy/orm/util.py | 7 ++ lib/sqlalchemy/sql/util.py | 18 +++ test/aaa_profiling/test_memusage.py | 2 +- test/orm/test_froms.py | 139 ++++++++++++++++++++- test/sql/test_external_traversal.py | 115 ++++++++++++++--- 6 files changed, 274 insertions(+), 17 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6060.rst diff --git a/doc/build/changelog/unreleased_14/6060.rst b/doc/build/changelog/unreleased_14/6060.rst new file mode 100644 index 0000000000..a133a24931 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6060.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 6060 + + Fixed bug where ORM queries using a correlated subquery in conjunction with + :func:`_orm.column_property` would fail to correlate correctly to an + enclosing subquery or to a CTE when :meth:`_sql.Select.correlate_except` + were used in the property to control correlation, in cases where the + subquery contained the same selectables as ones within the correlated + subquery that were intended to not be correlated. diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 37be077be7..8179149112 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -684,7 +684,14 @@ class AliasedInsp( equivalents=mapper._equivalent_columns, adapt_on_names=adapt_on_names, anonymize_labels=True, + # make sure the adapter doesn't try to grab other tables that + # are not even the thing we are mapping, such as embedded + # selectables in subqueries or CTEs. See issue #6060 + adapt_from_selectables=[ + m.selectable for m in self.with_polymorphic_mappers + ], ) + if inspected.is_aliased_class: self._adapter = inspected._adapter.wrap(self._adapter) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 4300d8a298..4dec30a80c 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -813,6 +813,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): exclude_fn=None, adapt_on_names=False, anonymize_labels=False, + adapt_from_selectables=None, ): self.__traverse_options__ = { "stop_on": [selectable], @@ -823,6 +824,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): self.exclude_fn = exclude_fn self.equivalents = util.column_dict(equivalents or {}) self.adapt_on_names = adapt_on_names + self.adapt_from_selectables = adapt_from_selectables def _corresponding_column( self, col, require_embedded, _seen=util.EMPTY_SET @@ -850,6 +852,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): if isinstance(col, FromClause) and not isinstance( col, functions.FunctionElement ): + if self.adapt_from_selectables: + for adp in self.adapt_from_selectables: + if adp.is_derived_from(col): + break + else: + return None + if self.selectable.is_derived_from(col): return self.selectable elif isinstance(col, Alias) and isinstance( @@ -875,6 +884,13 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): if "adapt_column" in col._annotations: col = col._annotations["adapt_column"] + if self.adapt_from_selectables and col not in self.equivalents: + for adp in self.adapt_from_selectables: + if adp.c.corresponding_column(col, False) is not None: + break + else: + return None + if self.include_fn and not self.include_fn(col): return None elif self.exclude_fn and self.exclude_fn(col): @@ -924,6 +940,7 @@ class ColumnAdapter(ClauseAdapter): adapt_on_names=False, allow_label_resolve=True, anonymize_labels=False, + adapt_from_selectables=None, ): ClauseAdapter.__init__( self, @@ -933,6 +950,7 @@ class ColumnAdapter(ClauseAdapter): exclude_fn=exclude_fn, adapt_on_names=adapt_on_names, anonymize_labels=anonymize_labels, + adapt_from_selectables=adapt_from_selectables, ) self.columns = util.WeakPopulateDict(self._locate_col) diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index dd709965ba..b1dd29a7ee 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1197,7 +1197,7 @@ class CycleTest(_fixtures.FixtureTest): # unfortunately there's a lot of cycles with an aliased() # for now, however calling upon clause_element does not seem # to make it worse which is what this was looking to test - @assert_cycles(68) + @assert_cycles(69) def go(): a1 = aliased(Foo) a1.user_name.__clause_element__() diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 1464cfc284..5881c54c2a 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -12,11 +12,13 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing +from sqlalchemy import Text from sqlalchemy import text from sqlalchemy import true from sqlalchemy import union from sqlalchemy import util from sqlalchemy.engine import default +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import aliased from sqlalchemy.orm import backref from sqlalchemy.orm import clear_mappers @@ -1113,9 +1115,10 @@ class InstancesTest(QueryTest, AssertsCompiledSQL): q = sess.query(User) def go(): - ulist_alias = aliased(User, alias=query.alias("ulist")) + ulist = query.alias("ulist") + ulist_alias = aliased(User, alias=ulist) result = ( - q.options(contains_eager("addresses")) + q.options(contains_eager("addresses", alias=ulist)) .select_entity_from(ulist_alias) .all() ) @@ -3894,3 +3897,135 @@ class LabelCollideTest(fixtures.MappedTest): # all three columns are loaded independently without # overlap, no additional SQL to load all attributes self.assert_sql_count(testing.db, go, 0) + + +class CorrelateORMTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + @testing.fixture + def mapping(self): + Base = declarative_base() + + def go(include_property, correlate_style, include_from): + class Address(Base): + __tablename__ = "addresses" + + id = Column(Integer, primary_key=True) + user_id = Column( + Integer, ForeignKey("users.id"), nullable=False + ) + city = Column(Text) + + class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + name = Column(Text) + + stmt = select(func.count(Address.id)).where( + Address.user_id == User.id + ) + if include_from: + stmt = stmt.select_from(Address) + + if include_property: + if correlate_style == "correlate": + User.total_addresses = column_property( + stmt.correlate(User).scalar_subquery() + ) + elif correlate_style == "correlate_except": + User.total_addresses = column_property( + stmt.correlate_except(Address).scalar_subquery() + ) + elif correlate_style is None: + User.total_addresses = column_property( + stmt.scalar_subquery() + ) + total_addresses = None + else: + + def total_addresses(cls): + stmt = select(func.count(Address.id)).where( + Address.user_id == cls.id + ) + + if correlate_style == "correlate": + stmt = stmt.correlate(cls) + elif correlate_style == "correlate_except": + stmt = stmt.correlate_except(Address) + + stmt = stmt.scalar_subquery() + + return stmt + + return User, Address, total_addresses + + yield go + Base.registry.dispose() + + def _combinations(fn): + + return testing.combinations( + (True,), (False,), argnames="include_property" + )( + testing.combinations( + ("correlate",), + ("correlate_except",), + (None,), + argnames="correlate_style", + )( + testing.combinations( + (True,), (False), argnames="include_from" + )(fn) + ) + ) + + @_combinations + def test_correlate_to_cte_legacy( + self, mapping, include_property, correlate_style, include_from + ): + User, Address, total_addresses = mapping( + include_property, correlate_style, include_from + ) + session = fixture_session() + + filtered_users = ( + session.query(User.id, User.name) + .join(Address) + .filter(Address.city == "somewhere") + .cte("filtered_users") + ) + + filtered_users_alias = aliased(User, filtered_users) + + paginated_users = ( + session.query(filtered_users_alias.id, filtered_users_alias.name) + .order_by(func.lower(filtered_users_alias.name).asc()) + .limit(25) + .cte("paginated_users") + ) + + paginated_users_alias = aliased(User, paginated_users) + + if total_addresses: + q = session.query( + paginated_users_alias, total_addresses(paginated_users_alias) + ) + else: + q = session.query(paginated_users_alias) + self.assert_compile( + q, + "WITH filtered_users AS " + "(SELECT users.id AS id, users.name AS name " + "FROM users JOIN addresses ON users.id = addresses.user_id " + "WHERE addresses.city = :city_1), " + "paginated_users AS (SELECT filtered_users.id AS id, " + "filtered_users.name AS name FROM filtered_users " + "ORDER BY lower(filtered_users.name) ASC LIMIT :param_1) " + "SELECT " + "paginated_users.id AS paginated_users_id, " + "paginated_users.name AS paginated_users_name, " + "(SELECT count(addresses.id) AS count_1 FROM addresses " + "WHERE addresses.user_id = paginated_users.id) AS anon_1 " + "FROM paginated_users", + ) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index b7e58dad9e..21b5b2d27b 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -1372,13 +1372,17 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): column_adapter = sql_util.ColumnAdapter(stmt2) is_(column_adapter.columns[expr], stmt2.selected_columns[3]) - def test_correlate_except_on_clone(self): + @testing.combinations((True,), (False,), argnames="use_adapt_from") + def test_correlate_except_on_clone(self, use_adapt_from): # test [ticket:4537]'s issue t1alias = t1.alias("t1alias") j = t1.join(t1alias, t1.c.col1 == t1alias.c.col2) - vis = sql_util.ClauseAdapter(j) + if use_adapt_from: + vis = sql_util.ClauseAdapter(j, adapt_from_selectables=[t1]) + else: + vis = sql_util.ClauseAdapter(j) # "control" subquery - uses correlate which has worked w/ adaption # for a long time @@ -1456,6 +1460,65 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): "JOIN table2 ON table1.col1 = table2.col1", ) + @testing.combinations((True,), (False,), argnames="use_adapt_from") + def test_correlate_except_with_mixed_tables(self, use_adapt_from): + # test [ticket:6060]'s issue + + stmt = select( + t1.c.col1, + select(func.count(t2.c.col1)) + .where(t2.c.col1 == t1.c.col1) + .correlate_except(t2) + .scalar_subquery(), + ) + self.assert_compile( + stmt, + "SELECT table1.col1, " + "(SELECT count(table2.col1) AS count_1 FROM table2 " + "WHERE table2.col1 = table1.col1) AS anon_1 " + "FROM table1", + ) + + subq = ( + select(t1) + .join(t2, t1.c.col1 == t2.c.col1) + .where(t2.c.col2 == "x") + .subquery() + ) + + if use_adapt_from: + vis = sql_util.ClauseAdapter(subq, adapt_from_selectables=[t1]) + else: + vis = sql_util.ClauseAdapter(subq) + + if use_adapt_from: + self.assert_compile( + vis.traverse(stmt), + "SELECT anon_1.col1, " + "(SELECT count(table2.col1) AS count_1 FROM table2 WHERE " + "table2.col1 = anon_1.col1) AS anon_2 " + "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, " + "table1.col3 AS col3 FROM table1 JOIN table2 ON table1.col1 = " + "table2.col1 WHERE table2.col2 = :col2_1) AS anon_1", + ) + else: + # here's the buggy version. table2 gets yanked out of the + # correlated subquery also. AliasedClass now uses + # adapt_from_selectables in all cases + self.assert_compile( + vis.traverse(stmt), + "SELECT anon_1.col1, " + "(SELECT count(table2.col1) AS count_1 FROM table2, " + "(SELECT table1.col1 AS col1, table1.col2 AS col2, " + "table1.col3 AS col3 FROM table1 JOIN table2 ON " + "table1.col1 = table2.col1 WHERE table2.col2 = :col2_1) AS " + "anon_1 WHERE table2.col1 = anon_1.col1) AS anon_2 " + "FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, " + "table1.col3 AS col3 FROM table1 JOIN table2 " + "ON table1.col1 = table2.col1 " + "WHERE table2.col2 = :col2_1) AS anon_1", + ) + @testing.fails_on_everything_except() def test_joins_dont_adapt(self): # adapting to a join, i.e. ClauseAdapter(t1.join(t2)), doesn't @@ -1483,24 +1546,36 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): "addresses.user_id", ) - def test_table_to_alias_1(self): + @testing.combinations((True,), (False,), argnames="use_adapt_from") + def test_table_to_alias_1(self, use_adapt_from): t1alias = t1.alias("t1alias") - vis = sql_util.ClauseAdapter(t1alias) + if use_adapt_from: + vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1]) + else: + vis = sql_util.ClauseAdapter(t1alias) ff = vis.traverse(func.count(t1.c.col1).label("foo")) assert list(_from_objects(ff)) == [t1alias] - def test_table_to_alias_2(self): + @testing.combinations((True,), (False,), argnames="use_adapt_from") + def test_table_to_alias_2(self, use_adapt_from): t1alias = t1.alias("t1alias") - vis = sql_util.ClauseAdapter(t1alias) + if use_adapt_from: + vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1]) + else: + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( vis.traverse(select(literal_column("*")).select_from(t1)), "SELECT * FROM table1 AS t1alias", ) - def test_table_to_alias_3(self): + @testing.combinations((True,), (False,), argnames="use_adapt_from") + def test_table_to_alias_3(self, use_adapt_from): t1alias = t1.alias("t1alias") - vis = sql_util.ClauseAdapter(t1alias) + if use_adapt_from: + vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1]) + else: + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( vis.traverse( select(literal_column("*")).where(t1.c.col1 == t2.c.col2) @@ -1509,9 +1584,13 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE t1alias.col1 = table2.col2", ) - def test_table_to_alias_4(self): + @testing.combinations((True,), (False,), argnames="use_adapt_from") + def test_table_to_alias_4(self, use_adapt_from): t1alias = t1.alias("t1alias") - vis = sql_util.ClauseAdapter(t1alias) + if use_adapt_from: + vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1]) + else: + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( vis.traverse( select(literal_column("*")) @@ -1522,9 +1601,13 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE t1alias.col1 = table2.col2", ) - def test_table_to_alias_5(self): + @testing.combinations((True,), (False,), argnames="use_adapt_from") + def test_table_to_alias_5(self, use_adapt_from): t1alias = t1.alias("t1alias") - vis = sql_util.ClauseAdapter(t1alias) + if use_adapt_from: + vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1]) + else: + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( select(t1alias, t2).where( t1alias.c.col1 @@ -1543,9 +1626,13 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): "(SELECT * FROM table2 WHERE t1alias.col1 = table2.col2)", ) - def test_table_to_alias_6(self): + @testing.combinations((True,), (False,), argnames="use_adapt_from") + def test_table_to_alias_6(self, use_adapt_from): t1alias = t1.alias("t1alias") - vis = sql_util.ClauseAdapter(t1alias) + if use_adapt_from: + vis = sql_util.ClauseAdapter(t1alias, adapt_from_selectables=[t1]) + else: + vis = sql_util.ClauseAdapter(t1alias) self.assert_compile( select(t1alias, t2).where( t1alias.c.col1 -- 2.47.2