]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed bug which was preventing UNIONS from being cloneable,
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Mar 2008 18:20:09 +0000 (18:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Mar 2008 18:20:09 +0000 (18:20 +0000)
[ticket:986]

CHANGES
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/sql/generative.py

diff --git a/CHANGES b/CHANGES
index de8c168eaaf69ddad28b65ce82a4b96ade41dca7..e53812101dfbcc507a83de20ae5ff22fad54162f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -19,6 +19,9 @@ CHANGES
     - implemented two-phase API for "threadlocal" engine, 
       via engine.begin_twophase(), engine.prepare()
       [ticket:936]
+
+    - fixed bug which was preventing UNIONS from being cloneable,
+      [ticket:986]
       
 - orm
     - any(), has(), contains(), attribute level == and != now
index 56629a6ca38984c1a04eb4b5726c3020263b6edb..812c70c2d84bdda49e9bb09c499d817c19e9ef6f 100644 (file)
@@ -2925,8 +2925,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
             else:
                 self.selects.append(s)
 
-        self._col_map = {}
-
         _SelectBaseMixin.__init__(self, **kwargs)
 
         for s in self.selects:
@@ -2942,11 +2940,15 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
                 yield c
 
     def _proxy_column(self, column):
-        selectable = column.table
-        col_ordering = self._col_map.get(selectable, None)
-        if col_ordering is None:
-            self._col_map[selectable] = col_ordering = []
-
+        if not hasattr(self, '_col_map'):
+            self._col_map = dict([(s, []) for s in self.selects])
+            for s in self.selects:
+                for c in s.c + [s.oid_column]:
+                    self._col_map[c] = s
+        
+        selectable = self._col_map[column]
+        col_ordering = self._col_map[selectable]
+        
         if selectable is self.selects[0]:
             if self.use_labels:
                 col = column._make_proxy(self, name=column._label)
@@ -2962,8 +2964,9 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
 
     def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
-        self._col_map = {}
         self.selects = [clone(s) for s in self.selects]
+        if hasattr(self, '_col_map'):
+            del self._col_map
         for attr in ('_order_by_clause', '_group_by_clause'):
             if getattr(self, attr) is not None:
                 setattr(self, attr, clone(getattr(self, attr)))
index 70a1dcc96416a20df65a2dc1cc162bcbe47c7f27..9954811d644afe13687eae5a1712746a62dc64ff 100644 (file)
@@ -152,6 +152,12 @@ class AbstractClauseProcessor(object):
 
     This class implements its own visit-and-copy strategy but maintains the
     same public interface as visitors.ClauseVisitor.
+    
+    The convert_element() method receives the *un-copied* version of each element.
+    It can return a new element or None for no change.  If None, the element
+    will be cloned afterwards and added to the new structure.  Note this is the
+    opposite behavior of visitors.traverse(clone=True), where visitors receive
+    the cloned element so that it can be mutated.
     """
 
     __traverse_options__ = {'column_collections':False}
index 831c2e2873660d2081045682a327a613bc48574a..5e6b3b7e6c4c62a477034acb3640fb3c565f6ecf 100644 (file)
@@ -224,7 +224,29 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
         print str(s5)
         assert str(s5) == s5_assert
         assert str(s4) == s4_assert
-
+    
+    def test_union(self):
+        u = union(t1.select(), t2.select())
+        u2 = ClauseVisitor().traverse(u, clone=True)
+        assert str(u) == str(u2)
+        assert [str(c) for c in u2.c] == [str(c) for c in u.c]
+
+        u = union(t1.select(), t2.select())
+        cols = [str(c) for c in u.c]
+        u2 = ClauseVisitor().traverse(u, clone=True)
+        assert str(u) == str(u2)
+        assert [str(c) for c in u2.c] == cols
+        
+        s1 = select([t1], t1.c.col1 == bindparam('id_param'))
+        s2 = select([t2])
+        u = union(s1, s2)
+        
+        u2 = u.params(id_param=7)
+        u3 = u.params(id_param=10)
+        assert str(u) == str(u2) == str(u3)
+        assert u2.compile().params == {'id_param':7}
+        assert u3.compile().params == {'id_param':10}
+        
     def test_binds(self):
         """test that unique bindparams change their name upon clone() to prevent conflicts"""