From: Alessio Bogon Date: Mon, 22 Jul 2019 18:37:07 +0000 (+0200) Subject: Initial version of cartesian product linter X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e6e847cc5fb2197e91d7f5cbf7826c0c592cc7be;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Initial version of cartesian product linter --- diff --git a/lib/sqlalchemy/ext/linter.py b/lib/sqlalchemy/ext/linter.py new file mode 100644 index 0000000000..f265316a59 --- /dev/null +++ b/lib/sqlalchemy/ext/linter.py @@ -0,0 +1,92 @@ +import collections +import itertools + +from sqlalchemy import util + +from sqlalchemy.sql import visitors +from sqlalchemy.sql.expression import Select + + +def before_execute_hook(conn, clauseelement, multiparams, params): + if isinstance(clauseelement, Select): + warn_for_cartesian(clauseelement) + else: + raise NotImplementedError + +# from sqlalchemy.ext.compiler import compiles +# @compiles(Select) +# def select_warn_for_cartesian_compiler(element, compiler, **kw): +# warn_for_cartesian(element) +# return compiler.visit_select(element, **kw) + + +def warn_for_cartesian(element): + froms = set(element.froms) + if not froms: + return + edges = set() + + # find all "a b", add that as edges + def visit_binary(binary_element): + edges.update( + itertools.product( + binary_element.left._from_objects, + binary_element.right._from_objects, + ) + ) + + # find all "a JOIN b", add "a" and "b" as froms + def visit_join(join_element): + if join_element in froms: + froms.remove(join_element) + froms.update((join_element.left, join_element.right)) + + # unwrap "FromGrouping" objects, e.g. parentheized froms + def visit_grouping(grouping_element): + if grouping_element in froms: + froms.remove(grouping_element) + + # the enclosed element will often be a JOIN. The visitors.traverse + # does a depth-first outside-in traversal so the next + # call will be visit_join() of this element :) + froms.add(grouping_element.element) + + visitors.traverse( + element, + {}, + { + "binary": visit_binary, + "join": visit_join, + "grouping": visit_grouping, + }, + ) + + # take any element from the list of FROMS. + # then traverse all the edges and ensure we can reach + # all other FROMS + start_with = froms.pop() + the_rest = froms + stack = collections.deque([start_with]) + while stack and the_rest: + node = stack.popleft() + the_rest.discard(node) + for edge in list(edges): + if edge not in edges: + continue + elif edge[0] is node: + edges.remove(edge) + stack.appendleft(edge[1]) + elif edge[1] is node: + edges.remove(edge) + stack.appendleft(edge[0]) + + # FROMS left over? boom + if the_rest: + util.warn( + 'for stmt %s FROM elements %s are not joined up to FROM element "%r"' + % ( + id(element), # defeat the warnings filter + ", ".join('"%r"' % f for f in the_rest), + start_with, + ) + ) \ No newline at end of file diff --git a/test/ext/test_linter.py b/test/ext/test_linter.py new file mode 100644 index 0000000000..fa5c215004 --- /dev/null +++ b/test/ext/test_linter.py @@ -0,0 +1,133 @@ +from sqlalchemy import select, Integer, event, testing +from sqlalchemy.ext import linter +from sqlalchemy.testing import fixtures, expect_warnings +from sqlalchemy.testing.schema import Table, Column + + +class TestLinter(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) + Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + Table("table_c", metadata, Column("col_c", Integer, primary_key=True)) + Table("table_d", metadata, Column("col_d", Integer, primary_key=True)) + + def setup(self): + self.a = self.tables.table_a + self.b = self.tables.table_b + self.c = self.tables.table_c + self.d = self.tables.table_d + event.listen(testing.db, 'before_execute', linter.before_execute_hook) + + def teardown(self): + event.remove(testing.db, 'before_execute', linter.before_execute_hook) + + def test_everything_is_connected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c) + .select_from(self.d) + .where(self.d.c.col_d == self.b.c.col_b) + .where(self.c.c.col_c == self.d.c.col_d) + .where(self.c.c.col_c == 5) + ) + with testing.db.connect() as conn: + conn.execute(query) + + def test_plain_cartesian(self): + query = ( + select([self.a]) + .where(self.b.c.col_b == 5) + ) + with expect_warnings( + r"for stmt .* FROM elements .*table_b.*col_b.* " + r"are not joined up to FROM element .*table_a.*col_a.*" + ): + with testing.db.connect() as conn: + conn.execute(query) + + def test_disconnect_between_ab_cd(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c) + .select_from(self.d) + .where(self.c.c.col_c == self.d.c.col_d) + .where(self.c.c.col_c == 5) + ) + with expect_warnings( + # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed) + r"for stmt .* FROM elements .* " + r"are not joined up to FROM element .*" + ): + with testing.db.connect() as conn: + conn.execute(query) + + def test_c_and_d_both_disconnected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .where(self.c.c.col_c == 5) + .where(self.d.c.col_d == 10) + ) + with expect_warnings( + # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed) + r"for stmt .* FROM elements .* " + r"are not joined up to FROM element .*" + ): + with testing.db.connect() as conn: + conn.execute(query) + + def test_now_connected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c.join(self.d, self.c.c.col_c == self.d.c.col_d)) + .where(self.c.c.col_c == self.b.c.col_b) + .where(self.c.c.col_c == 5) + .where(self.d.c.col_d == 10) + ) + with testing.db.connect() as conn: + conn.execute(query) + + def test_disconnected_subquery(self): + subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() + stmt = select([self.c]).select_from(subq) + with expect_warnings( + # TODO: Fix subquery not being displayed properly + r"for stmt .* FROM elements .* " + r"are not joined up to FROM element .*table_c.*col_c.*" + ): + with testing.db.connect() as conn: + conn.execute(stmt) + + def test_now_connect_it(self): + subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() + stmt = select([self.c]).select_from(subq).where(self.c.c.col_c == subq.c.col_a) + with testing.db.connect() as conn: + conn.execute(stmt) + + def test_right_nested_join_without_issue(self): + query = ( + select([self.a]) + .select_from( + self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b) + ) + ) + with testing.db.connect() as conn: + conn.execute(query) + + def test_right_nested_join_with_an_issue(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b)) + .where(self.d.c.col_d == 5) + ) + with expect_warnings( + # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed) + r"for stmt .* FROM elements .* " + r"are not joined up to FROM element .*" + ): + with testing.db.connect() as conn: + conn.execute(query) \ No newline at end of file