]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Split warn_for_cartesian in 2 different functions
authorAlessio Bogon <youtux@gmail.com>
Thu, 29 Aug 2019 18:53:27 +0000 (20:53 +0200)
committerAlessio Bogon <youtux@gmail.com>
Thu, 29 Aug 2019 18:53:27 +0000 (20:53 +0200)
This will help with testing *a lot*

lib/sqlalchemy/ext/linter.py
test/ext/test_linter.py

index f265316a5918df170c1b57530dbffad281b59359..549127f8b8fefef393c2b4293e6445bf09c4ab6d 100644 (file)
@@ -13,14 +13,9 @@ def before_execute_hook(conn, clauseelement, multiparams, params):
     else:
         raise NotImplementedError
 
-# from sqlalchemy.ext.compiler import compiles
-# @compiles(Select)
-# def select_warn_for_cartesian_compiler(element, compiler, **kw):
-#     warn_for_cartesian(element)
-#     return compiler.visit_select(element, **kw)
 
-
-def warn_for_cartesian(element):
+def find_unmatching_froms(element, start_with=None):
+    # TODO: It would be nicer to use OrderedSet, but it seems to not be too much optimize, so let's skip for now
     froms = set(element.froms)
     if not froms:
         return
@@ -64,11 +59,16 @@ def warn_for_cartesian(element):
     # take any element from the list of FROMS.
     # then traverse all the edges and ensure we can reach
     # all other FROMS
-    start_with = froms.pop()
+    if start_with is not None:
+        assert start_with in froms
+    else:
+        start_with = next(iter(froms))
+    froms.remove(start_with)
     the_rest = froms
     stack = collections.deque([start_with])
     while stack and the_rest:
         node = stack.popleft()
+        # the_rest.pop(node, None)
         the_rest.discard(node)
         for edge in list(edges):
             if edge not in edges:
@@ -82,11 +82,18 @@ def warn_for_cartesian(element):
 
     # FROMS left over?  boom
     if the_rest:
+        return the_rest, start_with
+    else:
+        return None, None
+
+def warn_for_cartesian(element):
+    froms, start_with = find_unmatching_froms(element)
+    if froms:
         util.warn(
             'for stmt %s FROM elements %s are not joined up to FROM element "%r"'
             % (
                 id(element),  # defeat the warnings filter
-                ", ".join('"%r"' % f for f in the_rest),
+                ", ".join('"%r"' % f for f in froms),
                 start_with,
             )
-        )
\ No newline at end of file
+        )
index fa5c215004b8759058f24bf330b20176f11c726b..da9cce1748646a5a5c87eeb93f209a11257af490 100644 (file)
@@ -1,10 +1,11 @@
 from sqlalchemy import select, Integer, event, testing
 from sqlalchemy.ext import linter
+from sqlalchemy.ext.linter import find_unmatching_froms
 from sqlalchemy.testing import fixtures, expect_warnings
 from sqlalchemy.testing.schema import Table, Column
 
 
-class TestLinter(fixtures.TablesTest):
+class TestFinder(fixtures.TablesTest):
     @classmethod
     def define_tables(cls, metadata):
         Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
@@ -17,10 +18,6 @@ class TestLinter(fixtures.TablesTest):
         self.b = self.tables.table_b
         self.c = self.tables.table_c
         self.d = self.tables.table_d
-        event.listen(testing.db, 'before_execute', linter.before_execute_hook)
-
-    def teardown(self):
-        event.remove(testing.db, 'before_execute', linter.before_execute_hook)
 
     def test_everything_is_connected(self):
         query = (
@@ -32,20 +29,25 @@ class TestLinter(fixtures.TablesTest):
             .where(self.c.c.col_c == self.d.c.col_d)
             .where(self.c.c.col_c == 5)
         )
-        with testing.db.connect() as conn:
-            conn.execute(query)
+        froms, start = find_unmatching_froms(query)
+        assert not froms
+
+        for start in self.a, self.b, self.c, self.d:
+            froms, start = find_unmatching_froms(query, start)
+            assert not froms
 
     def test_plain_cartesian(self):
         query = (
             select([self.a])
             .where(self.b.c.col_b == 5)
         )
-        with expect_warnings(
-            r"for stmt .* FROM elements .*table_b.*col_b.* "
-            r"are not joined up to FROM element .*table_a.*col_a.*"
-        ):
-            with testing.db.connect() as conn:
-                conn.execute(query)
+        froms, start = find_unmatching_froms(query, self.a)
+        assert start == self.a
+        assert froms == {self.b}
+
+        froms, start = find_unmatching_froms(query, self.b)
+        assert start == self.b
+        assert froms == {self.a}
 
     def test_disconnect_between_ab_cd(self):
         query = (
@@ -56,13 +58,14 @@ class TestLinter(fixtures.TablesTest):
             .where(self.c.c.col_c == self.d.c.col_d)
             .where(self.c.c.col_c == 5)
         )
-        with expect_warnings(
-            # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
-            r"for stmt .* FROM elements .* "
-            r"are not joined up to FROM element .*"
-        ):
-            with testing.db.connect() as conn:
-                conn.execute(query)
+        for start in self.a, self.b:
+            froms, start = find_unmatching_froms(query, start)
+            assert start == start
+            assert froms == {self.c, self.d}
+        for start in self.c, self.d:
+            froms, start = find_unmatching_froms(query, start)
+            assert start == start
+            assert froms == {self.a, self.b}
 
     def test_c_and_d_both_disconnected(self):
         query = (
@@ -71,13 +74,18 @@ class TestLinter(fixtures.TablesTest):
             .where(self.c.c.col_c == 5)
             .where(self.d.c.col_d == 10)
         )
-        with expect_warnings(
-            # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
-            r"for stmt .* FROM elements .* "
-            r"are not joined up to FROM element .*"
-        ):
-            with testing.db.connect() as conn:
-                conn.execute(query)
+        for start in self.a, self.b:
+            froms, start = find_unmatching_froms(query, start)
+            assert start == start
+            assert froms == {self.c, self.d}
+
+        froms, start = find_unmatching_froms(query, self.c)
+        assert start == self.c
+        assert froms == {self.a, self.b, self.d}
+
+        froms, start = find_unmatching_froms(query, self.d)
+        assert start == self.d
+        assert froms == {self.a, self.b, self.c}
 
     def test_now_connected(self):
         query = (
@@ -88,25 +96,35 @@ class TestLinter(fixtures.TablesTest):
             .where(self.c.c.col_c == 5)
             .where(self.d.c.col_d == 10)
         )
-        with testing.db.connect() as conn:
-            conn.execute(query)
+        froms, start = find_unmatching_froms(query)
+        assert not froms
+
+        for start in self.a, self.b, self.c, self.d:
+            froms, start = find_unmatching_froms(query, start)
+            assert not froms
 
     def test_disconnected_subquery(self):
         subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
         stmt = select([self.c]).select_from(subq)
-        with expect_warnings(
-            # TODO: Fix subquery not being displayed properly
-            r"for stmt .* FROM elements .* "
-            r"are not joined up to FROM element .*table_c.*col_c.*"
-        ):
-            with testing.db.connect() as conn:
-                conn.execute(stmt)
+
+        froms, start = find_unmatching_froms(stmt, self.c)
+        assert start == self.c
+        assert froms == {subq}
+
+        froms, start = find_unmatching_froms(stmt, subq)
+        assert start == subq
+        assert froms == {self.c}
 
     def test_now_connect_it(self):
         subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
         stmt = select([self.c]).select_from(subq).where(self.c.c.col_c == subq.c.col_a)
-        with testing.db.connect() as conn:
-            conn.execute(stmt)
+
+        froms, start = find_unmatching_froms(stmt)
+        assert not froms
+
+        for start in self.c, subq:
+            froms, start = find_unmatching_froms(stmt, start)
+            assert not froms
 
     def test_right_nested_join_without_issue(self):
         query = (
@@ -115,19 +133,64 @@ class TestLinter(fixtures.TablesTest):
                 self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b)
             )
         )
-        with testing.db.connect() as conn:
-            conn.execute(query)
+        froms, start = find_unmatching_froms(query)
+        assert not froms
+
+        for start in self.a, self.b, self.c:
+            froms, start = find_unmatching_froms(query, start)
+            assert not froms
 
     def test_right_nested_join_with_an_issue(self):
         query = (
             select([self.a])
-            .select_from(self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b))
+            .select_from(
+                self.a.join(
+                    self.b.join(self.c, self.b.c.col_b == self.c.c.col_c),
+                    self.a.c.col_a == self.b.c.col_b,
+                ),
+            )
             .where(self.d.c.col_d == 5)
         )
+
+        for start in self.a, self.b, self.c:
+            froms, start = find_unmatching_froms(query, start)
+            assert start == start
+            assert froms == {self.d}
+
+        froms, start = find_unmatching_froms(query, self.d)
+        assert start == self.d
+        assert froms == {self.a, self.b, self.c}
+
+
+class TestLinter(fixtures.TablesTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
+        Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+        Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
+        Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
+
+    def setup(self):
+        self.a = self.tables.table_a
+        self.b = self.tables.table_b
+        self.c = self.tables.table_c
+        self.d = self.tables.table_d
+        event.listen(testing.db, 'before_execute', linter.before_execute_hook)
+
+    def test_integration(self):
+        query = (
+            select([self.a])
+            .where(self.b.c.col_b == 5)
+        )
+        # TODO:
+        #  - make it a unit by mocking or spying "find_unmatching_froms"
+        #  - Make error string proper
         with expect_warnings(
-                # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
-            r"for stmt .* FROM elements .* "
-            r"are not joined up to FROM element .*"
+            r"for stmt .* FROM elements .*table_.*col_.* "
+            r"are not joined up to FROM element .*table_.*col_.*"
         ):
             with testing.db.connect() as conn:
-                conn.execute(query)
\ No newline at end of file
+                conn.execute(query)
+
+    def teardown(self):
+        event.remove(testing.db, 'before_execute', linter.before_execute_hook)