else:
raise NotImplementedError
-# from sqlalchemy.ext.compiler import compiles
-# @compiles(Select)
-# def select_warn_for_cartesian_compiler(element, compiler, **kw):
-# warn_for_cartesian(element)
-# return compiler.visit_select(element, **kw)
-
-def warn_for_cartesian(element):
+def find_unmatching_froms(element, start_with=None):
+ # TODO: It would be nicer to use OrderedSet, but it seems to not be too much optimize, so let's skip for now
froms = set(element.froms)
if not froms:
return
# take any element from the list of FROMS.
# then traverse all the edges and ensure we can reach
# all other FROMS
- start_with = froms.pop()
+ if start_with is not None:
+ assert start_with in froms
+ else:
+ start_with = next(iter(froms))
+ froms.remove(start_with)
the_rest = froms
stack = collections.deque([start_with])
while stack and the_rest:
node = stack.popleft()
+ # the_rest.pop(node, None)
the_rest.discard(node)
for edge in list(edges):
if edge not in edges:
# FROMS left over? boom
if the_rest:
+ return the_rest, start_with
+ else:
+ return None, None
+
+def warn_for_cartesian(element):
+ froms, start_with = find_unmatching_froms(element)
+ if froms:
util.warn(
'for stmt %s FROM elements %s are not joined up to FROM element "%r"'
% (
id(element), # defeat the warnings filter
- ", ".join('"%r"' % f for f in the_rest),
+ ", ".join('"%r"' % f for f in froms),
start_with,
)
- )
\ No newline at end of file
+ )
from sqlalchemy import select, Integer, event, testing
from sqlalchemy.ext import linter
+from sqlalchemy.ext.linter import find_unmatching_froms
from sqlalchemy.testing import fixtures, expect_warnings
from sqlalchemy.testing.schema import Table, Column
-class TestLinter(fixtures.TablesTest):
+class TestFinder(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
self.b = self.tables.table_b
self.c = self.tables.table_c
self.d = self.tables.table_d
- event.listen(testing.db, 'before_execute', linter.before_execute_hook)
-
- def teardown(self):
- event.remove(testing.db, 'before_execute', linter.before_execute_hook)
def test_everything_is_connected(self):
query = (
.where(self.c.c.col_c == self.d.c.col_d)
.where(self.c.c.col_c == 5)
)
- with testing.db.connect() as conn:
- conn.execute(query)
+ froms, start = find_unmatching_froms(query)
+ assert not froms
+
+ for start in self.a, self.b, self.c, self.d:
+ froms, start = find_unmatching_froms(query, start)
+ assert not froms
def test_plain_cartesian(self):
query = (
select([self.a])
.where(self.b.c.col_b == 5)
)
- with expect_warnings(
- r"for stmt .* FROM elements .*table_b.*col_b.* "
- r"are not joined up to FROM element .*table_a.*col_a.*"
- ):
- with testing.db.connect() as conn:
- conn.execute(query)
+ froms, start = find_unmatching_froms(query, self.a)
+ assert start == self.a
+ assert froms == {self.b}
+
+ froms, start = find_unmatching_froms(query, self.b)
+ assert start == self.b
+ assert froms == {self.a}
def test_disconnect_between_ab_cd(self):
query = (
.where(self.c.c.col_c == self.d.c.col_d)
.where(self.c.c.col_c == 5)
)
- with expect_warnings(
- # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
- r"for stmt .* FROM elements .* "
- r"are not joined up to FROM element .*"
- ):
- with testing.db.connect() as conn:
- conn.execute(query)
+ for start in self.a, self.b:
+ froms, start = find_unmatching_froms(query, start)
+ assert start == start
+ assert froms == {self.c, self.d}
+ for start in self.c, self.d:
+ froms, start = find_unmatching_froms(query, start)
+ assert start == start
+ assert froms == {self.a, self.b}
def test_c_and_d_both_disconnected(self):
query = (
.where(self.c.c.col_c == 5)
.where(self.d.c.col_d == 10)
)
- with expect_warnings(
- # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
- r"for stmt .* FROM elements .* "
- r"are not joined up to FROM element .*"
- ):
- with testing.db.connect() as conn:
- conn.execute(query)
+ for start in self.a, self.b:
+ froms, start = find_unmatching_froms(query, start)
+ assert start == start
+ assert froms == {self.c, self.d}
+
+ froms, start = find_unmatching_froms(query, self.c)
+ assert start == self.c
+ assert froms == {self.a, self.b, self.d}
+
+ froms, start = find_unmatching_froms(query, self.d)
+ assert start == self.d
+ assert froms == {self.a, self.b, self.c}
def test_now_connected(self):
query = (
.where(self.c.c.col_c == 5)
.where(self.d.c.col_d == 10)
)
- with testing.db.connect() as conn:
- conn.execute(query)
+ froms, start = find_unmatching_froms(query)
+ assert not froms
+
+ for start in self.a, self.b, self.c, self.d:
+ froms, start = find_unmatching_froms(query, start)
+ assert not froms
def test_disconnected_subquery(self):
subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
stmt = select([self.c]).select_from(subq)
- with expect_warnings(
- # TODO: Fix subquery not being displayed properly
- r"for stmt .* FROM elements .* "
- r"are not joined up to FROM element .*table_c.*col_c.*"
- ):
- with testing.db.connect() as conn:
- conn.execute(stmt)
+
+ froms, start = find_unmatching_froms(stmt, self.c)
+ assert start == self.c
+ assert froms == {subq}
+
+ froms, start = find_unmatching_froms(stmt, subq)
+ assert start == subq
+ assert froms == {self.c}
def test_now_connect_it(self):
subq = select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
stmt = select([self.c]).select_from(subq).where(self.c.c.col_c == subq.c.col_a)
- with testing.db.connect() as conn:
- conn.execute(stmt)
+
+ froms, start = find_unmatching_froms(stmt)
+ assert not froms
+
+ for start in self.c, subq:
+ froms, start = find_unmatching_froms(stmt, start)
+ assert not froms
def test_right_nested_join_without_issue(self):
query = (
self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b)
)
)
- with testing.db.connect() as conn:
- conn.execute(query)
+ froms, start = find_unmatching_froms(query)
+ assert not froms
+
+ for start in self.a, self.b, self.c:
+ froms, start = find_unmatching_froms(query, start)
+ assert not froms
def test_right_nested_join_with_an_issue(self):
query = (
select([self.a])
- .select_from(self.a.join(self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), self.a.c.col_a == self.b.c.col_b))
+ .select_from(
+ self.a.join(
+ self.b.join(self.c, self.b.c.col_b == self.c.c.col_c),
+ self.a.c.col_a == self.b.c.col_b,
+ ),
+ )
.where(self.d.c.col_d == 5)
)
+
+ for start in self.a, self.b, self.c:
+ froms, start = find_unmatching_froms(query, start)
+ assert start == start
+ assert froms == {self.d}
+
+ froms, start = find_unmatching_froms(query, self.d)
+ assert start == self.d
+ assert froms == {self.a, self.b, self.c}
+
+
+class TestLinter(fixtures.TablesTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
+ Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+ Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
+ Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
+
+ def setup(self):
+ self.a = self.tables.table_a
+ self.b = self.tables.table_b
+ self.c = self.tables.table_c
+ self.d = self.tables.table_d
+ event.listen(testing.db, 'before_execute', linter.before_execute_hook)
+
+ def test_integration(self):
+ query = (
+ select([self.a])
+ .where(self.b.c.col_b == 5)
+ )
+ # TODO:
+ # - make it a unit by mocking or spying "find_unmatching_froms"
+ # - Make error string proper
with expect_warnings(
- # TODO: Fix FROM element parts being undeterministic (impl uses set, no order is guaranteed)
- r"for stmt .* FROM elements .* "
- r"are not joined up to FROM element .*"
+ r"for stmt .* FROM elements .*table_.*col_.* "
+ r"are not joined up to FROM element .*table_.*col_.*"
):
with testing.db.connect() as conn:
- conn.execute(query)
\ No newline at end of file
+ conn.execute(query)
+
+ def teardown(self):
+ event.remove(testing.db, 'before_execute', linter.before_execute_hook)