]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more changes to traverse-and-clone; a particular element will only be cloned once...
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Nov 2007 18:06:21 +0000 (18:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Nov 2007 18:06:21 +0000 (18:06 +0000)
then re-used.  the FROM calculation of a Select normalizes the list of hide_froms against all
previous incarnations of each FROM clause, using a tag attached from cloned clause to
previous.

lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/sql/generative.py

index 6276f33bd9fda8ee92137595af80a3ea689fc1e2..e066632afbe5505b59bb7660777368270af76847 100644 (file)
@@ -844,6 +844,14 @@ class ClauseElement(object):
         """
         c = self.__class__.__new__(self.__class__)
         c.__dict__ = self.__dict__.copy()
+        
+        # this is a marker that helps to "equate" clauses to each other
+        # when a Select returns its list of FROM clauses.  the cloning
+        # process leaves around a lot of remnants of the previous clause
+        # typically in the form of column expressions still attached to the
+        # old table.
+        c._is_clone_of = self
+        
         return c
 
     def _get_from_objects(self, **modifiers):
@@ -2212,7 +2220,7 @@ class Join(FromClause):
         self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit])
 
     def description(self):
-        return "Join object on %s and %s" % (self.left.description, self.right.description)
+        return "Join object on %s(%d) and %s(%d)" % (self.left.description, id(self.left), self.right.description, id(self.right))
     description = property(description)
     
     primary_key = property(lambda s:s.__primary_key)
@@ -2394,15 +2402,6 @@ class Alias(FromClause):
         #return self.selectable._exportable_columns()
         return self.selectable.columns
 
-    def _clone(self):
-        # TODO: need test coverage to assert ClauseAdapter behavior
-        # here; must identify non-ORM failure cases when a. _clone() returns 'self' in all 
-        # cases and b. when _clone() does an actual _clone() in all cases.
-        if isinstance(self.selectable, TableClause):
-            return self
-        else:
-            return super(Alias, self)._clone()
-
     def _copy_internals(self, clone=_clone):
        self._clone_from_clause()
        self.selectable = _clone(self.selectable)
@@ -2996,6 +2995,9 @@ class Select(_SelectBaseMixin, FromClause):
         for col in self._raw_columns:
             for f in col._hide_froms():
                 hide_froms.add(f)
+                while hasattr(f, '_is_clone_of'):
+                    hide_froms.add(f._is_clone_of)
+                    f = f._is_clone_of
             for f in col._get_from_objects():
                 froms.add(f)
 
@@ -3007,17 +3009,26 @@ class Select(_SelectBaseMixin, FromClause):
             froms.add(elem)
             for f in elem._get_from_objects():
                 froms.add(f)
-
+        
         for elem in froms:
             for f in elem._hide_froms():
                 hide_froms.add(f)
-
+                while hasattr(f, '_is_clone_of'):
+                    hide_froms.add(f._is_clone_of)
+                    f = f._is_clone_of
+        
         froms = froms.difference(hide_froms)
-
+        
         if len(froms) > 1:
             corr = self.__correlate
             if self._should_correlate and existing_froms is not None:
                 corr = existing_froms.union(corr)
+                
+            for f in list(corr):
+                while hasattr(f, '_is_clone_of'):
+                    corr.add(f._is_clone_of)
+                    f = f._is_clone_of
+                
             f = froms.difference(corr)
             if len(f) == 0:
                 raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
@@ -3070,8 +3081,8 @@ class Select(_SelectBaseMixin, FromClause):
 
     def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
-        self._raw_columns = [clone(c) for c in self._raw_columns]
         self._recorrelate_froms([(f, clone(f)) for f in self._froms])
+        self._raw_columns = [clone(c) for c in self._raw_columns]
         for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
             if getattr(self, attr) is not None:
                 setattr(self, attr, clone(getattr(self, attr)))
index ecf4f3c16311117298131089a387543fbbdef68a..81d28ac7ed6d62ff0d291fc80fc80b0f24c0b365 100644 (file)
@@ -148,7 +148,7 @@ class AbstractClauseProcessor(object):
             list_[i] = self.traverse(list_[i], stop_on=stop_on)
         return list_
 
-    def _convert_element(self, elem, stop_on):
+    def _convert_element(self, elem, stop_on, cloned):
         v = self
         while v is not None:
             newelem = v.convert_element(elem)
@@ -156,25 +156,32 @@ class AbstractClauseProcessor(object):
                 stop_on.add(newelem)
                 return newelem
             v = getattr(v, '_next_acp', None)
-        return elem._clone()
         
-    def traverse(self, elem, clone=True, stop_on=None, _clone_toplevel=True):
+        if elem not in cloned:
+            # the full traversal will only make a clone of a particular element
+            # once.
+            cloned[elem] = elem._clone()
+        return cloned[elem]
+        
+    def traverse(self, elem, clone=True, stop_on=None):
         if not clone:
             raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
-            
+        
         if stop_on is None:
             stop_on = util.Set()
-            
+        return self._traverse(elem, stop_on, {}, _clone_toplevel=True)
+        
+    def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False):
         if elem in stop_on:
             return elem
         
         if _clone_toplevel:
-            elem = self._convert_element(elem, stop_on)
+            elem = self._convert_element(elem, stop_on, cloned)
             if elem in stop_on:
                 return elem
             
         def clone(element):
-            return self._convert_element(element, stop_on)
+            return self._convert_element(element, stop_on, cloned)
         elem._copy_internals(clone=clone)
         
         v = getattr(self, '_next', None)
@@ -186,7 +193,7 @@ class AbstractClauseProcessor(object):
         
         for e in elem.get_children(**self.__traverse_options__):
             if e not in stop_on:
-                self.traverse(e, stop_on=stop_on, _clone_toplevel=False)
+                self._traverse(e, stop_on, cloned)
         return elem
         
 class ClauseAdapter(AbstractClauseProcessor):
index 9bc5d2479fd2adafc1d9daf9847a7cb1453c0328..1a0629a17dcd3145475cfec3359e2a261330f9ec 100644 (file)
@@ -47,10 +47,19 @@ class ClauseVisitor(object):
                 traversal.insert(0, t)
                 for c in t.get_children(**self.__traverse_options__):
                     stack.append(c)
-
+    
     def traverse(self, obj, stop_on=None, clone=False):
+        
         if clone:
-            obj = obj._clone()
+            cloned = {}
+            def do_clone(obj):
+                # the full traversal will only make a clone of a particular element
+                # once.
+                if obj not in cloned:
+                    cloned[obj] = obj._clone()
+                return cloned[obj]
+            
+            obj = do_clone(obj)
             
         stack = [obj]
         traversal = []
@@ -59,7 +68,7 @@ class ClauseVisitor(object):
             if stop_on is None or t not in stop_on:
                 traversal.insert(0, t)
                 if clone:
-                    t._copy_internals()
+                    t._copy_internals(clone=do_clone)
                 for c in t.get_children(**self.__traverse_options__):
                     stack.append(c)
         for target in traversal:
index 2d1f3ccf91f71221a337aab0baba5c1a8a068840..1497ecde3d93971ec3c8f682bb9bb9d538c13d25 100644 (file)
@@ -1,6 +1,7 @@
 import testbase
 from sqlalchemy import *
 from sqlalchemy.sql import table, column, ClauseElement
+from sqlalchemy.sql.expression import  _clone
 from testlib import *
 from sqlalchemy.sql.visitors import *
 from sqlalchemy import util
@@ -56,8 +57,8 @@ class TraversalTest(AssertMixin):
                         return True
                 return False
             
-            def _copy_internals(self):    
-                self.items = [i._clone() for i in self.items]
+            def _copy_internals(self, clone=_clone):    
+                self.items = [clone(i) for i in self.items]
 
             def get_children(self, **kwargs):
                 return self.items
@@ -223,7 +224,23 @@ class ClauseTest(SQLCompileTest):
         print str(s5)
         assert str(s5) == s5_assert
         assert str(s4) == s4_assert
-    
+
+    def test_alias(self):
+        subq = t2.select().alias('subq')
+        s = select([t1.c.col1, subq.c.col1], from_obj=[t1, subq, t1.join(subq, t1.c.col1==subq.c.col2)])
+        orig = str(s)
+        s2 = ClauseVisitor().traverse(s, clone=True)
+        assert orig == str(s) == str(s2)
+
+        s4 = ClauseVisitor().traverse(s2, clone=True)
+        assert orig == str(s) == str(s2) == str(s4)
+
+        s3 = sql_util.ClauseAdapter(table('foo')).traverse(s, clone=True)
+        assert orig == str(s) == str(s3)
+
+        s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3, clone=True)
+        assert orig == str(s) == str(s3) == str(s4)
+        
     def test_correlated_select(self):
         s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2)
         class Vis(ClauseVisitor):