]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rewritten ClauseAdapter merged from the eager_minus_join branch; this is a much...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Nov 2007 22:13:17 +0000 (22:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Nov 2007 22:13:17 +0000 (22:13 +0000)
and "correct" version which will copy all elements exactly once, except for those which were
replaced with target elements.  It also can match a wider variety of target elements including
joins and selects on identity alone.

lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/profiling/compiler.py
test/profiling/zoomark.py
test/sql/generative.py
test/sql/selectable.py

index f4f8aa689f9e47ccef90fa931e7228e62b456761..0197900a581cf7afaf79573d017cfd1972bc49df 100644 (file)
@@ -644,7 +644,6 @@ class OracleCompiler(compiler.DefaultCompiler):
             orderby = self.process(select._order_by_clause)
             if not orderby:
                 orderby = select.oid_column
-                self.traverse(orderby)
                 orderby = self.process(orderby)
                 
             oldselect = select
index 660d546047b2d643c2e867ec3839ebe955562aeb..bd82f897d943e40c4f0d885bd31724fd7856cebd 100644 (file)
@@ -891,7 +891,7 @@ class Connection(Connectable):
     executors = {
         expression._Function : _execute_function,
         expression.ClauseElement : _execute_clauseelement,
-        visitors.ClauseVisitor : _execute_compiled,
+        Compiled : _execute_compiled,
         schema.SchemaItem:_execute_default,
         str.__mro__[-2] : _execute_text
     }
index 09a3a0f5b71dd9889b68a7160c52c000487a7e64..b6200fee52cecb142e4d1564845e34d409ff0ecf 100644 (file)
@@ -800,7 +800,7 @@ class Query(object):
             # adapt the given WHERECLAUSE to adjust instances of this query's mapped 
             # table to be that of our select_table,
             # which may be the "polymorphic" selectable used by our mapper.
-            sql_util.ClauseAdapter(self.table).traverse(whereclause, stop_on=util.Set([self.table]))
+            whereclause = sql_util.ClauseAdapter(self.table).traverse(whereclause, stop_on=util.Set([self.table]))
 
             # if extra entities, adapt the criterion to those as well
             for m in self._entities:
index fa4ac5a9f4a7c1e42f5864adda2f341efd00018f..ef66ffd5a64be0320b199b3562f09a376513aa2d 100644 (file)
@@ -87,10 +87,11 @@ OPERATORS =  {
     operators.isnot : 'IS NOT'
 }
 
-class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
+class DefaultCompiler(engine.Compiled):
     """Default implementation of Compiled.
 
-    Compiles ClauseElements into SQL strings.
+    Compiles ClauseElements into SQL strings.   Uses a similar visit
+    paradigm as visitors.ClauseVisitor but implements its own traversal.
     """
 
     __traverse_options__ = {'column_collections':False, 'entry':True}
@@ -163,7 +164,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
         if stack:
             self.stack.append(stack)
         try:
-            return self.traverse_single(obj, **kwargs)
+            meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+            if meth:
+                return meth(obj, **kwargs)
         finally:
             if stack:
                 self.stack.pop(-1)
index 49dbc143a8f35a3334291f229aa53e16c9b3bc98..67c1b727a991aa99168d7321a43d3aaf508ea4a5 100644 (file)
@@ -775,7 +775,9 @@ func = _FunctionGenerator()
 # TODO: use UnaryExpression for this instead ?
 modifier = _FunctionGenerator(group=False)
 
-
+def _clone(element):
+    return element._clone()
+    
 def _compound_select(keyword, *selects, **kwargs):
     return CompoundSelect(keyword, *selects, **kwargs)
 
@@ -908,7 +910,7 @@ class ClauseElement(object):
 
         return self is other
 
-    def _copy_internals(self):
+    def _copy_internals(self, clone=_clone):
         """Reassign internal elements to be clones of themselves.
 
         Called during a copy-and-traverse operation on newly
@@ -1580,8 +1582,7 @@ class FromClause(Selectable):
 
         An example would be an Alias of a Table is derived from that Table.
         """
-
-        return False
+        return fromclause is self
 
     def replace_selectable(self, old, alias):
       """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
@@ -1874,8 +1875,8 @@ class _TextClause(ClauseElement):
 
     columns = property(lambda s:[])
 
-    def _copy_internals(self):
-        self.bindparams = dict([(b.key, b._clone()) for b in self.bindparams.values()])
+    def _copy_internals(self, clone=_clone):
+        self.bindparams = dict([(b.key, clone(b)) for b in self.bindparams.values()])
 
     def get_children(self, **kwargs):
         return self.bindparams.values()
@@ -1933,8 +1934,8 @@ class ClauseList(ClauseElement):
         else:
             self.clauses.append(_literal_as_text(clause))
 
-    def _copy_internals(self):
-        self.clauses = [clause._clone() for clause in self.clauses]
+    def _copy_internals(self, clone=_clone):
+        self.clauses = [clone(clause) for clause in self.clauses]
 
     def get_children(self, **kwargs):
         return self.clauses
@@ -1989,8 +1990,8 @@ class _CalculatedClause(ColumnElement):
 
     key = property(lambda self:self.name or "_calc_")
 
-    def _copy_internals(self):
-        self.clause_expr = self.clause_expr._clone()
+    def _copy_internals(self, clone=_clone):
+        self.clause_expr = clone(self.clause_expr)
 
     def clauses(self):
         if isinstance(self.clause_expr, _Grouping):
@@ -2038,8 +2039,8 @@ class _Function(_CalculatedClause, FromClause):
     key = property(lambda self:self.name)
     columns = property(lambda self:[self])
 
-    def _copy_internals(self):
-        _CalculatedClause._copy_internals(self)
+    def _copy_internals(self, clone=_clone):
+        _CalculatedClause._copy_internals(self, clone=clone)
         self._clone_from_clause()
 
     def get_children(self, **kwargs):
@@ -2059,9 +2060,9 @@ class _Cast(ColumnElement):
         self.typeclause = _TypeClause(self.type)
         self._distance = 0
 
-    def _copy_internals(self):
-        self.clause = self.clause._clone()
-        self.typeclause = self.typeclause._clone()
+    def _copy_internals(self, clone=_clone):
+        self.clause = clone(self.clause)
+        self.typeclause = clone(self.typeclause)
 
     def get_children(self, **kwargs):
         return self.clause, self.typeclause
@@ -2092,8 +2093,8 @@ class _UnaryExpression(ColumnElement):
     def _get_from_objects(self, **modifiers):
         return self.element._get_from_objects(**modifiers)
 
-    def _copy_internals(self):
-        self.element = self.element._clone()
+    def _copy_internals(self, clone=_clone):
+        self.element = clone(self.element)
 
     def get_children(self, **kwargs):
         return self.element,
@@ -2134,9 +2135,9 @@ class _BinaryExpression(ColumnElement):
     def _get_from_objects(self, **modifiers):
         return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
 
-    def _copy_internals(self):
-        self.left = self.left._clone()
-        self.right = self.right._clone()
+    def _copy_internals(self, clone=_clone):
+        self.left = clone(self.left)
+        self.right = clone(self.right)
 
     def get_children(self, **kwargs):
         return self.left, self.right
@@ -2265,11 +2266,11 @@ class Join(FromClause):
             self._foreign_keys.add(f)
         return column
 
-    def _copy_internals(self):
+    def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
-        self.left = self.left._clone()
-        self.right = self.right._clone()
-        self.onclause = self.onclause._clone()
+        self.left = clone(self.left)
+        self.right = clone(self.right)
+        self.onclause = clone(self.onclause)
         self.__folded_equivalents = None
         self._init_primary_key()
 
@@ -2414,15 +2415,7 @@ class Alias(FromClause):
             self.oid_column = None
 
     def is_derived_from(self, fromclause):
-        x = self.selectable
-        while True:
-            if x is fromclause:
-                return True
-            if isinstance(x, Alias):
-                x = x.selectable
-            else:
-                break
-        return False
+        return self.selectable.is_derived_from(fromclause)
 
     def supports_execution(self):
         return self.original.supports_execution()
@@ -2437,13 +2430,12 @@ class Alias(FromClause):
         #return self.selectable._exportable_columns()
         return self.selectable.columns
 
-    def _copy_internals(self):
-        self._clone_from_clause()
-        self.selectable = self.selectable._clone()
-        baseselectable = self.selectable
-        while isinstance(baseselectable, Alias):
-            baseselectable = baseselectable.selectable
-        self.original = baseselectable
+    def _clone(self):
+        # Alias is immutable
+        return self
+
+    def _copy_internals(self, clone=_clone):
+        pass
 
     def get_children(self, **kwargs):
         for c in self.c:
@@ -2469,8 +2461,8 @@ class _ColumnElementAdapter(ColumnElement):
     key = property(lambda s: s.elem.key)
     _label = property(lambda s: s.elem._label)
 
-    def _copy_internals(self):
-        self.elem = self.elem._clone()
+    def _copy_internals(self, clone=_clone):
+        self.elem = clone(self.elem)
 
     def get_children(self, **kwargs):
         return self.elem,
@@ -2503,8 +2495,8 @@ class _FromGrouping(FromClause):
     def _hide_froms(self, **modifiers):
         return self.elem._hide_froms(**modifiers)
 
-    def _copy_internals(self):
-        self.elem = self.elem._clone()
+    def _copy_internals(self, clone=_clone):
+        self.elem = clone(self.elem)
 
     def _get_from_objects(self, **modifiers):
         return self.elem._get_from_objects(**modifiers)
@@ -2538,8 +2530,8 @@ class _Label(ColumnElement):
     def expression_element(self):
         return self.obj
 
-    def _copy_internals(self):
-        self.obj = self.obj._clone()
+    def _copy_internals(self, clone=_clone):
+        self.obj = clone(self.obj)
 
     def get_children(self, **kwargs):
         return self.obj,
@@ -2935,13 +2927,13 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         col.orig_set = colset
         return col
 
-    def _copy_internals(self):
+    def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
         self._col_map = {}
-        self.selects = [s._clone() for s in self.selects]
+        self.selects = [clone(s) for s in self.selects]
         for attr in ('_order_by_clause', '_group_by_clause'):
             if getattr(self, attr) is not None:
-                setattr(self, attr, getattr(self, attr)._clone())
+                setattr(self, attr, clone(getattr(self, attr)))
 
     def get_children(self, column_collections=True, **kwargs):
         return (column_collections and list(self.c) or []) + \
@@ -3091,13 +3083,19 @@ class Select(_SelectBaseMixin, FromClause):
 
     inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""")
 
-    def _copy_internals(self):
+    def is_derived_from(self, fromclause):
+        for f in self.locate_all_froms():
+            if f.is_derived_from(fromclause):
+                return True
+        return False
+
+    def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
-        self._raw_columns = [c._clone() for c in self._raw_columns]
-        self._recorrelate_froms([(f, f._clone()) for f in self._froms])
+        self._raw_columns = [clone(c) for c in self._raw_columns]
+        self._recorrelate_froms([(f, clone(f)) for f in self._froms])
         for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
             if getattr(self, attr) is not None:
-                setattr(self, attr, getattr(self, attr)._clone())
+                setattr(self, attr, clone(getattr(self, attr)))
 
     def get_children(self, column_collections=True, **kwargs):
         """return child elements as per the ClauseElement specification."""
@@ -3394,7 +3392,7 @@ class Insert(_UpdateBase):
         else:
             return ()
 
-    def _copy_internals(self):
+    def _copy_internals(self, clone=_clone):
         self.parameters = self.parameters.copy()
 
     def values(self, v):
@@ -3423,8 +3421,8 @@ class Update(_UpdateBase):
         else:
             return ()
 
-    def _copy_internals(self):
-        self._whereclause = self._whereclause._clone()
+    def _copy_internals(self, clone=_clone):
+        self._whereclause = clone(self._whereclause)
         self.parameters = self.parameters.copy()
 
     def values(self, v):
@@ -3449,8 +3447,8 @@ class Delete(_UpdateBase):
         else:
             return ()
 
-    def _copy_internals(self):
-        self._whereclause = self._whereclause._clone()
+    def _copy_internals(self, clone=_clone):
+        self._whereclause = clone(self._whereclause)
 
 class _IdentifiedClause(ClauseElement):
     def __init__(self, ident):
index 8876f42baa61592120b438d9e0fca5c0c09e9dc8..eed06cfc32b7bba60247af6b095b1c5302c92934 100644 (file)
@@ -3,7 +3,6 @@ from sqlalchemy.sql import expression, visitors
 
 """Utility functions that build upon SQL and Schema constructs."""
 
-
 class TableCollection(object):
     def __init__(self, tables=None):
         self.tables = tables or []
@@ -110,87 +109,86 @@ class ColumnsInClause(visitors.ClauseVisitor):
         if self.selectable.c.get(column.key) is column:
             self.result = True
 
-class AbstractClauseProcessor(visitors.NoColumnVisitor):
-    """Traverse a clause and attempt to convert the contents of container elements
-    to a converted element.
-
-    The conversion operation is defined by subclasses.
+class AbstractClauseProcessor(object):
+    """Traverse and copy a ClauseElement, replacing selected elements based on rules.
+    
+    This class implements its own visit-and-copy strategy but maintains the
+    same public interface as visitors.ClauseVisitor.
     """
-
+    
+    __traverse_options__ = {'column_collections':False}
+    
     def convert_element(self, elem):
         """Define the *conversion* method for this ``AbstractClauseProcessor``."""
 
         raise NotImplementedError()
 
-    def copy_and_process(self, list_):
-        """Copy the container elements in the given list to a new list and
-        process the new list.
-        """
-
+    def chain(self, visitor):
+        # chaining AbstractClauseProcessor and other ClauseVisitor
+        # objects separately.  All the ACP objects are chained on 
+        # their convert_element() method whereas regular visitors
+        # chain on their visit_XXX methods.
+        if isinstance(visitor, AbstractClauseProcessor):
+            attr = '_next_acp'
+        else:
+            attr = '_next'
+        
+        tail = self
+        while getattr(tail, attr, None) is not None:
+            tail = getattr(tail, attr)
+        setattr(tail, attr, visitor)
+        return self
+
+    def copy_and_process(self, list_, stop_on=None):
+        """Copy the given list to a new list, with each element traversed individually."""
+        
         list_ = list(list_)
-        self.process_list(list_)
+        stop_on = util.Set()
+        for i in range(0, len(list_)):
+            list_[i] = self.traverse(list_[i], stop_on=stop_on)
         return list_
 
-    def process_list(self, list_):
-        """Process all elements of the given list in-place."""
-
-        for i in range(0, len(list_)):
-            elem = self.convert_element(list_[i])
-            if elem is not None:
-                list_[i] = elem
-            else:
-                list_[i] = self.traverse(list_[i], clone=True)
-    
-    def visit_grouping(self, grouping):
-        elem = self.convert_element(grouping.elem)
-        if elem is not None:
-            grouping.elem = elem
+    def _convert_element(self, elem, stop_on):
+        v = self
+        while v is not None:
+            newelem = v.convert_element(elem)
+            if newelem:
+                stop_on.add(newelem)
+                return newelem
+            v = getattr(v, '_next_acp', None)
+        return elem._clone()
+        
+    def traverse(self, elem, clone=True, stop_on=None, _clone_toplevel=True):
+        if not clone:
+            raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
             
-    def visit_clauselist(self, clist):
-        for i in range(0, len(clist.clauses)):
-            n = self.convert_element(clist.clauses[i])
-            if n is not None:
-                clist.clauses[i] = n
-    
-    def visit_unary(self, unary):
-        elem = self.convert_element(unary.element)
-        if elem is not None:
-            unary.element = elem
+        if stop_on is None:
+            stop_on = util.Set()
             
-    def visit_binary(self, binary):
-        elem = self.convert_element(binary.left)
-        if elem is not None:
-            binary.left = elem
-        elem = self.convert_element(binary.right)
-        if elem is not None:
-            binary.right = elem
-    
-    def visit_join(self, join):
-        elem = self.convert_element(join.left)
-        if elem is not None:
-            join.left = elem
-        elem = self.convert_element(join.right)
-        if elem is not None:
-            join.right = elem
-        join._init_primary_key()
+        if elem in stop_on:
+            return elem
+        
+        if _clone_toplevel:
+            elem = self._convert_element(elem, stop_on)
+            if elem in stop_on:
+                return elem
             
-    def visit_select(self, select):
-        fr = util.OrderedSet()
-        for elem in select._froms:
-            n = self.convert_element(elem)
-            if n is not None:
-                fr.add((elem, n))
-        select._recorrelate_froms(fr)
-
-        col = []
-        for elem in select._raw_columns:
-            n = self.convert_element(elem)
-            if n is None:
-                col.append(elem)
-            else:
-                col.append(n)
-        select._raw_columns = col
-    
+        def clone(element):
+            return self._convert_element(element, stop_on)
+        elem._copy_internals(clone=clone)
+        
+        v = getattr(self, '_next', None)
+        while v is not None:
+            meth = getattr(v, "visit_%s" % elem.__visit_name__, None)
+            if meth:
+                meth(elem)
+            v = getattr(v, '_next', None)
+        
+        for e in elem.get_children(**self.__traverse_options__):
+            if e not in stop_on:
+                self.traverse(e, stop_on=stop_on, _clone_toplevel=False)
+        return elem
+        
 class ClauseAdapter(AbstractClauseProcessor):
     """Given a clause (like as in a WHERE criterion), locate columns
     which are embedded within a given selectable, and changes those
@@ -243,9 +241,6 @@ class ClauseAdapter(AbstractClauseProcessor):
                 newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False)
                 if newcol:
                     return newcol
-        #if newcol is None:
-        #    self.traverse(col)
-        #    return col
         return newcol
 
 
index 98e4de6c3341ebcb4aff969598241702dd28940d..bf15c2b7eec69e2ecd6c985a9a12b1f0dae06232 100644 (file)
@@ -1,8 +1,7 @@
 class ClauseVisitor(object):
-    """A class that knows how to traverse and visit
-    ``ClauseElements``.
+    """Traverses and visits ``ClauseElement`` structures.
     
-    Calls visit_XXX() methods dynamically generated for each particualr
+    Calls visit_XXX() methods dynamically generated for each particular
     ``ClauseElement`` subclass encountered.  Traversal of a
     hierarchy of ``ClauseElements`` is achieved via the
     ``traverse()`` method, which is passed the lead
@@ -40,7 +39,7 @@ class ClauseVisitor(object):
                 traversal.insert(0, t)
                 for c in t.get_children(**self.__traverse_options__):
                     stack.append(c)
-        
+
     def traverse(self, obj, stop_on=None, clone=False):
         if clone:
             obj = obj._clone()
@@ -75,13 +74,10 @@ class ClauseVisitor(object):
         return self
 
 class NoColumnVisitor(ClauseVisitor):
-    """a ClauseVisitor that will not traverse the exported Column 
-    collections on Table, Alias, Select, and CompoundSelect objects
-    (i.e. their 'columns' or 'c' attribute).
+    """ClauseVisitor with 'column_collections' set to False; will not
+    traverse the front-facing Column collections on Table, Alias, Select, 
+    and CompoundSelect objects.
     
-    this is useful because most traversals don't need those columns, or
-    in the case of DefaultCompiler it traverses them explicitly; so
-    skipping their traversal here greatly cuts down on method call overhead.
     """
     
     __traverse_options__ = {'column_collections':False}
index 470b45cb8924f25a5eb99de40b20ce17c7e6d343..29e17db778cbc566343f6862d6509d87dc22afde 100644 (file)
@@ -24,7 +24,7 @@ class CompileTest(AssertMixin):
         t1.update().compile()
 
     # TODO: this is alittle high
-    @profiling.profiled('ctest_select', call_range=(190, 210), always=True)        
+    @profiling.profiled('ctest_select', call_range=(170, 200), always=True)        
     def test_select(self):
         s = select([t1], t1.c.c2==t2.c.c1)
         s.compile()
index 6eb313b429c199d5c7910f2482ccd63f7253858a..833aec46c6c26bc3c93adc7378cd8d42d21137a6 100644 (file)
@@ -126,7 +126,7 @@ class ZooMarkTest(testing.AssertMixin):
             tick = i.execute(Species='Tick', Name='Tick %d' % x, Legs=8)
     
     @testing.supported('postgres')
-    @profiling.profiled('properties', call_range=(3030, 3430), always=True)
+    @profiling.profiled('properties', call_range=(2900, 3330), always=True)
     def test_3_properties(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
@@ -149,7 +149,7 @@ class ZooMarkTest(testing.AssertMixin):
             ticks = fullobject(Animal.select(Animal.c.Species=='Tick'))
     
     @testing.supported('postgres')
-    @profiling.profiled('expressions', call_range=(11350, 13200), always=True)
+    @profiling.profiled('expressions', call_range=(10350, 12200), always=True)
     def test_4_expressions(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
@@ -203,7 +203,7 @@ class ZooMarkTest(testing.AssertMixin):
             assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1
     
     @testing.supported('postgres')
-    @profiling.profiled('aggregates', call_range=(1000, 1270), always=True)
+    @profiling.profiled('aggregates', call_range=(960, 1170), always=True)
     def test_5_aggregates(self):
         Animal = metadata.tables['Animal']
         Zoo = metadata.tables['Zoo']
@@ -245,7 +245,7 @@ class ZooMarkTest(testing.AssertMixin):
             legs.sort()
     
     @testing.supported('postgres')
-    @profiling.profiled('editing', call_range=(1280, 1390), always=True)
+    @profiling.profiled('editing', call_range=(1200, 1290), always=True)
     def test_6_editing(self):
         Zoo = metadata.tables['Zoo']
         
@@ -274,7 +274,7 @@ class ZooMarkTest(testing.AssertMixin):
             assert SDZ['Founded'] == datetime.date(1935, 9, 13)
     
     @testing.supported('postgres')
-    @profiling.profiled('multiview', call_range=(2820, 3155), always=True)
+    @profiling.profiled('multiview', call_range=(2720, 3055), always=True)
     def test_7_multiview(self):
         Zoo = metadata.tables['Zoo']
         Animal = metadata.tables['Animal']
index 437f0874ef875b05f6c4a5237ef224ac00b05bbf..7892732d6a4dd4c81fa363f91e3497355c359fbc 100644 (file)
@@ -4,6 +4,8 @@ from sqlalchemy.sql import table, column, ClauseElement
 from testlib import *
 from sqlalchemy.sql.visitors import *
 from sqlalchemy import util
+from sqlalchemy.sql import util as sql_util
+
 
 class TraversalTest(AssertMixin):
     """test ClauseVisitor's traversal, particularly its ability to copy and modify
@@ -133,7 +135,8 @@ class TraversalTest(AssertMixin):
         s3 = vis2.traverse(struct, clone=True)
         assert struct != s3
         assert struct3 == s3
-
+    
+        
 class ClauseTest(SQLCompileTest):
     """test copy-in-place behavior of various ClauseElements."""
     
@@ -230,7 +233,6 @@ class ClauseTest(SQLCompileTest):
         self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2")
 
     def test_clause_adapter(self):
-        from sqlalchemy.sql import util as sql_util
         
         t1alias = t1.alias('t1alias')
         
@@ -257,7 +259,47 @@ class ClauseTest(SQLCompileTest):
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+
+    def test_joins(self):
+        """test that ClauseAdapter can target a Join object, replace it, and not dig into the sub-joins after
+        replacing."""
+        
+        metadata = MetaData()
+        a = Table('a', metadata,
+            Column('id', Integer, primary_key=True))
+        b = Table('b', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('aid', Integer, ForeignKey('a.id')),
+            )
+        c = Table('c', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('bid', Integer, ForeignKey('b.id')),
+            )
+
+        d = Table('d', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('aid', Integer, ForeignKey('a.id')),
+            )
+
+        j1 = a.outerjoin(b)
+        j2 = select([j1], use_labels=True)
+
+        j3 = c.join(j2, j2.c.b_id==c.c.bid)
+
+        j4 = j3.outerjoin(d)
+        self.assert_compile(j4,  "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) "
+                                 "ON b_id = c.bid"
+                                 " LEFT OUTER JOIN d ON a_id = d.aid")
+        j5 = j3.alias('foo')
+        j6 = sql_util.ClauseAdapter(j5).copy_and_process([j4])[0]
         
+        # this statement takes c join(a join b), wraps it inside an aliased "select * from c join(a join b) AS foo".
+        # the outermost right side "left outer join d" stays the same, except "d" joins against foo.a_id instead
+        # of plain "a_id"
+        self.assert_compile(j6, "(SELECT c.id AS c_id, c.bid AS c_bid, a_id AS a_id, b_id AS b_id, b_aid AS b_aid FROM "
+                                "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) "
+                                "ON b_id = c.bid) AS foo"
+                                " LEFT OUTER JOIN d ON foo.a_id = d.aid")
         
         
 class SelectTest(SQLCompileTest):
index 83203ad8e2ab859a7dde0e08153841858851bfa7..72f5f35d040a1e8b932bb850732db83d8b07d901 100755 (executable)
@@ -230,6 +230,39 @@ class PrimaryKeyTest(AssertMixin):
         assert str(j) == "a JOIN b ON a.id = b.id AND b.x = :b_x", str(j)
         assert list(j.primary_key) == [a.c.id, b.c.x]
 
+class DerivedTest(AssertMixin):
+    def test_table(self):
+        meta = MetaData()
+        t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+        t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+        
+        assert t1.is_derived_from(t1)
+        assert not t2.is_derived_from(t1)
+        
+    def test_alias(self):    
+        meta = MetaData()
+        t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+        t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+        
+        assert t1.alias().is_derived_from(t1)
+        assert not t2.alias().is_derived_from(t1)
+        assert not t1.is_derived_from(t1.alias())
+        assert not t1.is_derived_from(t2.alias())
+    
+    def test_select(self):
+        meta = MetaData()
+        t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+        t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+
+        assert t1.select().is_derived_from(t1)
+        assert not t2.select().is_derived_from(t1)
+
+        assert select([t1, t2]).is_derived_from(t1)
+
+        assert t1.select().alias('foo').is_derived_from(t1)
+        assert select([t1, t2]).alias('foo').is_derived_from(t1)
+        assert not t2.select().alias('foo').is_derived_from(t1)
+
 if __name__ == "__main__":
     testbase.main()