]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
removed AbstractClauseProcessor, merged its copy-and-visit behavior into ClauseVisitor
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Mar 2008 23:55:21 +0000 (23:55 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Mar 2008 23:55:21 +0000 (23:55 +0000)
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/sql/generative.py

index 9954811d644afe13687eae5a1712746a62dc64ff..d4163b73b62bfc370baa1ac7249492ebf60a792b 100644 (file)
@@ -147,102 +147,7 @@ class ColumnsInClause(visitors.ClauseVisitor):
         if self.selectable.c.get(column.key) is column:
             self.result = True
 
-class AbstractClauseProcessor(object):
-    """Traverse and copy a ClauseElement, replacing selected elements based on rules.
-
-    This class implements its own visit-and-copy strategy but maintains the
-    same public interface as visitors.ClauseVisitor.
-    
-    The convert_element() method receives the *un-copied* version of each element.
-    It can return a new element or None for no change.  If None, the element
-    will be cloned afterwards and added to the new structure.  Note this is the
-    opposite behavior of visitors.traverse(clone=True), where visitors receive
-    the cloned element so that it can be mutated.
-    """
-
-    __traverse_options__ = {'column_collections':False}
-
-    def __init__(self, stop_on=None):
-        self.stop_on = stop_on
-
-    def convert_element(self, elem):
-        """Define the *conversion* method for this ``AbstractClauseProcessor``."""
-
-        raise NotImplementedError()
-
-    def chain(self, visitor):
-        # chaining AbstractClauseProcessor and other ClauseVisitor
-        # objects separately.  All the ACP objects are chained on
-        # their convert_element() method whereas regular visitors
-        # chain on their visit_XXX methods.
-        if isinstance(visitor, AbstractClauseProcessor):
-            attr = '_next_acp'
-        else:
-            attr = '_next'
-
-        tail = self
-        while getattr(tail, attr, None) is not None:
-            tail = getattr(tail, attr)
-        setattr(tail, attr, visitor)
-        return self
-
-    def copy_and_process(self, list_):
-        """Copy the given list to a new list, with each element traversed individually."""
-
-        list_ = list(list_)
-        stop_on = util.Set(self.stop_on or [])
-        cloned = {}
-        for i in range(0, len(list_)):
-            list_[i] = self._traverse(list_[i], stop_on, cloned, _clone_toplevel=True)
-        return list_
-
-    def _convert_element(self, elem, stop_on, cloned):
-        v = self
-        while v is not None:
-            newelem = v.convert_element(elem)
-            if newelem:
-                stop_on.add(newelem)
-                return newelem
-            v = getattr(v, '_next_acp', None)
-
-        if elem not in cloned:
-            # the full traversal will only make a clone of a particular element
-            # once.
-            cloned[elem] = elem._clone()
-        return cloned[elem]
-
-    def traverse(self, elem, clone=True):
-        if not clone:
-            raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
-
-        return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True)
-
-    def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False):
-        if elem in stop_on:
-            return elem
-
-        if _clone_toplevel:
-            elem = self._convert_element(elem, stop_on, cloned)
-            if elem in stop_on:
-                return elem
-
-        def clone(element):
-            return self._convert_element(element, stop_on, cloned)
-        elem._copy_internals(clone=clone)
-
-        v = getattr(self, '_next', None)
-        while v is not None:
-            meth = getattr(v, "visit_%s" % elem.__visit_name__, None)
-            if meth:
-                meth(elem)
-            v = getattr(v, '_next', None)
-
-        for e in elem.get_children(**self.__traverse_options__):
-            if e not in stop_on:
-                self._traverse(e, stop_on, cloned)
-        return elem
-
-class ClauseAdapter(AbstractClauseProcessor):
+class ClauseAdapter(visitors.ClauseVisitor):
     """Given a clause (like as in a WHERE criterion), locate columns
     which are embedded within a given selectable, and changes those
     columns to be that of the selectable.
@@ -270,13 +175,21 @@ class ClauseAdapter(AbstractClauseProcessor):
       s.c.col1 == table2.c.col1
     """
 
+    __traverse_options__ = {'column_collections':False}
+
     def __init__(self, selectable, include=None, exclude=None, equivalents=None):
-        AbstractClauseProcessor.__init__(self, [selectable])
+        self.__traverse_options__ = self.__traverse_options__.copy()
+        self.__traverse_options__['stop_on'] = [selectable]
         self.selectable = selectable
         self.include = include
         self.exclude = exclude
         self.equivalents = equivalents
-
+    
+    def traverse(self, obj, clone=True):
+        if not clone:
+            raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True")
+        return visitors.ClauseVisitor.traverse(self, obj, clone=True)
+        
     def copy_and_chain(self, adapter):
         """create a copy of this adapter and chain to the given adapter.
 
@@ -289,14 +202,14 @@ class ClauseAdapter(AbstractClauseProcessor):
         if adapter is None:
             return self
 
-        if hasattr(self, '_next_acp') or hasattr(self, '_next'):
+        if hasattr(self, '_next'):
             raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)")
 
         ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents)
-        ca._next_acp = adapter
+        ca._next = adapter
         return ca
 
-    def convert_element(self, col):
+    def before_clone(self, col):
         if isinstance(col, expression.FromClause):
             if self.selectable.is_derived_from(col):
                 return self.selectable
index bb63ab09c92d547e74c4e5b277782e25b860ad88..57dfb4b96db1b61b65ed04251fffec5fb2a03c7b 100644 (file)
@@ -1,7 +1,9 @@
+from sqlalchemy import util
+
 class ClauseVisitor(object):
     """Traverses and visits ``ClauseElement`` structures.
     
-    Calls visit_XXX() methods dynamically generated for each particular
+    Calls visit_XXX() methods for each particular
     ``ClauseElement`` subclass encountered.  Traversal of a
     hierarchy of ``ClauseElements`` is achieved via the
     ``traverse()`` method, which is passed the lead
@@ -25,19 +27,18 @@ class ClauseVisitor(object):
     __traverse_options__ = {}
     
     def traverse_single(self, obj, **kwargs):
-        meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
-        if meth:
-            return meth(obj, **kwargs)
-
-    def traverse_chained(self, obj, **kwargs):
-        v = self
-        while v is not None:
-            meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+        """visit a single element, without traversing its child elements."""
+        
+        for v in self._iterate_visitors:
+            meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
             if meth:
-                meth(obj, **kwargs)
-            v = getattr(v, '_next', None)
+                return meth(obj, **kwargs)
+    
+    traverse_chained = traverse_single
         
     def iterate(self, obj):
+        """traverse the given expression structure, and return an iterator of all elements."""
+        
         stack = [obj]
         traversal = []
         while len(stack) > 0:
@@ -48,39 +49,118 @@ class ClauseVisitor(object):
                 stack.append(c)
     
     def traverse(self, obj, clone=False):
+        """traverse the given expression structure.
+        
+        Returns the structure given, or a copy of the structure if
+        clone=True.
         
+        When the copy operation takes place, the before_clone() method
+        will receive each element before it is copied.  If the method
+        returns a non-None value, the return value is taken as the 
+        "copied" element and traversal will not descend further.  
+        
+        The visit_XXX() methods receive the element *after* it's been
+        copied.  To compare an element to another regardless of
+        one element being a cloned copy of the original, the 
+        '_cloned_set' attribute of ClauseElement can be used for the compare, 
+        i.e.::
+        
+            original in copied._cloned_set
+            
+        
+        """
         if clone:
-            cloned = {}
-            def do_clone(obj):
-                # the full traversal will only make a clone of a particular element
-                # once.
-                if obj not in cloned:
-                    cloned[obj] = obj._clone()
-                return cloned[obj]
+            return self._cloned_traversal(obj)
+        else:
+            return self._non_cloned_traversal(obj)
+
+    def copy_and_process(self, list_):
+        """Apply cloned traversal to the given list of elements, and return the new list."""
+
+        return [self._cloned_traversal(x) for x in list_]
+
+    def before_clone(self, elem):
+        """receive pre-copied elements during a cloning traversal.
+        
+        If the method returns a new element, the element is used 
+        instead of creating a simple copy of the element.  Traversal 
+        will halt on the newly returned element if it is re-encountered.
+        """
+        return None
+    
+    def _clone_element(self, elem, stop_on, cloned):
+        for v in self._iterate_visitors:
+            newelem = v.before_clone(elem)
+            if newelem:
+                stop_on.add(newelem)
+                return newelem
+
+        if elem not in cloned:
+            # the full traversal will only make a clone of a particular element
+            # once.
+            cloned[elem] = elem._clone()
+        return cloned[elem]
             
-            obj = do_clone(obj)
+    def _cloned_traversal(self, obj):
+        """a recursive traversal which creates copies of elements, returning the new structure."""
+        
+        stop_on = self.__traverse_options__.get('stop_on', [])
+        return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True)
+        
+    def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False):
+        if elem in stop_on:
+            return elem
+
+        if _clone_toplevel:
+            elem = self._clone_element(elem, stop_on, cloned)
+            if elem in stop_on:
+                return elem
+
+        def clone(element):
+            return self._clone_element(element, stop_on, cloned)
+        elem._copy_internals(clone=clone)
+
+        for v in self._iterate_visitors:
+            meth = getattr(v, "visit_%s" % elem.__visit_name__, None)
+            if meth:
+                meth(elem)
+
+        for e in elem.get_children(**self.__traverse_options__):
+            if e not in stop_on:
+                self._cloned_traversal_impl(e, stop_on, cloned)
+        return elem
+        
+    def _non_cloned_traversal(self, obj):
+        """a non-recursive, non-cloning traversal."""
+        
         stack = [obj]
         traversal = []
         while len(stack) > 0:
             t = stack.pop()
             traversal.insert(0, t)
-            if clone:
-                t._copy_internals(clone=do_clone)
             for c in t.get_children(**self.__traverse_options__):
                 stack.append(c)
         for target in traversal:
-            v = self
-            while v is not None:
+            for v in self._iterate_visitors:
                 meth = getattr(v, "visit_%s" % target.__visit_name__, None)
                 if meth:
                     meth(target)
-                v = getattr(v, '_next', None)
         return obj
 
+    def _iterate_visitors(self):
+        """iterate through this visitor and each 'chained' visitor."""
+        
+        v = self
+        while v is not None:
+            yield v
+            v = getattr(v, '_next', None)
+    _iterate_visitors = property(_iterate_visitors)
+
     def chain(self, visitor):
         """'chain' an additional ClauseVisitor onto this ClauseVisitor.
         
-        the chained visitor will receive all visit events after this one."""
+        the chained visitor will receive all visit events after this one.
+        """
         tail = self
         while getattr(tail, '_next', None) is not None:
             tail = tail._next
@@ -96,14 +176,16 @@ class NoColumnVisitor(ClauseVisitor):
     
     __traverse_options__ = {'column_collections':False}
 
+
 def traverse(clause, **kwargs):
+    """traverse the given clause, applying visit functions passed in as keyword arguments."""
+    
     clone = kwargs.pop('clone', False)
     class Vis(ClauseVisitor):
         __traverse_options__ = kwargs.pop('traverse_options', {})
-        def __getattr__(self, key):
-            if key in kwargs:
-                return kwargs[key]
-            else:
-                return None
-    return Vis().traverse(clause, clone=clone)
+    vis = Vis()
+    for key in kwargs:
+        if key.startswith('visit_'):
+            setattr(vis, key, kwargs[key])
+    return vis.traverse(clause, clone=clone)
 
index 3d7c88972e8dd694469e7cf7c7003db9222bf2fd..0994491d98fceac9750db6bdb968ea690b8f8666 100644 (file)
@@ -92,7 +92,7 @@ class TraversalTest(TestBase, AssertsExecutionResults):
         s2 = vis.traverse(struct, clone=True)
         assert struct == s2
         assert not struct.is_other(s2)
-
+    
     def test_no_clone(self):
         struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
 
@@ -430,7 +430,38 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
             "(SELECT foo.col1 AS col1, foo.col2 AS col2, foo.col3 AS col3 FROM "\
             "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo  LIMIT 5 OFFSET 10) AS anon_1 "\
             "LEFT OUTER JOIN table1 AS bar ON anon_1.col1 = bar.col1")
+    
+    def test_recursive(self):
+        metadata = MetaData()
+        a = Table('a', metadata,
+            Column('id', Integer, primary_key=True))
+        b = Table('b', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('aid', Integer, ForeignKey('a.id')),
+            )
+        c = Table('c', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('bid', Integer, ForeignKey('b.id')),
+            )
+
+        d = Table('d', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('aid', Integer, ForeignKey('a.id')),
+            )
 
+        u = union(
+            a.join(b).select().apply_labels(),
+            a.join(d).select().apply_labels()
+        ).alias()    
+        
+        self.assert_compile(
+            sql_util.ClauseAdapter(u).traverse(select([c.c.bid]).where(c.c.bid==u.c.b_aid)),
+            "SELECT c.bid "\
+            "FROM c, (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid "\
+            "FROM a JOIN b ON a.id = b.aid UNION SELECT a.id AS a_id, d.id AS d_id, d.aid AS d_aid "\
+            "FROM a JOIN d ON a.id = d.aid) AS anon_1 "\
+            "WHERE c.bid = anon_1.b_aid"
+        )
 
 class SelectTest(TestBase, AssertsCompiledSQL):
     """tests the generative capability of Select"""