]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed generative behavior to copy collections, [ticket:752]
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Aug 2007 20:43:54 +0000 (20:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Aug 2007 20:43:54 +0000 (20:43 +0000)
lib/sqlalchemy/sql/expression.py
test/sql/generative.py

index ea87a8c4f4acc1f1db9060751760c7e5f3fd723b..41f2b741bb1e56612429c1886f8d0bb7bb0d0f67 100644 (file)
@@ -2905,6 +2905,11 @@ class Select(_SelectBaseMixin, FromClause):
 
         self._should_correlate = correlate
         self._distinct = distinct
+    
+        # NOTE: the _generate()
+        # operation creates a *shallow* copy of the object, so append_XXX() methods,
+        # usually called via a generative method, create a copy of each collection
+        # by default
 
         self._raw_columns = []
         self.__correlate = util.Set()
@@ -2915,11 +2920,11 @@ class Select(_SelectBaseMixin, FromClause):
 
         if columns is not None:
             for c in columns:
-                self.append_column(c)
+                self.append_column(c, copy_collection=False)
 
         if from_obj is not None:
             for f in from_obj:
-                self.append_from(f)
+                self.append_from(f, copy_collection=False)
 
         if whereclause is not None:
             self.append_whereclause(whereclause)
@@ -2929,7 +2934,7 @@ class Select(_SelectBaseMixin, FromClause):
 
         if prefixes is not None:
             for p in prefixes:
-                self.append_prefix(p)
+                self.append_prefix(p, copy_collection=False)
 
         _SelectBaseMixin.__init__(self, **kwargs)
 
@@ -3078,20 +3083,29 @@ class Select(_SelectBaseMixin, FromClause):
             s.append_correlation(fromclause)
         return s
 
-    def append_correlation(self, fromclause):
-        self.__correlate.add(fromclause)
+    def append_correlation(self, fromclause, copy_collection=True):
+        if not copy_collection:
+            self.__correlate.add(fromclause)
+        else:
+            self.__correlate = util.Set(list(self.__correlate) + [fromclause])
 
-    def append_column(self, column):
+    def append_column(self, column, copy_collection=True):
         column = _literal_as_column(column)
 
         if isinstance(column, _ScalarSelect):
             column = column.self_group(against=operators.comma_op)
+        
+        if not copy_collection:
+            self._raw_columns.append(column)
+        else:
+            self._raw_columns = self._raw_columns + [column]
 
-        self._raw_columns.append(column)
-
-    def append_prefix(self, clause):
+    def append_prefix(self, clause, copy_collection=True):
         clause = _literal_as_text(clause)
-        self._prefixes.append(clause)
+        if not copy_collection:
+            self._prefixes.append(clause)
+        else:
+            self._prefixes = self._prefixes + [clause]
 
     def append_whereclause(self, whereclause):
         if self._whereclause  is not None:
@@ -3105,10 +3119,14 @@ class Select(_SelectBaseMixin, FromClause):
         else:
             self._having = _literal_as_text(having)
 
-    def append_from(self, fromclause):
+    def append_from(self, fromclause, copy_collection=True):
         if _is_literal(fromclause):
             fromclause = FromClause(fromclause)
-        self._froms.add(fromclause)
+            
+        if not copy_collection:
+            self._froms.add(fromclause)
+        else:
+            self._froms = util.Set(list(self._froms) + [fromclause])
 
     def _exportable_columns(self):
         return [c for c in self._raw_columns if isinstance(c, (Selectable, ColumnElement))]
index ece8a8d055ba849aab1ba5668ca50061c660c860..bcf5d6a5fa5d5e133bb225cfe8e2517c755f530c 100644 (file)
@@ -270,6 +270,53 @@ class SelectTest(SQLCompileTest):
         s = s.correlate(t1).order_by(t2.c.col3)
         self.assert_compile(t1.select().select_from(s).order_by(t1.c.col3), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, (SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 WHERE table2.col1 = table1.col1 ORDER BY table2.col3) ORDER BY table1.col3")
 
+    def test_columns(self):
+        s = t1.select()
+        self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1")
+        select_copy = s.column('yyy')
+        self.assert_compile(select_copy, "SELECT table1.col1, table1.col2, table1.col3, yyy FROM table1")
+        assert s.columns is not select_copy.columns 
+        assert s._columns is not select_copy._columns
+        assert s._raw_columns is not select_copy._raw_columns
+        self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1")
+
+    def test_froms(self):
+        s = t1.select()
+        self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1")
+        select_copy = s.select_from(t2)
+        self.assert_compile(select_copy, "SELECT table1.col1, table1.col2, table1.col3 FROM table1, table2")
+        assert s._froms is not select_copy._froms
+        self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1")
+
+    def test_correlation(self):
+        s = select([t2], t1.c.col1==t2.c.col1)
+        self.assert_compile(s, "SELECT table2.col1, table2.col2, table2.col3 FROM table2, table1 WHERE table1.col1 = table2.col1")
+        s2 = select([t1], t1.c.col2==s.c.col2)
+        self.assert_compile(s2, "SELECT table1.col1, table1.col2, table1.col3 FROM table1, "
+                "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 "
+                "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2")
+                
+        s3 = s.correlate(None)        
+        self.assert_compile(select([t1], t1.c.col2==s3.c.col2), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, "
+                "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2, table1 "
+                "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2")
+        self.assert_compile(select([t1], t1.c.col2==s.c.col2), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, "
+                "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 "
+                "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2")
+        s4 = s3.correlate(t1)
+        self.assert_compile(select([t1], t1.c.col2==s4.c.col2), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, "
+                "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2 "
+                "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2")
+        self.assert_compile(select([t1], t1.c.col2==s3.c.col2), "SELECT table1.col1, table1.col2, table1.col3 FROM table1, "
+                "(SELECT table2.col1 AS col1, table2.col2 AS col2, table2.col3 AS col3 FROM table2, table1 "
+                "WHERE table1.col1 = table2.col1) WHERE table1.col2 = col2")
+    
+    def test_prefixes(self):
+        s = t1.select()
+        self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1")
+        select_copy = s.prefix_with("FOOBER")
+        self.assert_compile(select_copy, "SELECT FOOBER table1.col1, table1.col2, table1.col3 FROM table1")
+        self.assert_compile(s, "SELECT table1.col1, table1.col2, table1.col3 FROM table1")
 
 if __name__ == '__main__':
     testbase.main()