]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
moved "clone" conditional blocks into separate copy_internals() method; was a
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Jun 2007 03:14:01 +0000 (03:14 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Jun 2007 03:14:01 +0000 (03:14 +0000)
conflation of tasks having them in the same place like that.

lib/sqlalchemy/sql.py
test/sql/generative.py

index c52f5cdc4dc713748e648feb8f4870b35f4b9ad4..a7b88d158fea5adc33d8a22fed21f0ed8111e4c2 100644 (file)
@@ -898,8 +898,10 @@ class ClauseVisitor(object):
                     meth = getattr(v, "enter_%s" % obj.__visit_name__, None)
                     if meth:
                         meth(obj)
-
-            for c in obj.get_children(clone=clone, **self.__traverse_options__):
+            
+            if clone:
+                obj.copy_internals()
+            for c in obj.get_children(**self.__traverse_options__):
                 _trav(c)
 
             for v in visitors:
@@ -979,16 +981,19 @@ class ClauseElement(object):
 
         return self is other
 
-    def get_children(self, clone=False, **kwargs):
+    def copy_internals(self):
+        """reassign internal elements to be clones of themselves.
+        
+        called during a copy-and-traverse operation on newly 
+        shallow-copied elements to create a deep copy."""
+        
+        pass
+        
+    def get_children(self, **kwargs):
         """return immediate child elements of this ``ClauseElement``.
         
         this is used for visit traversal.
         
-        clone indicates child items should be _cloned(), replacing
-        the elements contained by this element, and the cloned
-        copy returned.  this allows modifying traversals
-        to take place.
-        
         \**kwargs may contain flags that change the collection
         that is returned, for example to return a subset of items
         in order to cut down on larger traversals, or to return 
@@ -1755,10 +1760,10 @@ class _TextClause(ClauseElement):
 
     columns = property(lambda s:[])
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self.bindparams = [b._clone() for b in self.bindparams]
-            
+    def copy_internals(self):
+        self.bindparams = [b._clone() for b in self.bindparams]
+
+    def get_children(self, **kwargs):
         return self.bindparams.values()
 
     def _get_from_objects(self, **modifiers):
@@ -1815,10 +1820,10 @@ class ClauseList(ClauseElement):
         else:
             self.clauses.append(_literal_as_text(clause))
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self.clauses = [clause._clone() for clause in self.clauses]
-            
+    def copy_internals(self):
+        self.clauses = [clause._clone() for clause in self.clauses]
+
+    def get_children(self, **kwargs):
         return self.clauses
 
     def _get_from_objects(self, **modifiers):
@@ -1870,9 +1875,10 @@ class _CalculatedClause(ColumnElement):
             
     key = property(lambda self:self.name or "_calc_")
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self.clause_expr = self.clause_expr._clone()
+    def copy_internals(self):
+        self.clause_expr = self.clause_expr._clone()
+
+    def get_children(self, **kwargs):
         return self.clause_expr,
         
     def _get_from_objects(self, **modifiers):
@@ -1911,10 +1917,11 @@ class _Function(_CalculatedClause, FromClause):
 
     key = property(lambda self:self.name)
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self._clone_from_clause()
-        return _CalculatedClause.get_children(self, clone=clone, **kwargs)
+    def copy_internals(self):
+        self._clone_from_clause()
+
+    def get_children(self, **kwargs):
+        return _CalculatedClause.get_children(self, **kwargs)
         
     def append(self, clause):
         self.clauses.append(_literal_as_binds(clause, self.name))
@@ -1928,11 +1935,11 @@ class _Cast(ColumnElement):
         self.clause = clause
         self.typeclause = _TypeClause(self.type)
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self.clause = self.clause._clone()
-            self.typeclause = self.typeclause._clone()
-            
+    def copy_internals(self):
+        self.clause = self.clause._clone()
+        self.typeclause = self.typeclause._clone()
+
+    def get_children(self, **kwargs):
         return self.clause, self.typeclause
 
     def _get_from_objects(self, **modifiers):
@@ -1960,9 +1967,10 @@ class _UnaryExpression(ColumnElement):
     def _get_from_objects(self, **modifiers):
         return self.element._get_from_objects(**modifiers)
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self.element = self.element._clone()
+    def copy_internals(self):
+        self.element = self.element._clone()
+
+    def get_children(self, **kwargs):
         return self.element,
 
     def compare(self, other):
@@ -1994,11 +2002,11 @@ class _BinaryExpression(ColumnElement):
     def _get_from_objects(self, **modifiers):
         return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self.left = self.left._clone()
-            self.right = self.right._clone()
-            
+    def copy_internals(self):
+        self.left = self.left._clone()
+        self.right = self.right._clone()
+
+    def get_children(self, **kwargs):
         return self.left, self.right
 
     def compare(self, other):
@@ -2078,14 +2086,15 @@ class Join(FromClause):
             self._foreign_keys.add(f)
         return column
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self._clone_from_clause()
-            self.left = self.left._clone()
-            self.right = self.right._clone()
-            self.onclause = self.onclause._clone()
-            self.__folded_equivalents = None
-            self._init_primary_key()
+    def copy_internals(self):
+        self._clone_from_clause()
+        self.left = self.left._clone()
+        self.right = self.right._clone()
+        self.onclause = self.onclause._clone()
+        self.__folded_equivalents = None
+        self._init_primary_key()
+
+    def get_children(self, **kwargs):
         return self.left, self.right, self.onclause
 
     def _match_primaries(self, primary, secondary):
@@ -2244,14 +2253,15 @@ class Alias(FromClause):
         #return self.selectable._exportable_columns()
         return self.selectable.columns
 
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self._clone_from_clause()
-            self.selectable = self.selectable._clone()
-            baseselectable = self.selectable
-            while isinstance(baseselectable, Alias):
-                baseselectable = baseselectable.selectable
-            self.original = baseselectable
+    def copy_internals(self):
+        self._clone_from_clause()
+        self.selectable = self.selectable._clone()
+        baseselectable = self.selectable
+        while isinstance(baseselectable, Alias):
+            baseselectable = baseselectable.selectable
+        self.original = baseselectable
+
+    def get_children(self, **kwargs):
         for c in self.c:
             yield c
         yield self.selectable
@@ -2273,9 +2283,10 @@ class _Grouping(ColumnElement):
     _label = property(lambda s: s.elem._label)
     orig_set = property(lambda s:s.elem.orig_set)
     
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self.elem = self.elem._clone()
+    def copy_internals(self):
+        self.elem = self.elem._clone()
+
+    def get_children(self, **kwargs):
         return self.elem,
         
     def _hide_froms(self, **modifiers):
@@ -2309,9 +2320,10 @@ class _Label(ColumnElement):
     def _compare_self(self):
         return self.obj
     
-    def get_children(self, clone=False, **kwargs):
-        if clone:
-            self.obj = self.obj._clone()
+    def copy_internals(self):
+        self.obj = self.obj._clone()
+
+    def get_children(self, **kwargs):
         return self.obj,
 
     def _get_from_objects(self, **modifiers):
@@ -2637,15 +2649,15 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         col.orig_set = colset
         return col
 
-    def get_children(self, clone=False, column_collections=True, **kwargs):
-        if clone:
-            self._clone_from_clause()
-            self._col_map = {}
-            self.selects = [s._clone() for s in self.selects]
-            for attr in ('_order_by_clause', '_group_by_clause'):
-                if getattr(self, attr) is not None:
-                    setattr(self, attr, getattr(self, attr)._clone())
+    def copy_internals(self):
+        self._clone_from_clause()
+        self._col_map = {}
+        self.selects = [s._clone() for s in self.selects]
+        for attr in ('_order_by_clause', '_group_by_clause'):
+            if getattr(self, attr) is not None:
+                setattr(self, attr, getattr(self, attr)._clone())
 
+    def get_children(self, column_collections=True, **kwargs):
         return (column_collections and list(self.c) or []) + \
             [self._order_by_clause, self._group_by_clause] + list(self.selects)
             
@@ -2814,15 +2826,15 @@ class Select(_SelectBaseMixin, FromClause):
             
     inner_columns = property(_get_inner_columns)
     
-    def get_children(self, clone=False, column_collections=True, **kwargs):
-        if clone:
-            self._clone_from_clause()
-            self._raw_columns = [c._clone() for c in self._raw_columns]
-            self._recorrelate_froms([f._clone() for f in self._froms])
-            for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
-                if getattr(self, attr) is not None:
-                    setattr(self, attr, getattr(self, attr)._clone())
-        
+    def copy_internals(self):
+        self._clone_from_clause()
+        self._raw_columns = [c._clone() for c in self._raw_columns]
+        self._recorrelate_froms([f._clone() for f in self._froms])
+        for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
+            if getattr(self, attr) is not None:
+                setattr(self, attr, getattr(self, attr)._clone())
+
+    def get_children(self, column_collections=True, **kwargs):
         return (column_collections and list(self.columns) or []) + \
             list(self._froms) + \
             [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
index befa5f9221d0b85dc086edc901aec5526037a4a1..18fa02dcfe0d83137bea368edf49120d7f8d7530 100644 (file)
@@ -52,10 +52,11 @@ class TraversalTest(testbase.AssertMixin):
                     if i1 != i2:
                         return True
                 return False
-                
-            def get_children(self, clone=False, **kwargs):
-                if clone:
-                    self.items = [i._clone() for i in self.items]
+            
+            def copy_internals(self):    
+                self.items = [i._clone() for i in self.items]
+
+            def get_children(self, **kwargs):
                 return self.items
             
             def accept_visitor(self, visitor):