]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- changed set used to generate FROM list to an ordered set; may fix [ticket:669]
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jul 2007 20:36:51 +0000 (20:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Jul 2007 20:36:51 +0000 (20:36 +0000)
- improvements to select generative capability, ClauseAdapter
- one select test is failing, but not from this checkin

lib/sqlalchemy/sql.py
lib/sqlalchemy/sql_util.py
test/sql/generative.py

index 2e5be6ee718df4205992243f6092a830744d28da..ed9843ac7eeac5f9fa4ab605988cff3cdb90b311 100644 (file)
@@ -2450,10 +2450,13 @@ class Alias(FromClause):
 
     def is_derived_from(self, fromclause):
         x = self.selectable
-        while isinstance(x, Alias):
-            x = x.selectable
+        while True:
             if x is fromclause:
                 return True
+            if isinstance(x, Alias):
+                x = x.selectable
+            else:
+                break
         return False
 
     def supports_execution(self):
@@ -2937,7 +2940,7 @@ class Select(_SelectBaseMixin, FromClause):
         _calculate_correlations() method.  
         
         """
-        froms = util.Set()
+        froms = util.OrderedSet()
         hide_froms = util.Set()
         
         for col in self._raw_columns:
@@ -3072,7 +3075,7 @@ class Select(_SelectBaseMixin, FromClause):
     def _copy_internals(self):
         self._clone_from_clause()
         self._raw_columns = [c._clone() for c in self._raw_columns]
-        self._recorrelate_froms([f._clone() for f in self._froms])
+        self._recorrelate_froms([(f, f._clone()) for f in self._froms])
         for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
             if getattr(self, attr) is not None:
                 setattr(self, attr, getattr(self, attr)._clone())
@@ -3084,12 +3087,17 @@ class Select(_SelectBaseMixin, FromClause):
 
     def _recorrelate_froms(self, froms):
         newcorrelate = util.Set()
-        for f in froms:
-            if f in self.__correlate:
-                newcorrelate.add(cl)
-                self.__correlate.remove(f)
+        newfroms = util.Set()
+        oldfroms = util.Set(self._froms)
+        for old, new in froms:
+            if old in self.__correlate:
+                newcorrelate.add(new)
+                self.__correlate.remove(old)
+            if old in oldfroms:
+                newfroms.add(new)
+                oldfroms.remove(old)
         self.__correlate = self.__correlate.union(newcorrelate)
-        self._froms = froms
+        self._froms = [f for f in oldfroms.union(newfroms)]
         
     def column(self, column):
         s = self._generate()
@@ -3116,7 +3124,7 @@ class Select(_SelectBaseMixin, FromClause):
         s.append_from(fromclause)
         return s
     
-    def correlate_to(self, fromclause):
+    def correlate(self, fromclause):
         s = self._generate()
         s.append_correlation(fromclause)
         return s
index 36d127c98cbcb799548d504abf378b46957a5aa3..96dae5e0a2c73e3df6fcee097b0f4622fef5ef8b 100644 (file)
@@ -167,10 +167,8 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
         fr = util.OrderedSet()
         for elem in select._froms:
             n = self.convert_element(elem)
-            if n is None:
-                fr.add(elem)
-            else:
-                fr.add(n)
+            if n is not None:
+                fr.add((elem, n))
         select._recorrelate_froms(fr)
 
         col = []
index cb8f4c6faca282ba49d615124efccda57fbade7d..5172e5b5b005af34b5c59136cdf6407ff122bb97 100644 (file)
@@ -1,4 +1,6 @@
 import testbase
+from sql import select as selecttests
+
 from sqlalchemy import *
 
 class TraversalTest(testbase.AssertMixin):
@@ -130,7 +132,7 @@ class TraversalTest(testbase.AssertMixin):
         assert struct != s3
         assert struct3 == s3
 
-class ClauseTest(testbase.AssertMixin):
+class ClauseTest(selecttests.SQLTest):
     """test copy-in-place behavior of various ClauseElements."""
     
     def setUpAll(self):
@@ -203,6 +205,61 @@ class ClauseTest(testbase.AssertMixin):
         print str(s5)
         assert str(s5) == s5_assert
         assert str(s4) == s4_assert
+    
+    def test_correlated_select(self):
+        s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2)
+        class Vis(ClauseVisitor):
+            def visit_select(self, select):
+                select.append_whereclause(t1.c.col2==7)
+                
+        self.runtest(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2")
+
+    def test_clause_adapter(self):
+        from sqlalchemy import sql_util
+        
+        t1alias = t1.alias('t1alias')
+        
+        vis = sql_util.ClauseAdapter(t1alias)
+        self.runtest(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias")
+        self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+        self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+        self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
+        self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
+
+        t2alias = t2.alias('t2alias')
+        vis.chain(sql_util.ClauseAdapter(t2alias))
+        self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+        self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+        self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+        self.runtest(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+        
+class SelectTest(selecttests.SQLTest):
+    """tests the generative capability of Select"""
+
+    def setUpAll(self):
+        global t1, t2
+        t1 = table("table1", 
+            column("col1"),
+            column("col2"),
+            column("col3"),
+            )
+        t2 = table("table2", 
+            column("col1"),
+            column("col2"),
+            column("col3"),
+            )
+    
+    def test_select(self):
+        self.runtest(t1.select().where(t1.c.col1==5).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1 WHERE table1.col1 = :table1_col1 ORDER BY table1.col3")
+    
+        self.runtest(t1.select().select_from(select([t2], t2.c.col1==t1.c.col1)).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1) ORDER BY table1.col3")
+        
+        s = select([t2], t2.c.col1==t1.c.col1, correlate=False)
+        s = s.correlate(t1).order_by(t2.c.col3)
+        self.runtest(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3")
+        
+        
+        
         
         
 if __name__ == '__main__':