From: Mike Bayer Date: Thu, 19 Jul 2007 20:36:51 +0000 (+0000) Subject: - changed set used to generate FROM list to an ordered set; may fix [ticket:669] X-Git-Tag: rel_0_4_6~69 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e9da3425e4cb0173154d8e4c42c9c7afca63786f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - changed set used to generate FROM list to an ordered set; may fix [ticket:669] - improvements to select generative capability, ClauseAdapter - one select test is failing, but not from this checkin --- diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 2e5be6ee71..ed9843ac7e 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -2450,10 +2450,13 @@ class Alias(FromClause): def is_derived_from(self, fromclause): x = self.selectable - while isinstance(x, Alias): - x = x.selectable + while True: if x is fromclause: return True + if isinstance(x, Alias): + x = x.selectable + else: + break return False def supports_execution(self): @@ -2937,7 +2940,7 @@ class Select(_SelectBaseMixin, FromClause): _calculate_correlations() method. """ - froms = util.Set() + froms = util.OrderedSet() hide_froms = util.Set() for col in self._raw_columns: @@ -3072,7 +3075,7 @@ class Select(_SelectBaseMixin, FromClause): def _copy_internals(self): self._clone_from_clause() self._raw_columns = [c._clone() for c in self._raw_columns] - self._recorrelate_froms([f._clone() for f in self._froms]) + self._recorrelate_froms([(f, f._clone()) for f in self._froms]) for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): if getattr(self, attr) is not None: setattr(self, attr, getattr(self, attr)._clone()) @@ -3084,12 +3087,17 @@ class Select(_SelectBaseMixin, FromClause): def _recorrelate_froms(self, froms): newcorrelate = util.Set() - for f in froms: - if f in self.__correlate: - newcorrelate.add(cl) - self.__correlate.remove(f) + newfroms = util.Set() + oldfroms = util.Set(self._froms) + for old, new in froms: + if old in self.__correlate: + newcorrelate.add(new) + self.__correlate.remove(old) + if old in oldfroms: + newfroms.add(new) + oldfroms.remove(old) self.__correlate = self.__correlate.union(newcorrelate) - self._froms = froms + self._froms = [f for f in oldfroms.union(newfroms)] def column(self, column): s = self._generate() @@ -3116,7 +3124,7 @@ class Select(_SelectBaseMixin, FromClause): s.append_from(fromclause) return s - def correlate_to(self, fromclause): + def correlate(self, fromclause): s = self._generate() s.append_correlation(fromclause) return s diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 36d127c98c..96dae5e0a2 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -167,10 +167,8 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): fr = util.OrderedSet() for elem in select._froms: n = self.convert_element(elem) - if n is None: - fr.add(elem) - else: - fr.add(n) + if n is not None: + fr.add((elem, n)) select._recorrelate_froms(fr) col = [] diff --git a/test/sql/generative.py b/test/sql/generative.py index cb8f4c6fac..5172e5b5b0 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -1,4 +1,6 @@ import testbase +from sql import select as selecttests + from sqlalchemy import * class TraversalTest(testbase.AssertMixin): @@ -130,7 +132,7 @@ class TraversalTest(testbase.AssertMixin): assert struct != s3 assert struct3 == s3 -class ClauseTest(testbase.AssertMixin): +class ClauseTest(selecttests.SQLTest): """test copy-in-place behavior of various ClauseElements.""" def setUpAll(self): @@ -203,6 +205,61 @@ class ClauseTest(testbase.AssertMixin): print str(s5) assert str(s5) == s5_assert assert str(s4) == s4_assert + + def test_correlated_select(self): + s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2) + class Vis(ClauseVisitor): + def visit_select(self, select): + select.append_whereclause(t1.c.col2==7) + + self.runtest(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2") + + def test_clause_adapter(self): + from sqlalchemy import sql_util + + t1alias = t1.alias('t1alias') + + vis = sql_util.ClauseAdapter(t1alias) + self.runtest(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2") + + t2alias = t2.alias('t2alias') + vis.chain(sql_util.ClauseAdapter(t2alias)) + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") + self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2") + +class SelectTest(selecttests.SQLTest): + """tests the generative capability of Select""" + + def setUpAll(self): + global t1, t2 + t1 = table("table1", + column("col1"), + column("col2"), + column("col3"), + ) + t2 = table("table2", + column("col1"), + column("col2"), + column("col3"), + ) + + def test_select(self): + self.runtest(t1.select().where(t1.c.col1==5).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1 WHERE table1.col1 = :table1_col1 ORDER BY table1.col3") + + self.runtest(t1.select().select_from(select([t2], t2.c.col1==t1.c.col1)).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1) ORDER BY table1.col3") + + s = select([t2], t2.c.col1==t1.c.col1, correlate=False) + s = s.correlate(t1).order_by(t2.c.col3) + self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3") + + + if __name__ == '__main__':