]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Initial version of cartesian product linter
authorAlessio Bogon <youtux@gmail.com>
Mon, 22 Jul 2019 18:37:07 +0000 (20:37 +0200)
committerAlessio Bogon <youtux@gmail.com>
Thu, 29 Aug 2019 17:10:21 +0000 (19:10 +0200)
lib/sqlalchemy/ext/linter.py [new file with mode: 0644]
test/ext/test_linter.py [new file with mode: 0644]

diff --git a/lib/sqlalchemy/ext/linter.py b/lib/sqlalchemy/ext/linter.py
new file mode 100644 (file)
index 0000000..f265316
--- /dev/null
@@ -0,0 +1,92 @@
+import collections
+import itertools
+
+from sqlalchemy import util
+
+from sqlalchemy.sql import visitors
+from sqlalchemy.sql.expression import Select
+
+
+def before_execute_hook(conn, clauseelement, multiparams, params):
+    if isinstance(clauseelement, Select):
+        warn_for_cartesian(clauseelement)
+    else:
+        raise NotImplementedError
+
+# from sqlalchemy.ext.compiler import compiles
+# @compiles(Select)
+# def select_warn_for_cartesian_compiler(element, compiler, **kw):
+#     warn_for_cartesian(element)
+#     return compiler.visit_select(element, **kw)
+
+
+def warn_for_cartesian(element):
+    froms = set(element.froms)
+    if not froms:
+        return
+    edges = set()
+
+    # find all "a <operator> b", add that as edges
+    def visit_binary(binary_element):
+        edges.update(
+            itertools.product(
+                binary_element.left._from_objects,
+                binary_element.right._from_objects,
+            )
+        )
+
+    # find all "a JOIN b", add "a" and "b" as froms
+    def visit_join(join_element):
+        if join_element in froms:
+            froms.remove(join_element)
+            froms.update((join_element.left, join_element.right))
+
+    # unwrap "FromGrouping" objects, e.g. parentheized froms
+    def visit_grouping(grouping_element):
+        if grouping_element in froms:
+            froms.remove(grouping_element)
+
+            # the enclosed element will often be a JOIN.  The visitors.traverse
+            # does a depth-first outside-in traversal so the next
+            # call will be visit_join() of this element :)
+            froms.add(grouping_element.element)
+
+    visitors.traverse(
+        element,
+        {},
+        {
+            "binary": visit_binary,
+            "join": visit_join,
+            "grouping": visit_grouping,
+        },
+    )
+
+    # take any element from the list of FROMS.
+    # then traverse all the edges and ensure we can reach
+    # all other FROMS
+    start_with = froms.pop()
+    the_rest = froms
+    stack = collections.deque([start_with])
+    while stack and the_rest:
+        node = stack.popleft()
+        the_rest.discard(node)
+        for edge in list(edges):
+            if edge not in edges:
+                continue
+            elif edge[0] is node:
+                edges.remove(edge)
+                stack.appendleft(edge[1])
+            elif edge[1] is node:
+                edges.remove(edge)
+                stack.appendleft(edge[0])
+
+    # FROMS left over?  boom
+    if the_rest:
+        util.warn(
+            'for stmt %s FROM elements %s are not joined up to FROM element "%r"'
+            % (
+                id(element),  # defeat the warnings filter
+                ", ".join('"%r"' % f for f in the_rest),
+                start_with,
+            )
+        )
\ No newline at end of file
diff --git a/test/ext/test_linter.py b/test/ext/test_linter.py
new file mode 100644 (file)
index 0000000..fa5c215
--- /dev/null
@@ -0,0 +1,133 @@
+from sqlalchemy import select, Integer, event, testing
+from sqlalchemy.ext import linter
+from sqlalchemy.testing import fixtures, expect_warnings
+from sqlalchemy.testing.schema import Table, Column
+
+
+class TestLinter(fixtures.TablesTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
+        Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+        Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
+        Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
+
+    def setup(self):
+        self.a = self.tables.table_a
+        self.b = self.tables.table_b
+        self.c = self.tables.table_c
+        self.d = self.tables.table_d
+        event.listen(testing.db, 'before_execute', linter.before_execute_hook)
+
+    def teardown(self):
+        event.remove(testing.db, 'before_execute', linter.before_execute_hook)
+
+    def test_everything_is_connected(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+            .select_from(self.c)
+            .select_from(self.d)
+            .where(self.d.c.col_d == self.b.c.col_b)
+            .where(self.c.c.col_c == self.d.c.col_d)
+            .where(self.c.c.col_c == 5)
+        )
+        with testing.db.connect() as conn:
+            conn.execute(query)
+
+    def test_plain_cartesian(self):
+        query = (
+            select([self.a])
+            .where(self.b.c.col_b == 5)
+        )
+        with expect_warnings(
+            r"for stmt .* FROM elements .*table_b.*col_b.* "
+            r"are not joined up to FROM element .*table_a.*col_a.*"
+        ):
+            with testing.db.connect() as conn:
+                conn.execute(query)
+
+    def test_disconnect_between_ab_cd(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+            .select_from(self.c)
+            .select_from(self.d)
+            .where(self.c.c.col_c == self.d.c.col_d)
+            .where(self.c.c.col_c == 5)
+        )
+        with expect_warnings(
+            # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
+            r"for stmt .* FROM elements .* "
+            r"are not joined up to FROM element .*"
+        ):
+            with testing.db.connect() as conn:
+                conn.execute(query)
+
+    def test_c_and_d_both_disconnected(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+            .where(self.c.c.col_c == 5)
+            .where(self.d.c.col_d == 10)
+        )
+        with expect_warnings(
+            # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
+            r"for stmt .* FROM elements .* "
+            r"are not joined up to FROM element .*"
+        ):
+            with testing.db.connect() as conn:
+                conn.execute(query)
+
+    def test_now_connected(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+            .select_from(self.c.join(self.d, self.c.c.col_c == self.d.c.col_d))
+            .where(self.c.c.col_c == self.b.c.col_b)
+            .where(self.c.c.col_c == 5)
+            .where(self.d.c.col_d == 10)
+        )
+        with testing.db.connect() as conn:
+            conn.execute(query)
+
+    def test_disconnected_subquery(self):
+        subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
+        stmt = select([self.c]).select_from(subq)
+        with expect_warnings(
+            # TODO: Fix subquery not being displayed properly
+            r"for stmt .* FROM elements .* "
+            r"are not joined up to FROM element .*table_c.*col_c.*"
+        ):
+            with testing.db.connect() as conn:
+                conn.execute(stmt)
+
+    def test_now_connect_it(self):
+        subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
+        stmt = select([self.c]).select_from(subq).where(self.c.c.col_c == subq.c.col_a)
+        with testing.db.connect() as conn:
+            conn.execute(stmt)
+
+    def test_right_nested_join_without_issue(self):
+        query = (
+            select([self.a])
+            .select_from(
+                self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b)
+            )
+        )
+        with testing.db.connect() as conn:
+            conn.execute(query)
+
+    def test_right_nested_join_with_an_issue(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b))
+            .where(self.d.c.col_d == 5)
+        )
+        with expect_warnings(
+                # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
+            r"for stmt .* FROM elements .* "
+            r"are not joined up to FROM element .*"
+        ):
+            with testing.db.connect() as conn:
+                conn.execute(query)
\ No newline at end of file