]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Track a second from_linter for lateral subqueries
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Feb 2021 18:38:28 +0000 (13:38 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Feb 2021 18:54:37 +0000 (13:54 -0500)
Fixed bug where the "cartesian product" assertion was not correctly
accommodating for joins between tables that relied upon the use of LATERAL
to connect from a subquery to another subquery in the enclosing context.

Additionally, enabled from_linting for the base assert_compile(),
however it remains off by default; to enable by default we would
have to make sure it isn't set for DDL compiles and there's also
a lot of tests that would also need to turn it off, so leaving
this off for expediency.

Fixes: #5924
Change-Id: I22604baf572f8c4d96befcc610b3dcb79c13fc4a

doc/build/changelog/unreleased_14/5924.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/assertions.py
test/sql/test_from_linter.py
test/sql/test_lateral.py

diff --git a/doc/build/changelog/unreleased_14/5924.rst b/doc/build/changelog/unreleased_14/5924.rst
new file mode 100644 (file)
index 0000000..f0ec874
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 5924
+
+    Fixed bug where the "cartesian product" assertion was not correctly
+    accommodating for joins between tables that relied upon the use of LATERAL
+    to connect from a subquery to another subquery in the enclosing context.
\ No newline at end of file
index 696b38e649ae3a38862d936b1d48574d3dfbabdd..f22e8614bc6d4dd5259b9213f04b75a909a634a4 100644 (file)
@@ -1996,14 +1996,24 @@ class SQLCompiler(Compiled):
         override_operator=None,
         eager_grouping=False,
         from_linter=None,
+        lateral_from_linter=None,
         **kw
     ):
         if from_linter and operators.is_comparison(binary.operator):
-            from_linter.edges.update(
-                itertools.product(
-                    binary.left._from_objects, binary.right._from_objects
+            if lateral_from_linter is not None:
+                enclosing_lateral = kw["enclosing_lateral"]
+                lateral_from_linter.edges.update(
+                    itertools.product(
+                        binary.left._from_objects + [enclosing_lateral],
+                        binary.right._from_objects + [enclosing_lateral],
+                    )
+                )
+            else:
+                from_linter.edges.update(
+                    itertools.product(
+                        binary.left._from_objects, binary.right._from_objects
+                    )
                 )
-            )
 
         # don't allow "? = ?" to render
         if (
@@ -2027,7 +2037,11 @@ class SQLCompiler(Compiled):
                 )
             else:
                 return self._generate_generic_binary(
-                    binary, opstring, from_linter=from_linter, **kw
+                    binary,
+                    opstring,
+                    from_linter=from_linter,
+                    lateral_from_linter=lateral_from_linter,
+                    **kw
                 )
 
     def visit_function_as_comparison_op_binary(self, element, operator, **kw):
@@ -2570,6 +2584,24 @@ class SQLCompiler(Compiled):
         from_linter=None,
         **kwargs
     ):
+
+        if lateral:
+            if "enclosing_lateral" not in kwargs:
+                # if lateral is set and enclosing_lateral is not
+                # present, we assume we are being called directly
+                # from visit_lateral() and we need to set enclosing_lateral.
+                assert alias._is_lateral
+                kwargs["enclosing_lateral"] = alias
+
+            # for lateral objects, we track a second from_linter that is...
+            # lateral!  to the level above us.
+            if (
+                from_linter
+                and "lateral_from_linter" not in kwargs
+                and "enclosing_lateral" in kwargs
+            ):
+                kwargs["lateral_from_linter"] = from_linter
+
         if enclosing_alias is not None and enclosing_alias.element is alias:
             inner = alias.element._compiler_dispatch(
                 self,
index 79125e1f16af904eaafc0bf36f279e56535aa885..44d7405b12e7372f09437e180e6c91ce90aaeafd 100644 (file)
@@ -394,6 +394,7 @@ class AssertsCompiledSQL(object):
         schema_translate_map=None,
         render_schema_translate=False,
         default_schema_name=None,
+        from_linting=False,
     ):
         if use_default_dialect:
             dialect = default.DefaultDialect()
@@ -438,6 +439,9 @@ class AssertsCompiledSQL(object):
         if render_schema_translate:
             kw["render_schema_translate"] = True
 
+        if from_linting or getattr(self, "assert_from_linting", False):
+            kw["linting"] = sql.FROM_LINTING
+
         from sqlalchemy import orm
 
         if isinstance(clause, orm.dynamic.AppenderQuery):
index b0bcee18e242ed6fdd4b714b2f3cf4a37fc2af9b..9e0ededecc3e5bae67f7438f2a5f1b6e03974076 100644 (file)
@@ -1,6 +1,7 @@
 from sqlalchemy import Integer
 from sqlalchemy import select
 from sqlalchemy import sql
+from sqlalchemy import testing
 from sqlalchemy import true
 from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
@@ -58,6 +59,108 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
         assert start == self.b
         assert froms == {self.a}
 
+    @testing.combinations(("lateral",), ("cartesian",), ("join",))
+    def test_lateral_subqueries(self, control):
+        """
+        ::
+
+            test=> create table a (id integer);
+            CREATE TABLE
+            test=> create table b (id integer);
+            CREATE TABLE
+            test=> insert into a(id) values (1), (2), (3);
+            INSERT 0 3
+            test=> insert into b(id) values (1), (2), (3);
+            INSERT 0 3
+
+            test=> select * from (select id from a) as a1,
+            lateral (select id from b where id=a1.id) as b1;
+            id | id
+            ----+----
+            1 |  1
+            2 |  2
+            3 |  3
+            (3 rows)
+
+        """
+        p1 = select(self.a).subquery()
+
+        p2 = select(self.b).where(self.b.c.col_b == p1.c.col_a).subquery()
+
+        if control == "lateral":
+            p2 = p2.lateral()
+
+        query = select(p1, p2)
+
+        if control == "join":
+            query = query.join_from(p1, p2, p1.c.col_a == p2.c.col_b)
+
+        froms, start = find_unmatching_froms(query, p1)
+
+        if control == "cartesian":
+            assert start is p1
+            assert froms == {p2}
+        else:
+            assert start is None
+            assert froms is None
+
+        froms, start = find_unmatching_froms(query, p2)
+
+        if control == "cartesian":
+            assert start is p2
+            assert froms == {p1}
+        else:
+            assert start is None
+            assert froms is None
+
+    def test_lateral_subqueries_w_joins(self):
+        p1 = select(self.a).subquery()
+        p2 = (
+            select(self.b)
+            .where(self.b.c.col_b == p1.c.col_a)
+            .subquery()
+            .lateral()
+        )
+        p3 = (
+            select(self.c)
+            .where(self.c.c.col_c == p1.c.col_a)
+            .subquery()
+            .lateral()
+        )
+
+        query = select(p1, p2, p3).join_from(p1, p2, true()).join(p3, true())
+
+        for p in (p1, p2, p3):
+            froms, start = find_unmatching_froms(query, p)
+            assert start is None
+            assert froms is None
+
+    def test_lateral_subqueries_ok_do_we_still_find_cartesians(self):
+        p1 = select(self.a).subquery()
+
+        p3 = select(self.a).subquery()
+
+        p2 = select(self.b).where(self.b.c.col_b == p3.c.col_a).subquery()
+
+        p2 = p2.lateral()
+
+        query = select(p1, p2, p3)
+
+        froms, start = find_unmatching_froms(query, p1)
+
+        assert start is p1
+        assert froms == {p2, p3}
+
+        froms, start = find_unmatching_froms(query, p2)
+
+        assert start is p2
+        assert froms == {p1}
+
+        froms, start = find_unmatching_froms(query, p3)
+
+        assert start is p3
+        assert froms == {p1}
+
     def test_count_non_eq_comparison_operators(self):
         query = select(self.a).where(self.a.c.col_a > self.b.c.col_b)
         froms, start = find_unmatching_froms(query, self.a)
index a80ad7083e642d94b49ae13ebd98cf08f0ef8b6e..6723be8504a49b9c6efcc42c1c84bfbaaf632f8d 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import lateral
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import table
+from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import true
 from sqlalchemy.engine import default
@@ -21,6 +22,8 @@ from sqlalchemy.testing import fixtures
 class LateralTest(fixtures.TablesTest, AssertsCompiledSQL):
     __dialect__ = default.DefaultDialect(supports_native_boolean=True)
 
+    assert_from_linting = True
+
     run_setup_bind = None
 
     run_create_tables = None
@@ -234,6 +237,65 @@ class LateralTest(fixtures.TablesTest, AssertsCompiledSQL):
             "WHERE people.people_id = books.book_owner_id) AS anon_1 ON true",
         )
 
+    @testing.combinations((True,), (False,))
+    def test_join_lateral_subquery_twolevel(self, use_twolevel):
+        people, books, bookcases = self.tables("people", "books", "bookcases")
+
+        p1 = select(
+            books.c.book_id, books.c.bookcase_id, books.c.book_owner_id
+        ).subquery()
+        p2 = (
+            select(bookcases.c.bookcase_id, bookcases.c.bookcase_owner_id)
+            .where(bookcases.c.bookcase_id == p1.c.bookcase_id)
+            .subquery()
+            .lateral()
+        )
+        p3 = (
+            select(people.c.people_id)
+            .where(p1.c.book_owner_id == people.c.people_id)
+            .subquery()
+            .lateral()
+        )
+
+        onelevel = (
+            select(p1.c.book_id, p2.c.bookcase_id)
+            .select_from(p1)
+            .join(p2, true())
+        )
+
+        if use_twolevel:
+            twolevel = onelevel.add_columns(p3.c.people_id).join(p3, true())
+
+            self.assert_compile(
+                twolevel,
+                "SELECT anon_1.book_id, anon_2.bookcase_id, anon_3.people_id "
+                "FROM (SELECT books.book_id AS book_id, books.bookcase_id AS "
+                "bookcase_id, books.book_owner_id AS book_owner_id "
+                "FROM books) "
+                "AS anon_1 JOIN LATERAL (SELECT bookcases.bookcase_id AS "
+                "bookcase_id, "
+                "bookcases.bookcase_owner_id AS bookcase_owner_id "
+                "FROM bookcases "
+                "WHERE bookcases.bookcase_id = anon_1.bookcase_id) "
+                "AS anon_2 ON true JOIN LATERAL "
+                "(SELECT people.people_id AS people_id FROM people "
+                "WHERE anon_1.book_owner_id = people.people_id) AS anon_3 "
+                "ON true",
+            )
+        else:
+            self.assert_compile(
+                onelevel,
+                "SELECT anon_1.book_id, anon_2.bookcase_id FROM "
+                "(SELECT books.book_id AS book_id, books.bookcase_id "
+                "AS bookcase_id, books.book_owner_id AS book_owner_id "
+                "FROM books) AS anon_1 JOIN LATERAL "
+                "(SELECT bookcases.bookcase_id AS bookcase_id, "
+                "bookcases.bookcase_owner_id AS bookcase_owner_id "
+                "FROM bookcases "
+                "WHERE bookcases.bookcase_id = anon_1.bookcase_id) AS anon_2 "
+                "ON true",
+            )
+
     def test_join_lateral_w_select_implicit_subquery(self):
         table1 = self.tables.people
         table2 = self.tables.books