From 700b48b091311e5e5c6a4b836a70a73a33206f47 Mon Sep 17 00:00:00 2001 From: Alessio Bogon Date: Thu, 29 Aug 2019 20:53:27 +0200 Subject: [PATCH] Split warn_for_cartesian in 2 different functions This will help with testing *a lot* --- lib/sqlalchemy/ext/linter.py | 27 ++++--- test/ext/test_linter.py | 153 ++++++++++++++++++++++++----------- 2 files changed, 125 insertions(+), 55 deletions(-) diff --git a/lib/sqlalchemy/ext/linter.py b/lib/sqlalchemy/ext/linter.py index f265316a59..549127f8b8 100644 --- a/lib/sqlalchemy/ext/linter.py +++ b/lib/sqlalchemy/ext/linter.py @@ -13,14 +13,9 @@ def before_execute_hook(conn, clauseelement, multiparams, params): else: raise NotImplementedError -# from sqlalchemy.ext.compiler import compiles -# @compiles(Select) -# def select_warn_for_cartesian_compiler(element, compiler, **kw): -# warn_for_cartesian(element) -# return compiler.visit_select(element, **kw) - -def warn_for_cartesian(element): +def find_unmatching_froms(element, start_with=None): + # TODO: It would be nicer to use OrderedSet, but it seems to not be too much optimize, so let's skip for now froms = set(element.froms) if not froms: return @@ -64,11 +59,16 @@ def warn_for_cartesian(element): # take any element from the list of FROMS. # then traverse all the edges and ensure we can reach # all other FROMS - start_with = froms.pop() + if start_with is not None: + assert start_with in froms + else: + start_with = next(iter(froms)) + froms.remove(start_with) the_rest = froms stack = collections.deque([start_with]) while stack and the_rest: node = stack.popleft() + # the_rest.pop(node, None) the_rest.discard(node) for edge in list(edges): if edge not in edges: @@ -82,11 +82,18 @@ def warn_for_cartesian(element): # FROMS left over? boom if the_rest: + return the_rest, start_with + else: + return None, None + +def warn_for_cartesian(element): + froms, start_with = find_unmatching_froms(element) + if froms: util.warn( 'for stmt %s FROM elements %s are not joined up to FROM element "%r"' % ( id(element), # defeat the warnings filter - ", ".join('"%r"' % f for f in the_rest), + ", ".join('"%r"' % f for f in froms), start_with, ) - ) \ No newline at end of file + ) diff --git a/test/ext/test_linter.py b/test/ext/test_linter.py index fa5c215004..da9cce1748 100644 --- a/test/ext/test_linter.py +++ b/test/ext/test_linter.py @@ -1,10 +1,11 @@ from sqlalchemy import select, Integer, event, testing from sqlalchemy.ext import linter +from sqlalchemy.ext.linter import find_unmatching_froms from sqlalchemy.testing import fixtures, expect_warnings from sqlalchemy.testing.schema import Table, Column -class TestLinter(fixtures.TablesTest): +class TestFinder(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) @@ -17,10 +18,6 @@ class TestLinter(fixtures.TablesTest): self.b = self.tables.table_b self.c = self.tables.table_c self.d = self.tables.table_d - event.listen(testing.db, 'before_execute', linter.before_execute_hook) - - def teardown(self): - event.remove(testing.db, 'before_execute', linter.before_execute_hook) def test_everything_is_connected(self): query = ( @@ -32,20 +29,25 @@ class TestLinter(fixtures.TablesTest): .where(self.c.c.col_c == self.d.c.col_d) .where(self.c.c.col_c == 5) ) - with testing.db.connect() as conn: - conn.execute(query) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert not froms def test_plain_cartesian(self): query = ( select([self.a]) .where(self.b.c.col_b == 5) ) - with expect_warnings( - r"for stmt .* FROM elements .*table_b.*col_b.* " - r"are not joined up to FROM element .*table_a.*col_a.*" - ): - with testing.db.connect() as conn: - conn.execute(query) + froms, start = find_unmatching_froms(query, self.a) + assert start == self.a + assert froms == {self.b} + + froms, start = find_unmatching_froms(query, self.b) + assert start == self.b + assert froms == {self.a} def test_disconnect_between_ab_cd(self): query = ( @@ -56,13 +58,14 @@ class TestLinter(fixtures.TablesTest): .where(self.c.c.col_c == self.d.c.col_d) .where(self.c.c.col_c == 5) ) - with expect_warnings( - # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed) - r"for stmt .* FROM elements .* " - r"are not joined up to FROM element .*" - ): - with testing.db.connect() as conn: - conn.execute(query) + for start in self.a, self.b: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.c, self.d} + for start in self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.a, self.b} def test_c_and_d_both_disconnected(self): query = ( @@ -71,13 +74,18 @@ class TestLinter(fixtures.TablesTest): .where(self.c.c.col_c == 5) .where(self.d.c.col_d == 10) ) - with expect_warnings( - # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed) - r"for stmt .* FROM elements .* " - r"are not joined up to FROM element .*" - ): - with testing.db.connect() as conn: - conn.execute(query) + for start in self.a, self.b: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.c, self.d} + + froms, start = find_unmatching_froms(query, self.c) + assert start == self.c + assert froms == {self.a, self.b, self.d} + + froms, start = find_unmatching_froms(query, self.d) + assert start == self.d + assert froms == {self.a, self.b, self.c} def test_now_connected(self): query = ( @@ -88,25 +96,35 @@ class TestLinter(fixtures.TablesTest): .where(self.c.c.col_c == 5) .where(self.d.c.col_d == 10) ) - with testing.db.connect() as conn: - conn.execute(query) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert not froms def test_disconnected_subquery(self): subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() stmt = select([self.c]).select_from(subq) - with expect_warnings( - # TODO: Fix subquery not being displayed properly - r"for stmt .* FROM elements .* " - r"are not joined up to FROM element .*table_c.*col_c.*" - ): - with testing.db.connect() as conn: - conn.execute(stmt) + + froms, start = find_unmatching_froms(stmt, self.c) + assert start == self.c + assert froms == {subq} + + froms, start = find_unmatching_froms(stmt, subq) + assert start == subq + assert froms == {self.c} def test_now_connect_it(self): subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() stmt = select([self.c]).select_from(subq).where(self.c.c.col_c == subq.c.col_a) - with testing.db.connect() as conn: - conn.execute(stmt) + + froms, start = find_unmatching_froms(stmt) + assert not froms + + for start in self.c, subq: + froms, start = find_unmatching_froms(stmt, start) + assert not froms def test_right_nested_join_without_issue(self): query = ( @@ -115,19 +133,64 @@ class TestLinter(fixtures.TablesTest): self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b) ) ) - with testing.db.connect() as conn: - conn.execute(query) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c: + froms, start = find_unmatching_froms(query, start) + assert not froms def test_right_nested_join_with_an_issue(self): query = ( select([self.a]) - .select_from(self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b)) + .select_from( + self.a.join( + self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), + self.a.c.col_a == self.b.c.col_b, + ), + ) .where(self.d.c.col_d == 5) ) + + for start in self.a, self.b, self.c: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.d} + + froms, start = find_unmatching_froms(query, self.d) + assert start == self.d + assert froms == {self.a, self.b, self.c} + + +class TestLinter(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) + Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + Table("table_c", metadata, Column("col_c", Integer, primary_key=True)) + Table("table_d", metadata, Column("col_d", Integer, primary_key=True)) + + def setup(self): + self.a = self.tables.table_a + self.b = self.tables.table_b + self.c = self.tables.table_c + self.d = self.tables.table_d + event.listen(testing.db, 'before_execute', linter.before_execute_hook) + + def test_integration(self): + query = ( + select([self.a]) + .where(self.b.c.col_b == 5) + ) + # TODO: + # - make it a unit by mocking or spying "find_unmatching_froms" + # - Make error string proper with expect_warnings( - # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed) - r"for stmt .* FROM elements .* " - r"are not joined up to FROM element .*" + r"for stmt .* FROM elements .*table_.*col_.* " + r"are not joined up to FROM element .*table_.*col_.*" ): with testing.db.connect() as conn: - conn.execute(query) \ No newline at end of file + conn.execute(query) + + def teardown(self): + event.remove(testing.db, 'before_execute', linter.before_execute_hook) -- 2.47.2