From: Mike Bayer Date: Fri, 5 Feb 2021 18:38:28 +0000 (-0500) Subject: Track a second from_linter for lateral subqueries X-Git-Tag: rel_1_4_0b3~22^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bc9221bf781adfffdddf12860d4eed7650457a0a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Track a second from_linter for lateral subqueries Fixed bug where the "cartesian product" assertion was not correctly accommodating for joins between tables that relied upon the use of LATERAL to connect from a subquery to another subquery in the enclosing context. Additionally, enabled from_linting for the base assert_compile(), however it remains off by default; to enable by default we would have to make sure it isn't set for DDL compiles and there's also a lot of tests that would also need to turn it off, so leaving this off for expediency. Fixes: #5924 Change-Id: I22604baf572f8c4d96befcc610b3dcb79c13fc4a --- diff --git a/doc/build/changelog/unreleased_14/5924.rst b/doc/build/changelog/unreleased_14/5924.rst new file mode 100644 index 0000000000..f0ec874c6d --- /dev/null +++ b/doc/build/changelog/unreleased_14/5924.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, sql + :tickets: 5924 + + Fixed bug where the "cartesian product" assertion was not correctly + accommodating for joins between tables that relied upon the use of LATERAL + to connect from a subquery to another subquery in the enclosing context. \ No newline at end of file diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 696b38e649..f22e8614bc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1996,14 +1996,24 @@ class SQLCompiler(Compiled): override_operator=None, eager_grouping=False, from_linter=None, + lateral_from_linter=None, **kw ): if from_linter and operators.is_comparison(binary.operator): - from_linter.edges.update( - itertools.product( - binary.left._from_objects, binary.right._from_objects + if lateral_from_linter is not None: + enclosing_lateral = kw["enclosing_lateral"] + lateral_from_linter.edges.update( + itertools.product( + binary.left._from_objects + [enclosing_lateral], + binary.right._from_objects + [enclosing_lateral], + ) + ) + else: + from_linter.edges.update( + itertools.product( + binary.left._from_objects, binary.right._from_objects + ) ) - ) # don't allow "? = ?" to render if ( @@ -2027,7 +2037,11 @@ class SQLCompiler(Compiled): ) else: return self._generate_generic_binary( - binary, opstring, from_linter=from_linter, **kw + binary, + opstring, + from_linter=from_linter, + lateral_from_linter=lateral_from_linter, + **kw ) def visit_function_as_comparison_op_binary(self, element, operator, **kw): @@ -2570,6 +2584,24 @@ class SQLCompiler(Compiled): from_linter=None, **kwargs ): + + if lateral: + if "enclosing_lateral" not in kwargs: + # if lateral is set and enclosing_lateral is not + # present, we assume we are being called directly + # from visit_lateral() and we need to set enclosing_lateral. + assert alias._is_lateral + kwargs["enclosing_lateral"] = alias + + # for lateral objects, we track a second from_linter that is... + # lateral! to the level above us. + if ( + from_linter + and "lateral_from_linter" not in kwargs + and "enclosing_lateral" in kwargs + ): + kwargs["lateral_from_linter"] = from_linter + if enclosing_alias is not None and enclosing_alias.element is alias: inner = alias.element._compiler_dispatch( self, diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 79125e1f16..44d7405b12 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -394,6 +394,7 @@ class AssertsCompiledSQL(object): schema_translate_map=None, render_schema_translate=False, default_schema_name=None, + from_linting=False, ): if use_default_dialect: dialect = default.DefaultDialect() @@ -438,6 +439,9 @@ class AssertsCompiledSQL(object): if render_schema_translate: kw["render_schema_translate"] = True + if from_linting or getattr(self, "assert_from_linting", False): + kw["linting"] = sql.FROM_LINTING + from sqlalchemy import orm if isinstance(clause, orm.dynamic.AppenderQuery): diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py index b0bcee18e2..9e0ededecc 100644 --- a/test/sql/test_from_linter.py +++ b/test/sql/test_from_linter.py @@ -1,6 +1,7 @@ from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import sql +from sqlalchemy import testing from sqlalchemy import true from sqlalchemy.testing import config from sqlalchemy.testing import engines @@ -58,6 +59,108 @@ class TestFindUnmatchingFroms(fixtures.TablesTest): assert start == self.b assert froms == {self.a} + @testing.combinations(("lateral",), ("cartesian",), ("join",)) + def test_lateral_subqueries(self, control): + """ + :: + + test=> create table a (id integer); + CREATE TABLE + test=> create table b (id integer); + CREATE TABLE + test=> insert into a(id) values (1), (2), (3); + INSERT 0 3 + test=> insert into b(id) values (1), (2), (3); + INSERT 0 3 + + test=> select * from (select id from a) as a1, + lateral (select id from b where id=a1.id) as b1; + id | id + ----+---- + 1 | 1 + 2 | 2 + 3 | 3 + (3 rows) + + """ + p1 = select(self.a).subquery() + + p2 = select(self.b).where(self.b.c.col_b == p1.c.col_a).subquery() + + if control == "lateral": + p2 = p2.lateral() + + query = select(p1, p2) + + if control == "join": + query = query.join_from(p1, p2, p1.c.col_a == p2.c.col_b) + + froms, start = find_unmatching_froms(query, p1) + + if control == "cartesian": + assert start is p1 + assert froms == {p2} + else: + assert start is None + assert froms is None + + froms, start = find_unmatching_froms(query, p2) + + if control == "cartesian": + assert start is p2 + assert froms == {p1} + else: + assert start is None + assert froms is None + + def test_lateral_subqueries_w_joins(self): + p1 = select(self.a).subquery() + p2 = ( + select(self.b) + .where(self.b.c.col_b == p1.c.col_a) + .subquery() + .lateral() + ) + p3 = ( + select(self.c) + .where(self.c.c.col_c == p1.c.col_a) + .subquery() + .lateral() + ) + + query = select(p1, p2, p3).join_from(p1, p2, true()).join(p3, true()) + + for p in (p1, p2, p3): + froms, start = find_unmatching_froms(query, p) + assert start is None + assert froms is None + + def test_lateral_subqueries_ok_do_we_still_find_cartesians(self): + p1 = select(self.a).subquery() + + p3 = select(self.a).subquery() + + p2 = select(self.b).where(self.b.c.col_b == p3.c.col_a).subquery() + + p2 = p2.lateral() + + query = select(p1, p2, p3) + + froms, start = find_unmatching_froms(query, p1) + + assert start is p1 + assert froms == {p2, p3} + + froms, start = find_unmatching_froms(query, p2) + + assert start is p2 + assert froms == {p1} + + froms, start = find_unmatching_froms(query, p3) + + assert start is p3 + assert froms == {p1} + def test_count_non_eq_comparison_operators(self): query = select(self.a).where(self.a.c.col_a > self.b.c.col_b) froms, start = find_unmatching_froms(query, self.a) diff --git a/test/sql/test_lateral.py b/test/sql/test_lateral.py index a80ad7083e..6723be8504 100644 --- a/test/sql/test_lateral.py +++ b/test/sql/test_lateral.py @@ -7,6 +7,7 @@ from sqlalchemy import lateral from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import table +from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import true from sqlalchemy.engine import default @@ -21,6 +22,8 @@ from sqlalchemy.testing import fixtures class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): __dialect__ = default.DefaultDialect(supports_native_boolean=True) + assert_from_linting = True + run_setup_bind = None run_create_tables = None @@ -234,6 +237,65 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL): "WHERE people.people_id = books.book_owner_id) AS anon_1 ON true", ) + @testing.combinations((True,), (False,)) + def test_join_lateral_subquery_twolevel(self, use_twolevel): + people, books, bookcases = self.tables("people", "books", "bookcases") + + p1 = select( + books.c.book_id, books.c.bookcase_id, books.c.book_owner_id + ).subquery() + p2 = ( + select(bookcases.c.bookcase_id, bookcases.c.bookcase_owner_id) + .where(bookcases.c.bookcase_id == p1.c.bookcase_id) + .subquery() + .lateral() + ) + p3 = ( + select(people.c.people_id) + .where(p1.c.book_owner_id == people.c.people_id) + .subquery() + .lateral() + ) + + onelevel = ( + select(p1.c.book_id, p2.c.bookcase_id) + .select_from(p1) + .join(p2, true()) + ) + + if use_twolevel: + twolevel = onelevel.add_columns(p3.c.people_id).join(p3, true()) + + self.assert_compile( + twolevel, + "SELECT anon_1.book_id, anon_2.bookcase_id, anon_3.people_id " + "FROM (SELECT books.book_id AS book_id, books.bookcase_id AS " + "bookcase_id, books.book_owner_id AS book_owner_id " + "FROM books) " + "AS anon_1 JOIN LATERAL (SELECT bookcases.bookcase_id AS " + "bookcase_id, " + "bookcases.bookcase_owner_id AS bookcase_owner_id " + "FROM bookcases " + "WHERE bookcases.bookcase_id = anon_1.bookcase_id) " + "AS anon_2 ON true JOIN LATERAL " + "(SELECT people.people_id AS people_id FROM people " + "WHERE anon_1.book_owner_id = people.people_id) AS anon_3 " + "ON true", + ) + else: + self.assert_compile( + onelevel, + "SELECT anon_1.book_id, anon_2.bookcase_id FROM " + "(SELECT books.book_id AS book_id, books.bookcase_id " + "AS bookcase_id, books.book_owner_id AS book_owner_id " + "FROM books) AS anon_1 JOIN LATERAL " + "(SELECT bookcases.bookcase_id AS bookcase_id, " + "bookcases.bookcase_owner_id AS bookcase_owner_id " + "FROM bookcases " + "WHERE bookcases.bookcase_id = anon_1.bookcase_id) AS anon_2 " + "ON true", + ) + def test_join_lateral_w_select_implicit_subquery(self): table1 = self.tables.people table2 = self.tables.books