]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merge of generative_sql branch
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 24 Jun 2007 19:58:41 +0000 (19:58 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 24 Jun 2007 19:58:41 +0000 (19:58 +0000)
- copy_container() removed.  ClauseVisitor.traverse() now features "clone"
flag which allows traversal with copy-and-modify-in-place behavior
- select() objects copyable now [ticket:52] [ticket:569]
- improved support for custom column_property() attributes which
  feature correlated subqueries...work better with eager loading now.
- accept_visitor()  methods removed.  ClauseVisitor now genererates method
names based on class names, or an optional __visit_name__ attribute.  calls
regular visit_XXX methods as they exist, can optionally call an additional
"pre-descent" enter_XXX method to allow stack-based operations on traversals
- select() and union()'s now have "generative" behavior.  methods like
order_by() and group_by() return a *new* instance - the original instance
is left unchanged.  non-generative methods remain as well.
- the internals of select/union vastly simplified - all decision making
regarding "is subquery" and "correlation" pushed to SQL generation phase.
select() elements are now *never* mutated by their enclosing containers
or by any dialect's compilation process

26 files changed:
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/sql_util.py
lib/sqlalchemy/util.py
test/orm/generative.py
test/orm/inheritance/poly_linked_list.py
test/orm/mapper.py
test/perf/masseagerload.py
test/sql/alltests.py
test/sql/generative.py [new file with mode: 0644]
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index febda7d016a266737fc3b2d53806fb38a01b0bee..b118856feee0ab8ee516ccd6d45d63c93e9cb59c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,8 @@
       auto-construction of joins which cross the same paths but
       are querying divergent criteria.  ClauseElements at the front
       of filter_by() are removed (use filter()).
+    - improved support for custom column_property() attributes which
+      feature correlated subqueries...work better with eager loading now.
     - along with recent speedups to ResultProxy, total number of
       function calls significantly reduced for large loads.
       test/perf/masseagerload.py reports 0.4 as having the fewest number
     - added undefer_group() MapperOption, sets a set of "deferred"
       columns joined by a "group" to load as "undeferred".
 - sql
+  - significant architectural overhaul to SQL elements (ClauseElement).  
+    all elements share  a common "mutability" framework which allows a 
+    consistent approach to in-place modifications of elements as well as 
+    generative behavior.  improves stability of the ORM which makes 
+    heavy usage of mutations to SQL expressions.  
+  - select() and union()'s now have "generative" behavior.  methods like
+    order_by() and group_by() return a *new* instance - the original instance
+    is left unchanged.  non-generative methods remain as well.  
+  - the internals of select/union vastly simplified - all decision making
+    regarding "is subquery" and "correlation" pushed to SQL generation phase.
+    select() elements are now *never* mutated by their enclosing containers
+    or by any dialect's compilation process [ticket:52] [ticket:569]
   - result sets from CRUD operations close their underlying cursor immediately.
     will also autoclose the connection if defined for the operation; this 
     allows more efficient usage of connections for successive CRUD operations
index c489d7929ae618a555ee39f1b62547f92afee7de..e8610f864403a48d7532d1847250d944570a9d36 100644 (file)
@@ -66,13 +66,13 @@ class ANSIDialect(default.DefaultDialect):
         """
         return ANSIIdentifierPreparer(self)
 
-class ANSICompiler(sql.Compiled):
+class ANSICompiler(engine.Compiled):
     """Default implementation of Compiled.
 
     Compiles ClauseElements into ANSI-compliant SQL strings.
     """
 
-    __traverse_options__ = {'column_collections':False}
+    __traverse_options__ = {'column_collections':False, 'entry':True}
 
     def __init__(self, dialect, statement, parameters=None, **kwargs):
         """Construct a new ``ANSICompiler`` object.
@@ -92,7 +92,7 @@ class ANSICompiler(sql.Compiled):
           correspond to the keys present in the parameters.
         """
         
-        sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
+        super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs)
 
         # if we are insert/update.  set to true when we visit an INSERT or UPDATE
         self.isinsert = self.isupdate = False
@@ -158,7 +158,14 @@ class ANSICompiler(sql.Compiled):
 
         # an ANSIIdentifierPreparer that formats the quoting of identifiers
         self.preparer = dialect.identifier_preparer
-
+        
+        # a dictionary containing attributes about all select()
+        # elements located within the clause, regarding which are subqueries, which are
+        # selected from, and which elements should be correlated to an enclosing select.
+        # used mostly to determine the list of FROM elements for each select statement, as well
+        # as some dialect-specific rules regarding subqueries.
+        self.correlate_state = {}
+        
         # for UPDATE and INSERT statements, a set of columns whos values are being set
         # from a SQL expression (i.e., not one of the bind parameter values).  if present,
         # default-value logic in the Dialect knows not to fire off column defaults
@@ -193,7 +200,10 @@ class ANSICompiler(sql.Compiled):
 
     def get_str(self, obj):
         return self.strings[obj]
-
+    
+    def is_subquery(self, select):
+        return self.correlate_state[select].get('is_subquery', False)
+        
     def get_whereclause(self, obj):
         return self.wheres.get(obj, None)
 
@@ -343,7 +353,7 @@ class ANSICompiler(sql.Compiled):
 
     def visit_compound_select(self, cs):
         text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
-        group_by = self.get_str(cs.group_by_clause)
+        group_by = self.get_str(cs._group_by_clause)
         if group_by:
             text += " GROUP BY " + group_by
         text += self.order_by_clause(cs)            
@@ -424,40 +434,68 @@ class ANSICompiler(sql.Compiled):
         self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias)
         self.strings[alias] = self.get_str(alias.original)
 
+    def enter_select(self, select):
+        select.calculate_correlations(self.correlate_state)
+        self.select_stack.append(select)
+    
+    def enter_update(self, update):
+        update.calculate_correlations(self.correlate_state)
+
+    def enter_delete(self, delete):
+        delete.calculate_correlations(self.correlate_state)
+    
+    def label_select_column(self, select, column):
+        """convert a column from a select's "columns" clause.
+        
+        given a select() and a column element from its inner_columns collection, return a
+        Label object if this column should be labeled in the columns clause.  Otherwise,
+        return None and the column will be used as-is.
+        
+        The calling method will traverse the returned label to acquire its string
+        representation.
+        """
+        
+        # SQLite doesnt like selecting from a subquery where the column
+        # names look like table.colname. so if column is in a "selected from"
+        # subquery, label it synoymously with its column name
+        if \
+            self.correlate_state[select].get('is_selected_from', False) and \
+            isinstance(column, sql._ColumnClause) and \
+            not column.is_literal and \
+            column.table is not None and \
+            not isinstance(column.table, sql.Select):
+            return column.label(column.name)
+        else:
+            return None
+            
     def visit_select(self, select):
         # the actual list of columns to print in the SELECT column list.
         inner_columns = util.OrderedDict()
-
-        self.select_stack.append(select)
-        for c in select._raw_columns:
-            if hasattr(c, '_selectable'):
-                s = c._selectable()
+        
+        froms = select.get_display_froms(self.correlate_state)
+        for f in froms:
+            if f not in self.strings:
+                self.traverse(f)
+                
+        for co in select.inner_columns:
+            if select.use_labels:
+                labelname = co._label
+                if labelname is not None:
+                    l = co.label(labelname)
+                    self.traverse(l)
+                    inner_columns[labelname] = l
+                else:
+                    self.traverse(co)
+                    inner_columns[self.get_str(co)] = co
             else:
-                self.traverse(c)
-                inner_columns[self.get_str(c)] = c
-                continue
-            for co in s.columns:
-                if select.use_labels:
-                    labelname = co._label
-                    if labelname is not None:
-                        l = co.label(labelname)
-                        self.traverse(l)
-                        inner_columns[labelname] = l
-                    else:
-                        self.traverse(co)
-                        inner_columns[self.get_str(co)] = co
-                # TODO: figure this out, a ColumnClause with a select as a parent
-                # is different from any other kind of parent
-                elif select.is_selected_from and isinstance(co, sql._ColumnClause) and not co.is_literal and co.table is not None and not isinstance(co.table, sql.Select):
-                    # SQLite doesnt like selecting from a subquery where the column
-                    # names look like table.colname, so add a label synonomous with
-                    # the column name
-                    l = co.label(co.name)
+                l = self.label_select_column(select, co)
+                if l is not None:
                     self.traverse(l)
                     inner_columns[self.get_str(l.obj)] = l
                 else:
                     self.traverse(co)
                     inner_columns[self.get_str(co)] = co
+                    
         self.select_stack.pop(-1)
 
         collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
@@ -466,29 +504,10 @@ class ANSICompiler(sql.Compiled):
         text += self.visit_select_precolumns(select)
         text += collist
 
-        whereclause = select.whereclause
-
-        froms = []
-        for f in select.froms:
-
-            if self.parameters is not None:
-                # TODO: whack this feature in 0.4
-                # look at our own parameters, see if they
-                # are all present in the form of BindParamClauses.  if
-                # not, then append to the above whereclause column conditions
-                # matching those keys
-                for c in f.columns:
-                    if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
-                        value = self.parameters[c.key]
-                    else:
-                        continue
-                    clause = c==value
-                    if whereclause is not None:
-                        whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause]))
-                    else:
-                        whereclause = clause
-                        self.traverse(whereclause)
+        whereclause = select._whereclause
 
+        from_strings = []
+        for f in froms:
             # special thingy used by oracle to redefine a join
             w = self.get_whereclause(f)
             if w is not None:
@@ -500,11 +519,11 @@ class ANSICompiler(sql.Compiled):
 
             t = self.get_from_text(f)
             if t is not None:
-                froms.append(t)
+                from_strings.append(t)
 
         if len(froms):
             text += " \nFROM "
-            text += string.join(froms, ', ')
+            text += string.join(from_strings, ', ')
         else:
             text += self.default_from()
 
@@ -513,12 +532,12 @@ class ANSICompiler(sql.Compiled):
             if t:
                 text += " \nWHERE " + t
 
-        group_by = self.get_str(select.group_by_clause)
+        group_by = self.get_str(select._group_by_clause)
         if group_by:
             text += " GROUP BY " + group_by
 
-        if select.having is not None:
-            t = self.get_str(select.having)
+        if select._having is not None:
+            t = self.get_str(select._having)
             if t:
                 text += " \nHAVING " + t
 
@@ -532,7 +551,7 @@ class ANSICompiler(sql.Compiled):
     def visit_select_precolumns(self, select):
         """Called when building a ``SELECT`` statement, position is just before column list."""
 
-        return select.distinct and "DISTINCT " or ""
+        return select._distinct and "DISTINCT " or ""
 
     def visit_select_postclauses(self, select):
         """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses.
@@ -540,10 +559,10 @@ class ANSICompiler(sql.Compiled):
         Most DB syntaxes put ``LIMIT``/``OFFSET`` here.
         """
 
-        return (select.limit or select.offset) and self.limit_clause(select) or ""
+        return (select._limit or select._offset) and self.limit_clause(select) or ""
 
     def order_by_clause(self, select):
-        order_by = self.get_str(select.order_by_clause)
+        order_by = self.get_str(select._order_by_clause)
         if order_by:
             return " ORDER BY " + order_by
         else:
@@ -557,12 +576,12 @@ class ANSICompiler(sql.Compiled):
 
     def limit_clause(self, select):
         text = ""
-        if select.limit is not None:
-            text +=  " \n LIMIT " + str(select.limit)
-        if select.offset is not None:
-            if select.limit is None:
+        if select._limit is not None:
+            text +=  " \n LIMIT " + str(select._limit)
+        if select._offset is not None:
+            if select._limit is None:
                 text += " \n LIMIT -1"
-            text += " OFFSET " + str(select.offset)
+            text += " OFFSET " + str(select._offset)
         return text
 
     def visit_table(self, table):
@@ -696,8 +715,8 @@ class ANSICompiler(sql.Compiled):
 
         text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
 
-        if update_stmt.whereclause:
-            text += " WHERE " + self.get_str(update_stmt.whereclause)
+        if update_stmt._whereclause:
+            text += " WHERE " + self.get_str(update_stmt._whereclause)
 
         self.strings[update_stmt] = text
 
@@ -755,13 +774,14 @@ class ANSICompiler(sql.Compiled):
                 if sql._is_literal(value):
                     value = sql.bindparam(c.key, value, type=c.type, unique=True)
                 values.append((c, value))
+
         return values
 
     def visit_delete(self, delete_stmt):
         text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
 
-        if delete_stmt.whereclause:
-            text += " WHERE " + self.get_str(delete_stmt.whereclause)
+        if delete_stmt._whereclause:
+            text += " WHERE " + self.get_str(delete_stmt._whereclause)
 
         self.strings[delete_stmt] = text
 
@@ -795,7 +815,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
     def visit_metadata(self, metadata):
         collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
         for table in collection:
-            table.accept_visitor(self)
+            self.traverse_single(table)
         if self.dialect.supports_alter():
             for alterable in self.find_alterables(collection):
                 self.add_foreignkey(alterable)
@@ -803,7 +823,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
     def visit_table(self, table):
         for column in table.columns:
             if column.default is not None:
-                column.default.accept_visitor(self)
+                self.traverse_single(column.default)
             #if column.onupdate is not None:
             #    column.onupdate.accept_visitor(visitor)
 
@@ -820,20 +840,20 @@ class ANSISchemaGenerator(ANSISchemaBase):
             if column.primary_key:
                 first_pk = True
             for constraint in column.constraints:
-                constraint.accept_visitor(self)
+                self.traverse_single(constraint)
 
         # On some DB order is significant: visit PK first, then the
         # other constraints (engine.ReflectionTest.testbasic failed on FB2)
         if len(table.primary_key):
-            table.primary_key.accept_visitor(self)
+            self.traverse_single(table.primary_key)
         for constraint in [c for c in table.constraints if c is not table.primary_key]:
-            constraint.accept_visitor(self)
+            self.traverse_single(constraint)
 
         self.append("\n)%s\n\n" % self.post_create_table(table))
         self.execute()
         if hasattr(table, 'indexes'):
             for index in table.indexes:
-                index.accept_visitor(self)
+                self.traverse_single(index)
 
     def post_create_table(self, table):
         return ''
@@ -929,7 +949,7 @@ class ANSISchemaDropper(ANSISchemaBase):
             for alterable in self.find_alterables(collection):
                 self.drop_foreignkey(alterable)
         for table in collection:
-            table.accept_visitor(self)
+            self.traverse_single(table)
 
     def visit_index(self, index):
         self.append("\nDROP INDEX " + index.name)
@@ -942,7 +962,7 @@ class ANSISchemaDropper(ANSISchemaBase):
     def visit_table(self, table):
         for column in table.columns:
             if column.default is not None:
-                column.default.accept_visitor(self)
+                self.traverse_single(column.default)
 
         self.append("\nDROP TABLE " + self.preparer.format_table(table))
         self.execute()
index a02781c846b195fdeb3007cd1af71c49ee42f5d1..64f5842384191c5516a77f330ce2a3a20ca1abba 100644 (file)
@@ -324,11 +324,11 @@ class FBCompiler(ansisql.ANSICompiler):
         """
 
         result = ""
-        if select.limit:
-            result += " FIRST %d "  % select.limit
-        if select.offset:
-            result +=" SKIP %d "  %  select.offset
-        if select.distinct:
+        if select._limit:
+            result += " FIRST %d "  % select._limit
+        if select._offset:
+            result +=" SKIP %d "  %  select._offset
+        if select._distinct:
             result += " DISTINCT "
         return result
 
index 2fb508280c5cd39deccc49da1f9d5a0987a6a70e..99bc3896c9898e4b79af4c479c61a18e3824337c 100644 (file)
@@ -373,19 +373,19 @@ class InfoCompiler(ansisql.ANSICompiler):
         return " from systables where tabname = 'systables' "
     
     def visit_select_precolumns( self , select ):
-        s = select.distinct and "DISTINCT " or ""
+        s = select._distinct and "DISTINCT " or ""
         # only has limit
-        if select.limit:
-            off = select.offset or 0
-            s += " FIRST %s " % ( select.limit + off )
+        if select._limit:
+            off = select._offset or 0
+            s += " FIRST %s " % ( select._limit + off )
         else:
             s += ""
         return s
     
     def visit_select(self, select):
-        if select.offset:
-            self.offset = select.offset
-            self.limit  = select.limit or 0
+        if select._offset:
+            self.offset = select._offset
+            self.limit  = select._limit or 0
         # the column in order by clause must in select too
         
         def __label( c ):
index 2b6808eaca4bada54e0df74e3e6d0c4e6ad4b615..8b81884bb9a8ca7890c0ba8739beee35b52991c9 100644 (file)
@@ -25,7 +25,7 @@
 
 * Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
 
-* ``select.limit`` implemented as ``SELECT TOP n``
+* ``select._limit`` implemented as ``SELECT TOP n``
 
 
 Known issues / TODO:
@@ -756,10 +756,10 @@ class MSSQLCompiler(ansisql.ANSICompiler):
 
     def visit_select_precolumns(self, select):
         """ MS-SQL puts TOP, it's version of LIMIT here """
-        s = select.distinct and "DISTINCT " or ""
-        if select.limit:
-            s += "TOP %s " % (select.limit,)
-        if select.offset:
+        s = select._distinct and "DISTINCT " or ""
+        if select._limit:
+            s += "TOP %s " % (select._limit,)
+        if select._offset:
             raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
         return s
 
@@ -803,13 +803,11 @@ class MSSQLCompiler(ansisql.ANSICompiler):
             binary.left, binary.right = binary.right, binary.left
         super(MSSQLCompiler, self).visit_binary(binary)
 
-    def visit_select(self, select):        
-        # label function calls, so they return a name in cursor.description        
-        for i,c in enumerate(select._raw_columns):
-            if isinstance(c, sql._Function):
-                select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])        
-
-        super(MSSQLCompiler, self).visit_select(select)
+    def label_select_column(self, select, column):
+        if isinstance(column, sql._Function):
+            return co.label(co.name + "_" + hex(random.randint(0, 65535))[2:])        
+        else:
+            return super(MSSQLCompiler, self).label_select_column(select, column)
 
     function_rewrites =  {'current_date': 'getdate',
                           'length':     'len',
@@ -823,10 +821,10 @@ class MSSQLCompiler(ansisql.ANSICompiler):
         return ''
 
     def order_by_clause(self, select):
-        order_by = self.get_str(select.order_by_clause)
+        order_by = self.get_str(select._order_by_clause)
 
         # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
-        if order_by and (not select.is_subquery or select.limit):
+        if order_by and (not self.is_subquery(select) or select._limit):
             return " ORDER BY " + order_by
         else:
             return ""
index e45536a756fe898b81e6aee41dbe38ee90efb9fe..66bb306b56f7ffb518df7e35dc4d43c4bfe0a044 100644 (file)
@@ -1200,13 +1200,13 @@ class MySQLCompiler(ansisql.ANSICompiler):
 
     def limit_clause(self, select):
         text = ""
-        if select.limit is not None:
-            text +=  " \n LIMIT " + str(select.limit)
-        if select.offset is not None:
-            if select.limit is None:
-                # striaght from the MySQL docs, I kid you not
+        if select._limit is not None:
+            text +=  " \n LIMIT " + str(select._limit)
+        if select._offset is not None:
+            if select._limit is None:
+                # straight from the MySQL docs, I kid you not
                 text += " \n LIMIT 18446744073709551615"
-            text += " OFFSET " + str(select.offset)
+            text += " OFFSET " + str(select._offset)
         return text
 
 class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
index 4210a949747afcc8633ee461bc008a6ed455f976..b88bea663fca2c6ee6b63cbb2657f2cdee955113 100644 (file)
@@ -433,11 +433,6 @@ class OracleCompiler(ansisql.ANSICompiler):
     the use_ansi flag is False.
     """
 
-    def __init__(self, *args, **kwargs):
-        super(OracleCompiler, self).__init__(*args, **kwargs)
-        # we have to modify SELECT objects a little bit, so store state here
-        self._select_state = {}
-        
     def default_from(self):
         """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
 
@@ -472,7 +467,7 @@ class OracleCompiler(ansisql.ANSICompiler):
 
             self._outertable = None
 
-        self.wheres[join].accept_visitor(self)
+        self.traverse_single(self.wheres[join])
 
     def visit_insert_sequence(self, column, sequence, parameters):
         """This is the `sequence` equivalent to ``ANSICompiler``'s
@@ -508,74 +503,35 @@ class OracleCompiler(ansisql.ANSICompiler):
 
     def _TODO_visit_compound_select(self, select):
         """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
-
-        if getattr(select, '_oracle_visit', False):
-            # cancel out the compiled order_by on the select
-            if hasattr(select, "order_by_clause"):
-                self.strings[select.order_by_clause] = ""
-            ansisql.ANSICompiler.visit_compound_select(self, select)
-            return
-
-        if select.limit is not None or select.offset is not None:
-            select._oracle_visit = True
-            # to use ROW_NUMBER(), an ORDER BY is required.
-            orderby = self.strings[select.order_by_clause]
-            if not orderby:
-                orderby = select.oid_column
-                self.traverse(orderby)
-                orderby = self.strings[orderby]
-            class SelectVisitor(sql.NoColumnVisitor):
-                def visit_select(self, select):
-                    select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
-            SelectVisitor().traverse(select)
-            limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
-            if select.offset is not None:
-                limitselect.append_whereclause("ora_rn>%d" % select.offset)
-                if select.limit is not None:
-                    limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
-            else:
-                limitselect.append_whereclause("ora_rn<=%d" % select.limit)
-            self.traverse(limitselect)
-            self.strings[select] = self.strings[limitselect]
-            self.froms[select] = self.froms[limitselect]
-        else:
-            ansisql.ANSICompiler.visit_compound_select(self, select)
+        pass
 
     def visit_select(self, select):
         """Look for ``LIMIT`` and OFFSET in a select statement, and if
         so tries to wrap it in a subquery with ``row_number()`` criterion.
         """
 
-        # TODO: put a real copy-container on Select and copy, or somehow make this
-        # not modify the Select statement
-        if self._select_state.get((select, 'visit'), False):
-            # cancel out the compiled order_by on the select
-            if hasattr(select, "order_by_clause"):
-                self.strings[select.order_by_clause] = ""
-            ansisql.ANSICompiler.visit_select(self, select)
-            return
-
-        if select.limit is not None or select.offset is not None:
-            self._select_state[(select, 'visit')] = True
+        if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
             # to use ROW_NUMBER(), an ORDER BY is required.
-            orderby = self.strings[select.order_by_clause]
+            orderby = self.strings[select._order_by_clause]
             if not orderby:
                 orderby = select.oid_column
                 self.traverse(orderby)
                 orderby = self.strings[orderby]
-            if not hasattr(select, '_oracle_visit'):
-                select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
-                select._oracle_visit = True
+                
+            oldselect = select
+            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
+            select._oracle_visit = True
+                
             limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
-            if select.offset is not None:
-                limitselect.append_whereclause("ora_rn>%d" % select.offset)
-                if select.limit is not None:
-                    limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
+            if select._offset is not None:
+                limitselect.append_whereclause("ora_rn>%d" % select._offset)
+                if select._limit is not None:
+                    limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
             else:
-                limitselect.append_whereclause("ora_rn<=%d" % select.limit)
+                limitselect.append_whereclause("ora_rn<=%d" % select._limit)
             self.traverse(limitselect)
-            self.strings[select] = self.strings[limitselect]
-            self.froms[select] = self.froms[limitselect]
+            self.strings[oldselect] = self.strings[limitselect]
+            self.froms[oldselect] = self.froms[limitselect]
         else:
             ansisql.ANSICompiler.visit_select(self, select)
 
index b48a709d8c14c82f770355688d91cb7c68488cb8..ea92cf7fd2d825c41fc1a3673244836fdb487eca 100644 (file)
@@ -423,16 +423,16 @@ class PGCompiler(ansisql.ANSICompiler):
         return text
 
     def visit_select_precolumns(self, select):
-        if select.distinct:
-            if type(select.distinct) == bool:
+        if select._distinct:
+            if type(select._distinct) == bool:
                 return "DISTINCT "
-            if type(select.distinct) == list:
+            if type(select._distinct) == list:
                 dist_set = "DISTINCT ON ("
-                for col in select.distinct:
+                for col in select._distinct:
                     dist_set += self.strings[col] + ", "
                     dist_set = dist_set[:-2] + ") "
                 return dist_set
-            return "DISTINCT ON (" + str(select.distinct) + ") "
+            return "DISTINCT ON (" + str(select._distinct) + ") "
         else:
             return ""
 
index 0bd7cf6aee4bb5cccd9d49827529f2f678dca688..e3282e028ae43fd845bb43ce112584fe2eb9aac0 100644 (file)
@@ -327,12 +327,12 @@ class SQLiteCompiler(ansisql.ANSICompiler):
 
     def limit_clause(self, select):
         text = ""
-        if select.limit is not None:
-            text +=  " \n LIMIT " + str(select.limit)
-        if select.offset is not None:
-            if select.limit is None:
+        if select._limit is not None:
+            text +=  " \n LIMIT " + str(select._limit)
+        if select._offset is not None:
+            if select._limit is None:
                 text += " \n LIMIT -1"
-            text += " OFFSET " + str(select.offset)
+            text += " OFFSET " + str(select._offset)
         else:
             text += " OFFSET 0"
         return text
index b1e1ee5cc9683bb5b158fe9d94ce912af3531024..de4d6b2aeb16cb6a16853eca8676f82f102605a3 100644 (file)
@@ -364,8 +364,90 @@ class ExecutionContext(object):
 
         raise NotImplementedError()
 
+class Compiled(sql.ClauseVisitor):
+    """Represent a compiled SQL expression.
+
+    The ``__str__`` method of the ``Compiled`` object should produce
+    the actual text of the statement.  ``Compiled`` objects are
+    specific to their underlying database dialect, and also may
+    or may not be specific to the columns referenced within a
+    particular set of bind parameters.  In no case should the
+    ``Compiled`` object be dependent on the actual values of those
+    bind parameters, even though it may reference those values as
+    defaults.
+    """
+
+    def __init__(self, dialect, statement, parameters, engine=None):
+        """Construct a new ``Compiled`` object.
+
+        statement
+          ``ClauseElement`` to be compiled.
+
+        parameters
+          Optional dictionary indicating a set of bind parameters
+          specified with this ``Compiled`` object.  These parameters
+          are the *default* values corresponding to the
+          ``ClauseElement``'s ``_BindParamClauses`` when the
+          ``Compiled`` is executed.  In the case of an ``INSERT`` or
+          ``UPDATE`` statement, these parameters will also result in
+          the creation of new ``_BindParamClause`` objects for each
+          key and will also affect the generated column list in an
+          ``INSERT`` statement and the ``SET`` clauses of an
+          ``UPDATE`` statement.  The keys of the parameter dictionary
+          can either be the string names of columns or
+          ``_ColumnClause`` objects.
+
+        engine
+          Optional Engine to compile this statement against.
+        """
+        self.dialect = dialect
+        self.statement = statement
+        self.parameters = parameters
+        self.engine = engine
+        self.can_execute = statement.supports_execution()
+
+    def compile(self):
+        self.traverse(self.statement)
+        self.after_compile()
+
+    def __str__(self):
+        """Return the string text of the generated SQL statement."""
+
+        raise NotImplementedError()
+
+    def get_params(self, **params):
+        """Deprecated.  use construct_params().  (supports unicode names)
+        """
+
+        return self.construct_params(params)
+
+    def construct_params(self, params):
+        """Return the bind params for this compiled object.
+
+        Will start with the default parameters specified when this
+        ``Compiled`` object was first constructed, and will override
+        those values with those sent via `**params`, which are
+        key/value pairs.  Each key should match one of the
+        ``_BindParamClause`` objects compiled into this object; either
+        the `key` or `shortname` property of the ``_BindParamClause``.
+        """
+        raise NotImplementedError()
+
+    def execute(self, *multiparams, **params):
+        """Execute this compiled object."""
+
+        e = self.engine
+        if e is None:
+            raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.")
+        return e.execute_compiled(self, *multiparams, **params)
+
+    def scalar(self, *multiparams, **params):
+        """Execute this compiled object and return the result's scalar value."""
+
+        return self.execute(*multiparams, **params).scalar()
+
 
-class Connectable(sql.Executor):
+class Connectable(object):
     """Interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
 
     def contextual_connect(self):
@@ -522,7 +604,7 @@ class Connection(Connectable):
             raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
 
     def execute_default(self, default, **kwargs):
-        return default.accept_visitor(self.__engine.dialect.defaultrunner(self))
+        return self.__engine.dialect.defaultrunner(self).traverse_single(default)
 
     def execute_text(self, statement, *multiparams, **params):
         if len(multiparams) == 0:
@@ -729,7 +811,7 @@ class Engine(Connectable):
         else:
             conn = connection
         try:
-            element.accept_visitor(visitorcallable(conn, **kwargs))
+            visitorcallable(conn, **kwargs).traverse(element)
         finally:
             if connection is None:
                 conn.close()
@@ -1248,13 +1330,13 @@ class DefaultRunner(schema.SchemaVisitor):
         
     def get_column_default(self, column):
         if column.default is not None:
-            return column.default.accept_visitor(self)
+            return self.traverse_single(column.default)
         else:
             return None
 
     def get_column_onupdate(self, column):
         if column.onupdate is not None:
-            return column.onupdate.accept_visitor(self)
+            return self.traverse_single(column.onupdate)
         else:
             return None
 
index a9b93bc56414aa5dbf8b80214e5839ca33b202d6..15be667090c4aff6aeaabb79cc05a6ab30646b0a 100644 (file)
@@ -425,7 +425,7 @@ def _selectable_name(selectable):
     if isinstance(selectable, sql.Alias):
         return _selectable_name(selectable.selectable)
     elif isinstance(selectable, sql.Select):
-        return ''.join([_selectable_name(s) for s in selectable.froms])
+        return ''.join([_selectable_name(s) for s in selectable.get_display_froms()])
     elif isinstance(selectable, schema.Table):
         return selectable.name.capitalize()
     else:
index cb12611306ec3ca47d2e44235eeeb676dc9be58b..8b0878688d4d47efe7cec4d286412fa5664fd78e 100644 (file)
@@ -1546,7 +1546,7 @@ class Mapper(object):
         return obj
 
     def _deferred_inheritance_condition(self, needs_tables):
-        cond = self.inherit_condition.copy_container()
+        cond = self.inherit_condition
 
         param_names = []
         def visit_binary(binary):
@@ -1560,7 +1560,7 @@ class Mapper(object):
             elif rightcol not in needs_tables:
                 binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True)
                 param_names.append(rightcol)
-        mapperutil.BinaryVisitor(visit_binary).traverse(cond)
+        cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True)
         return cond, param_names
 
     def translate_row(self, tomapper, row):
index b5b8f830699bec0c4ceb675c6180ae23c1accd6f..79fa101d25b6c6d1a27b273e7c2fa6d074ee9015 100644 (file)
@@ -399,15 +399,13 @@ class PropertyLoader(StrategizedProperty):
         # if the target mapper loads polymorphically, adapt the clauses to the target's selectable
         if self.loads_polymorphic:
             if self.secondaryjoin:
-                self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container()
-                sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin)
-                self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
+                self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
+                self.polymorphic_primaryjoin = self.primaryjoin
             else:
-                self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
                 if self.direction is sync.ONETOMANY:
-                    sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+                    self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
                 elif self.direction is sync.MANYTOONE:
-                    sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+                    self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
                 self.polymorphic_secondaryjoin = None
             # load "polymorphic" versions of the columns present in "remote_side" - this is
             # important for lazy-clause generation which goes off the polymorphic target selectable
@@ -422,8 +420,8 @@ class PropertyLoader(StrategizedProperty):
                 else:
                     raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable "  + str(self.mapper.select_table))
         else:
-            self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
-            self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None
+            self.polymorphic_primaryjoin = self.primaryjoin
+            self.polymorphic_secondaryjoin = self.secondaryjoin
 
     def _post_init(self):
         if logging.is_info_enabled(self.logger):
@@ -466,17 +464,13 @@ class PropertyLoader(StrategizedProperty):
             return self._parent_join_cache[(parent, primary, secondary)]
         except KeyError:
             parent_equivalents = parent._get_equivalent_columns()
-            primaryjoin = self.polymorphic_primaryjoin.copy_container()
-            if self.secondaryjoin is not None:
-                secondaryjoin = self.polymorphic_secondaryjoin.copy_container()
-            else:
-                secondaryjoin = None
+            secondaryjoin = self.polymorphic_secondaryjoin
             if self.direction is sync.ONETOMANY:
-                sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+                primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
             elif self.direction is sync.MANYTOONE:
-                sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+                primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
             elif self.secondaryjoin:
-                sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+                primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
 
             if secondaryjoin is not None:
                 if secondary and not primary:
index 37f3232486fd63c21cfe42d1fe7fac964c8e9e92..0fb8939c931724bfba9887515506cdb02c4432ab 100644 (file)
@@ -331,15 +331,15 @@ class Query(object):
                 else:
                     if prop.secondary:
                         if create_aliases:
-                            join = prop.get_join(mapper, primary=True, secondary=False).copy_container()
+                            join = prop.get_join(mapper, primary=True, secondary=False)
                             secondary_alias = prop.secondary.alias()
                             if alias is not None:
-                                sql_util.ClauseAdapter(alias).traverse(join)
+                                join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
                             sql_util.ClauseAdapter(secondary_alias).traverse(join)
                             clause = clause.join(secondary_alias, join)
                             alias = prop.select_table.alias()
-                            join = prop.get_join(mapper, primary=False).copy_container()
-                            sql_util.ClauseAdapter(secondary_alias).traverse(join)
+                            join = prop.get_join(mapper, primary=False)
+                            join = sql_util.ClauseAdapter(secondary_alias).traverse(join, clone=True)
                             sql_util.ClauseAdapter(alias).traverse(join)
                             clause = clause.join(alias, join)
                         else:
@@ -347,11 +347,11 @@ class Query(object):
                             clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False))
                     else:
                         if create_aliases:
-                            join = prop.get_join(mapper).copy_container()
+                            join = prop.get_join(mapper)
                             if alias is not None:
-                                sql_util.ClauseAdapter(alias).traverse(join)
+                                join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
                             alias = prop.select_table.alias()
-                            sql_util.ClauseAdapter(alias).traverse(join)
+                            join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
                             clause = clause.join(alias, join)
                         else:
                             clause = clause.join(prop.select_table, prop.get_join(mapper))
@@ -401,7 +401,7 @@ class Query(object):
         For performance, only use subselect if `order_by` attribute is set.
         """
 
-        ops = {'distinct':self._distinct, 'order_by':self._order_by, 'from_obj':self._from_obj}
+        ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj}
 
         if self._order_by is not False:
             s1 = sql.select([col], self._criterion, **ops).alias('u')
@@ -781,12 +781,8 @@ class Query(object):
         # from there
         context = QueryContext(self)
         order_by = context.order_by
-        group_by = context.group_by
         from_obj = context.from_obj
         lockmode = context.lockmode
-        distinct = context.distinct
-        limit = context.limit
-        offset = context.offset
         if order_by is False:
             order_by = self.mapper.order_by
         if order_by is False:
@@ -821,20 +817,20 @@ class Query(object):
             else:
                 cf = []
 
-            s2 = sql.select(self.table.primary_key + list(cf), whereclause, use_labels=True, from_obj=from_obj, **context.select_args())
+            s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args())
             if order_by:
-                s2.order_by(*util.to_list(order_by))
+                s2 = s2.order_by(*util.to_list(order_by))
             s3 = s2.alias('tbl_row_count')
-            crit = s3.primary_key==self.table.primary_key
+            crit = s3.primary_key==self.primary_key_columns
             statement = sql.select([], crit, use_labels=True, for_update=for_update)
             # now for the order by, convert the columns to their corresponding columns
             # in the "rowcount" query, and tack that new order by onto the "rowcount" query
             if order_by:
-                statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
+                statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
         else:
             statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
             if order_by:
-                statement.order_by(*util.to_list(order_by))
+                statement.append_order_by(*util.to_list(order_by))
                 
             # for a DISTINCT query, you need the columns explicitly specified in order
             # to use it in "order_by".  ensure they are in the column criterion (particularly oid).
@@ -1101,7 +1097,7 @@ class QueryContext(OperationContext):
         ``QueryContext`` that can be applied to a ``sql.Select``
         statement.
         """
-        return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by}
+        return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None}
 
     def accept_option(self, opt):
         """Accept a ``MapperOption`` which will process (modify) the
index 8de2d00e5f3411c8eb253b3d1ec942cc72a2ffa5..b4a66a6dc3b6da8408d18c3385950ed319a75f23 100644 (file)
@@ -290,7 +290,7 @@ class LazyLoader(AbstractRelationLoader):
                 # based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper",
                 # probably via the query's own "mapper" property, and also use one of two "lazy" clauses,
                 # one against the "union" the other not
-                for primary_key in self.select_mapper.pks_by_table[self.select_mapper.mapped_table]:
+                for primary_key in self.select_mapper.primary_key: 
                     bind = self.lazyreverse[primary_key]
                     ident.append(params[bind.key])
                 return q.get(ident)
@@ -303,6 +303,15 @@ class LazyLoader(AbstractRelationLoader):
                 q = q.options(*options)
             q = q.filter(self.lazywhere).params(**params)
 
+            result = q.all()
+            if self.uselist:
+                return result
+            else:
+                if len(result):
+                    return result[0]
+                else:
+                    return None
+            
             if self.uselist:
                 return q.all()
             else:
@@ -378,16 +387,15 @@ class LazyLoader(AbstractRelationLoader):
                         sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True))
                 reverse[leftcol] = binds[col]
 
-        lazywhere = primaryjoin.copy_container()
+        lazywhere = primaryjoin
         li = mapperutil.BinaryVisitor(visit_binary)
         
         if not secondaryjoin or not reverse_direction:
-            li.traverse(lazywhere)
+            lazywhere = li.traverse(lazywhere, clone=True)
         
         if secondaryjoin is not None:
-            secondaryjoin = secondaryjoin.copy_container()
             if reverse_direction:
-                li.traverse(secondaryjoin)
+                secondaryjoin = li.traverse(secondaryjoin, clone=True)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
         return (lazywhere, binds, reverse)
     _create_lazy_clause = classmethod(_create_lazy_clause)
@@ -461,18 +469,18 @@ class EagerLoader(AbstractRelationLoader):
                 else:
                     aliasizer = sql_util.ClauseAdapter(self.eagertarget).\
                         chain(sql_util.ClauseAdapter(self.eagersecondary))
-                self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container()
-                aliasizer.traverse(self.eagersecondaryjoin)
-                self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
-                aliasizer.traverse(self.eagerprimary)
+                self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin
+                self.eagersecondaryjoin = aliasizer.traverse(self.eagersecondaryjoin, clone=True)
+                self.eagerprimary = eagerloader.polymorphic_primaryjoin
+                self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True)
             else:
-                self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
+                self.eagerprimary = eagerloader.polymorphic_primaryjoin
                 if parentclauses is not None: 
                     aliasizer = sql_util.ClauseAdapter(self.eagertarget)
                     aliasizer.chain(sql_util.ClauseAdapter(parentclauses.eagertarget, exclude=eagerloader.parent_property.remote_side))
                 else:
                     aliasizer = sql_util.ClauseAdapter(self.eagertarget)
-                aliasizer.traverse(self.eagerprimary)
+                self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True)
 
             if eagerloader.order_by:
                 self.eager_order_by = sql_util.ClauseAdapter(self.eagertarget).copy_and_process(util.to_list(eagerloader.order_by))
@@ -492,8 +500,15 @@ class EagerLoader(AbstractRelationLoader):
             if column in self.extra_cols:
                 return self.extra_cols[column]
             
-            aliased_column = column.copy_container()
-            sql_util.ClauseAdapter(self.eagertarget).traverse(aliased_column)
+            aliased_column = column
+            # for column-level subqueries, swap out its selectable with our
+            # eager version as appropriate, and manually build the 
+            # "correlation" list of the subquery.  
+            class ModifySubquery(sql.ClauseVisitor):
+                def visit_select(s, select):
+                    select._should_correlate = False
+                    select.append_correlation(self.eagertarget)
+            aliased_column = sql_util.ClauseAdapter(self.eagertarget).chain(ModifySubquery()).traverse(aliased_column, clone=True)
             alias = self._aliashash(column.name)
             aliased_column = aliased_column.label(alias)
             self._row_decorator.map[column] = alias
@@ -561,7 +576,7 @@ class EagerLoader(AbstractRelationLoader):
             # this will locate the selectable inside of any containers it may be a part of (such
             # as a join).  if its inside of a join, we want to outer join on that join, not the 
             # selectable.
-            for fromclause in statement.froms:
+            for fromclause in statement.get_display_froms():
                 if fromclause is localparent.mapped_table:
                     towrap = fromclause
                     break
@@ -571,7 +586,7 @@ class EagerLoader(AbstractRelationLoader):
                         break
             else:
                 raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table))
-        
+
         try:
             clauses = self.clauses[parentclauses]
         except KeyError:
@@ -584,16 +599,17 @@ class EagerLoader(AbstractRelationLoader):
         if self.secondaryjoin is not None:
             statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin)
             if self.order_by is False and self.secondary.default_order_by() is not None:
-                statement.order_by(*clauses.eagersecondary.default_order_by())
+                statement.append_order_by(*clauses.eagersecondary.default_order_by())
         else:
             statement._outerjoin = towrap.outerjoin(clauses.eagertarget, clauses.eagerprimary)
             if self.order_by is False and clauses.eagertarget.default_order_by() is not None:
-                statement.order_by(*clauses.eagertarget.default_order_by())
+                statement.append_order_by(*clauses.eagertarget.default_order_by())
 
         if clauses.eager_order_by:
-            statement.order_by(*util.to_list(clauses.eager_order_by))
-                
+            statement.append_order_by(*util.to_list(clauses.eager_order_by))
+        
         statement.append_from(statement._outerjoin)
+
         for value in self.select_mapper.props.values():
             value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper)
 
index 5b2d229c4b57f461e0442f2927ba0a92d78ad934..713064d9cc1eb9daa37a5ef0f3dd800e1f94dad8 100644 (file)
@@ -28,6 +28,8 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', '
 class SchemaItem(object):
     """Base class for items that define a database schema."""
 
+    __metaclass__ = sql._FigureVisitName
+
     def _init_items(self, *args):
         """Initialize the list of child items for this SchemaItem."""
 
@@ -128,7 +130,7 @@ def _get_table_key(name, schema):
     else:
         return schema + "." + name
 
-class _TableSingleton(type):
+class _TableSingleton(sql._FigureVisitName):
     """A metaclass used by the ``Table`` object to provide singleton behavior."""
 
     def __call__(self, name, metadata, *args, **kwargs):
@@ -721,11 +723,6 @@ class ForeignKey(SchemaItem):
 
     column = property(lambda s: s._init_column())
 
-    def accept_visitor(self, visitor):
-        """Call the `visit_foreign_key` method on the given visitor."""
-
-        visitor.visit_foreign_key(self)
-
     def _get_parent(self):
         return self.parent
 
@@ -777,9 +774,6 @@ class PassiveDefault(DefaultGenerator):
         super(PassiveDefault, self).__init__(**kwargs)
         self.arg = arg
 
-    def accept_visitor(self, visitor):
-        return visitor.visit_passive_default(self)
-
     def __repr__(self):
         return "PassiveDefault(%s)" % repr(self.arg)
 
@@ -794,13 +788,12 @@ class ColumnDefault(DefaultGenerator):
         super(ColumnDefault, self).__init__(**kwargs)
         self.arg = arg
 
-    def accept_visitor(self, visitor):
-        """Call the visit_column_default method on the given visitor."""
-
+    def _visit_name(self):
         if self.for_update:
-            return visitor.visit_column_onupdate(self)
+            return "column_onupdate"
         else:
-            return visitor.visit_column_default(self)
+            return "column_default"
+    __visit_name__ = property(_visit_name)
 
     def __repr__(self):
         return "ColumnDefault(%s)" % repr(self.arg)
@@ -834,10 +827,6 @@ class Sequence(DefaultGenerator):
     def drop(self, connectable=None, checkfirst=True):
        self.get_engine(connectable=connectable).drop(self, checkfirst=checkfirst)
 
-    def accept_visitor(self, visitor):
-        """Call the visit_seauence method on the given visitor."""
-
-        return visitor.visit_sequence(self)
 
 class Constraint(SchemaItem):
     """Represent a table-level ``Constraint`` such as a composite primary key, foreign key, or unique constraint.
@@ -876,11 +865,12 @@ class CheckConstraint(Constraint):
         super(CheckConstraint, self).__init__(name)
         self.sqltext = sqltext
 
-    def accept_visitor(self, visitor):
+    def _visit_name(self):
         if isinstance(self.parent, Table):
-            visitor.visit_check_constraint(self)
+            return "check_constraint"
         else:
-            visitor.visit_column_check_constraint(self)
+            return "column_check_constraint"
+    __visit_name__ = property(_visit_name)
 
     def _set_parent(self, parent):
         self.parent = parent
@@ -909,9 +899,6 @@ class ForeignKeyConstraint(Constraint):
         for (c, r) in zip(self.__colnames, self.__refcolnames):
             self.append_element(c,r)
 
-    def accept_visitor(self, visitor):
-        visitor.visit_foreign_key_constraint(self)
-
     def append_element(self, col, refcol):
         fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
         fk._set_parent(self.table.c[col])
@@ -935,9 +922,6 @@ class PrimaryKeyConstraint(Constraint):
         for c in self.__colnames:
             self.append_column(table.c[c])
 
-    def accept_visitor(self, visitor):
-        visitor.visit_primary_key_constraint(self)
-
     def add(self, col):
         self.append_column(col)
 
@@ -969,9 +953,6 @@ class UniqueConstraint(Constraint):
     def append_column(self, col):
         self.columns.add(col)
 
-    def accept_visitor(self, visitor):
-        visitor.visit_unique_constraint(self)
-
     def copy(self):
         return UniqueConstraint(name=self.name, *self.__colnames)
 
@@ -1048,9 +1029,6 @@ class Index(SchemaItem):
         else:
             self.get_engine().drop(self)
 
-    def accept_visitor(self, visitor):
-        visitor.visit_index(self)
-
     def __str__(self):
         return repr(self)
 
@@ -1063,6 +1041,8 @@ class Index(SchemaItem):
 class MetaData(SchemaItem):
     """Represent a collection of Tables and their associated schema constructs."""
 
+    __visit_name__ = 'metadata'
+    
     def __init__(self, url=None, engine=None, **kwargs):
         """create a new MetaData object.
         
@@ -1174,9 +1154,6 @@ class MetaData(SchemaItem):
             connectable = self.get_engine()
         connectable.drop(self, checkfirst=checkfirst, tables=tables)
 
-    def accept_visitor(self, visitor):
-        visitor.visit_metadata(self)
-
     def _derived_metadata(self):
         return self
 
@@ -1186,6 +1163,8 @@ class BoundMetaData(MetaData):
     
     """
 
+    __visit_name__ = 'metadata'
+
     def __init__(self, engine_or_url, **kwargs):
         from sqlalchemy.engine.url import URL
         if isinstance(engine_or_url, basestring) or isinstance(engine_or_url, URL):
@@ -1200,6 +1179,8 @@ multiple ``Engine`` implementations on a dynamically alterable,
 thread-local basis.
     """
 
+    __visit_name__ = 'metadata'
+
     def __init__(self, threadlocal=True, **kwargs):
         if threadlocal:
             self.context = util.ThreadLocal()
@@ -1245,61 +1226,3 @@ class SchemaVisitor(sql.ClauseVisitor):
     """Define the visiting for ``SchemaItem`` objects."""
 
     __traverse_options__ = {'schema_visitor':True}
-
-    def visit_schema(self, schema):
-        """Visit a generic ``SchemaItem``."""
-        pass
-
-    def visit_table(self, table):
-        """Visit a ``Table``."""
-        pass
-
-    def visit_column(self, column):
-        """Visit a ``Column``."""
-        pass
-
-    def visit_foreign_key(self, join):
-        """Visit a ``ForeignKey``."""
-        pass
-
-    def visit_index(self, index):
-        """Visit an ``Index``."""
-        pass
-
-    def visit_passive_default(self, default):
-        """Visit a passive default."""
-        pass
-
-    def visit_column_default(self, default):
-        """Visit a ``ColumnDefault``."""
-        pass
-
-    def visit_column_onupdate(self, onupdate):
-        """Visit a ``ColumnDefault`` with the `for_update` flag set."""
-        pass
-
-    def visit_sequence(self, sequence):
-        """Visit a ``Sequence``."""
-        pass
-
-    def visit_primary_key_constraint(self, constraint):
-        """Visit a ``PrimaryKeyConstraint``."""
-        pass
-
-    def visit_foreign_key_constraint(self, constraint):
-        """Visit a ``ForeignKeyConstraint``."""
-        pass
-
-    def visit_unique_constraint(self, constraint):
-        """Visit a ``UniqueConstraint``."""
-        pass
-
-    def visit_check_constraint(self, constraint):
-        """Visit a ``CheckConstraint``."""
-        pass
-
-    def visit_column_check_constraint(self, constraint):
-        """Visit a ``CheckConstraint`` on a ``Column``."""
-        pass
-
-
index 31aa4788ac5f009beefae0d379bd6c4d76186d0a..afeb7dd6687948711815b62d7f66f5a159982635 100644 (file)
@@ -28,11 +28,10 @@ from sqlalchemy import util, exceptions, logging
 from sqlalchemy import types as sqltypes
 import string, re, random, sets
 
-
 __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
            'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
-           'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join',
-           'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc',
+           'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 
+           'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
            'between_', 'bindparam', 'case', 'cast', 'column', 'delete',
            'desc', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
            'insert', 'intersect', 'intersect_all', 'join', 'literal',
@@ -126,7 +125,7 @@ def join(left, right, onclause=None, **kwargs):
 
     return Join(left, right, onclause, **kwargs)
 
-def select(columns=None, whereclause = None, from_obj = [], **kwargs):
+def select(columns=None, whereclause=None, from_obj=[], **kwargs):
     """Returns a ``SELECT`` clause element.
 
     Similar functionality is also available via the ``select()`` method on any
@@ -237,7 +236,7 @@ def select(columns=None, whereclause = None, from_obj = [], **kwargs):
       
     """
 
-    return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs)
+    return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
 
 def subquery(alias, *args, **kwargs):
     """Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select].
@@ -253,7 +252,7 @@ def subquery(alias, *args, **kwargs):
     return Select(*args, **kwargs).alias(alias)
 
 def insert(table, values = None, **kwargs):
-    """Return an [sqlalchemy.sql#_Insert] clause element.
+    """Return an [sqlalchemy.sql#Insert] clause element.
 
     Similar functionality is available via the ``insert()`` 
     method on [sqlalchemy.schema#Table].
@@ -286,10 +285,10 @@ def insert(table, values = None, **kwargs):
     against the ``INSERT`` statement.
     """
 
-    return _Insert(table, values, **kwargs)
+    return Insert(table, values, **kwargs)
 
 def update(table, whereclause = None, values = None, **kwargs):
-    """Return an [sqlalchemy.sql#_Update] clause element.
+    """Return an [sqlalchemy.sql#Update] clause element.
 
     Similar functionality is available via the ``update()`` 
     method on [sqlalchemy.schema#Table].
@@ -326,10 +325,10 @@ def update(table, whereclause = None, values = None, **kwargs):
     against the ``UPDATE`` statement.
     """
 
-    return _Update(table, whereclause, values, **kwargs)
+    return Update(table, whereclause, values, **kwargs)
 
 def delete(table, whereclause = None, **kwargs):
-    """Return a [sqlalchemy.sql#_Delete] clause element.
+    """Return a [sqlalchemy.sql#Delete] clause element.
 
     Similar functionality is available via the ``delete()`` 
     method on [sqlalchemy.schema#Table].
@@ -343,7 +342,7 @@ def delete(table, whereclause = None, **kwargs):
 
     """
 
-    return _Delete(table, whereclause, **kwargs)
+    return Delete(table, whereclause, **kwargs)
 
 def and_(*clauses):
     """Join a list of clauses together using the ``AND`` operator.
@@ -384,7 +383,7 @@ def between(ctest, cleft, cright):
     provides similar functionality.
     """
 
-    return _BinaryExpression(ctest, and_(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type)), 'BETWEEN')
+    return _BinaryExpression(ctest, and_(_literal_as_binds(cleft, type=ctest.type), _literal_as_binds(cright, type=ctest.type)), 'BETWEEN')
 
 def between_(*args, **kwargs):
     """synonym for [sqlalchemy.sql#between()] (deprecated)."""
@@ -757,13 +756,13 @@ def _compound_select(keyword, *selects, **kwargs):
 def _is_literal(element):
     return not isinstance(element, ClauseElement)
 
-def _literals_as_text(element):
+def _literal_as_text(element):
     if _is_literal(element):
         return _TextClause(unicode(element))
     else:
         return element
 
-def _literals_as_binds(element, name='literal', type=None):
+def _literal_as_binds(element, name='literal', type=None):
     if _is_literal(element):
         if element is None:
             return null()
@@ -860,22 +859,46 @@ class ClauseVisitor(object):
     these options can indicate modifications to the set of 
     elements returned, such as to not return column collections
     (column_collections=False) or to return Schema-level items
-    (schema_visitor=True)."""
+    (schema_visitor=True).
+    
+    """
     __traverse_options__ = {}
-    def traverse(self, obj, stop_on=None):
-        stack = [obj]
-        traversal = []
-        while len(stack) > 0:
-            t = stack.pop()
-            if stop_on is None or t not in stop_on:
-                traversal.insert(0, t)
-                for c in t.get_children(**self.__traverse_options__):
-                    stack.append(c)
-        for target in traversal:
-            v = self
-            while v is not None:
-                target.accept_visitor(v)
-                v = getattr(v, '_next', None)
+    
+    def traverse_single(self, obj):
+        meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+        if meth:
+            return meth(obj)
+            
+    def traverse(self, obj, stop_on=None, clone=False):
+        if clone:
+            obj = obj._clone()
+
+        # entry flag indicates to also call a before-descent "enter_XXXX" method
+        entry = self.__traverse_options__.get('entry', False)
+
+        v = self
+        visitors = []
+        while v is not None:
+            visitors.append(v)
+            v = getattr(v, '_next', None)
+
+        def _trav(obj):
+            if stop_on is not None and obj in stop_on:
+                return
+            if entry:
+                for v in visitors:
+                    meth = getattr(v, "enter_%s" % obj.__visit_name__, None)
+                    if meth:
+                        meth(obj)
+
+            for c in obj.get_children(clone=clone, **self.__traverse_options__):
+                _trav(c)
+
+            for v in visitors:
+                meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+                if meth:
+                    meth(obj)
+        _trav(obj)
         return obj
         
     def chain(self, visitor):
@@ -887,78 +910,6 @@ class ClauseVisitor(object):
             tail = tail._next
         tail._next = visitor
         return self
-        
-    def visit_column(self, column):
-        pass
-    def visit_table(self, table):
-        pass
-    def visit_fromclause(self, fromclause):
-        pass
-    def visit_bindparam(self, bindparam):
-        pass
-    def visit_textclause(self, textclause):
-        pass
-    def visit_compound(self, compound):
-        pass
-    def visit_compound_select(self, compound):
-        pass
-    def visit_binary(self, binary):
-        pass
-    def visit_unary(self, unary):
-        pass
-    def visit_alias(self, alias):
-        pass
-    def visit_select(self, select):
-        pass
-    def visit_join(self, join):
-        pass
-    def visit_null(self, null):
-        pass
-    def visit_clauselist(self, list):
-        pass
-    def visit_calculatedclause(self, calcclause):
-        pass
-    def visit_grouping(self, gr):
-        pass
-    def visit_function(self, func):
-        pass
-    def visit_cast(self, cast):
-        pass
-    def visit_label(self, label):
-        pass
-    def visit_typeclause(self, typeclause):
-        pass
-
-class LoggingClauseVisitor(ClauseVisitor):
-    """extends ClauseVisitor to include debug logging of all traversal.
-    
-    To install this visitor, set logging.DEBUG for 
-    'sqlalchemy.sql.ClauseVisitor' **before** you import the 
-    sqlalchemy.sql module.
-    """
-
-    def traverse(self, obj, stop_on=None):
-        stack = [(obj, "")]
-        traversal = []
-        while len(stack) > 0:
-            (t, indent) = stack.pop()
-            if stop_on is None or t not in stop_on:
-                traversal.insert(0, (t, indent))
-                for c in t.get_children(**self.__traverse_options__):
-                    stack.append((c, indent + "    "))
-        
-        for (target, indent) in traversal:
-            self.logger.debug(indent + repr(target))
-            v = self
-            while v is not None:
-                target.accept_visitor(v)
-                v = getattr(v, '_next', None)
-        return obj
-
-LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor)
-
-if logging.is_debug_enabled(LoggingClauseVisitor.logger):
-    ClauseVisitor=LoggingClauseVisitor
 
 class NoColumnVisitor(ClauseVisitor):
     """a ClauseVisitor that will not traverse the exported Column 
@@ -971,109 +922,31 @@ class NoColumnVisitor(ClauseVisitor):
     """
     
     __traverse_options__ = {'column_collections':False}
-    
-class Executor(object):
-    """Interface representing a "thing that can produce Compiled objects 
-    and execute them"."""
 
-    def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
-        """Execute a Compiled object."""
-
-        raise NotImplementedError()
-
-    def compiler(self, statement, parameters, **kwargs):
-        """Return a Compiled object for the given statement and parameters."""
-
-        raise NotImplementedError()
-
-class Compiled(ClauseVisitor):
-    """Represent a compiled SQL expression.
-
-    The ``__str__`` method of the ``Compiled`` object should produce
-    the actual text of the statement.  ``Compiled`` objects are
-    specific to their underlying database dialect, and also may
-    or may not be specific to the columns referenced within a
-    particular set of bind parameters.  In no case should the
-    ``Compiled`` object be dependent on the actual values of those
-    bind parameters, even though it may reference those values as
-    defaults.
-    """
-
-    def __init__(self, dialect, statement, parameters, engine=None):
-        """Construct a new ``Compiled`` object.
-
-        statement
-          ``ClauseElement`` to be compiled.
-
-        parameters
-          Optional dictionary indicating a set of bind parameters
-          specified with this ``Compiled`` object.  These parameters
-          are the *default* values corresponding to the
-          ``ClauseElement``'s ``_BindParamClauses`` when the
-          ``Compiled`` is executed.  In the case of an ``INSERT`` or
-          ``UPDATE`` statement, these parameters will also result in
-          the creation of new ``_BindParamClause`` objects for each
-          key and will also affect the generated column list in an
-          ``INSERT`` statement and the ``SET`` clauses of an
-          ``UPDATE`` statement.  The keys of the parameter dictionary
-          can either be the string names of columns or
-          ``_ColumnClause`` objects.
-
-        engine
-          Optional Engine to compile this statement against.
-        """
-        self.dialect = dialect
-        self.statement = statement
-        self.parameters = parameters
-        self.engine = engine
-        self.can_execute = statement.supports_execution()
-
-    def compile(self):
-        self.traverse(self.statement)
-        self.after_compile()
-
-    def __str__(self):
-        """Return the string text of the generated SQL statement."""
-
-        raise NotImplementedError()
 
-    def get_params(self, **params):
-        """Deprecated.  use construct_params().  (supports unicode names)
-        """
-
-        return self.construct_params(params)
-
-    def construct_params(self, params):
-        """Return the bind params for this compiled object.
-
-        Will start with the default parameters specified when this
-        ``Compiled`` object was first constructed, and will override
-        those values with those sent via `**params`, which are
-        key/value pairs.  Each key should match one of the
-        ``_BindParamClause`` objects compiled into this object; either
-        the `key` or `shortname` property of the ``_BindParamClause``.
-        """
-        raise NotImplementedError()
+class _FigureVisitName(type):
+    def __init__(cls, clsname, bases, dict):
+        if not '__visit_name__' in cls.__dict__:
+            m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+            x = m.group(1)
+            x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+            cls.__visit_name__ = x.lower()
+        super(_FigureVisitName, cls).__init__(clsname, bases, dict)
         
-    def execute(self, *multiparams, **params):
-        """Execute this compiled object."""
-
-        e = self.engine
-        if e is None:
-            raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.")
-        return e.execute_compiled(self, *multiparams, **params)
-
-    def scalar(self, *multiparams, **params):
-        """Execute this compiled object and return the result's scalar value."""
-
-        return self.execute(*multiparams, **params).scalar()
-
 class ClauseElement(object):
     """Base class for elements of a programmatically constructed SQL
     expression.
     """
+    __metaclass__ = _FigureVisitName
+    
+    def _clone(self):
+        # shallow copy.  mutator operations always create
+        # clones of container objects.
+        c = self.__class__.__new__(self.__class__)
+        c.__dict__ = self.__dict__.copy()
+        return c
 
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         """Return objects represented in this ``ClauseElement`` that
         should be added to the ``FROM`` list of a query, when this
         ``ClauseElement`` is placed in the column clause of a
@@ -1082,7 +955,7 @@ class ClauseElement(object):
 
         raise NotImplementedError(repr(self))
 
-    def _hide_froms(self):
+    def _hide_froms(self, **modifiers):
         """Return a list of ``FROM`` clause elements which this
         ``ClauseElement`` replaces.
         """
@@ -1098,18 +971,16 @@ class ClauseElement(object):
 
         return self is other
 
-    def accept_visitor(self, visitor):
-        """Accept a ``ClauseVisitor`` and call the appropriate
-        ``visit_xxx`` method.
-        """
-
-        raise NotImplementedError(repr(self))
-    
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
         """return immediate child elements of this ``ClauseElement``.
         
         this is used for visit traversal.
         
+        clone indicates child items should be _cloned(), replacing
+        the elements contained by this element, and the cloned
+        copy returned.  this allows modifying traversals
+        to take place.
+        
         \**kwargs may contain flags that change the collection
         that is returned, for example to return a subset of items
         in order to cut down on larger traversals, or to return 
@@ -1127,18 +998,6 @@ class ClauseElement(object):
 
         return False
 
-    def copy_container(self):
-        """Return a copy of this ``ClauseElement``, if this
-        ``ClauseElement`` contains other ``ClauseElements``.
-
-        If this ``ClauseElement`` is not a container, it should return
-        self.  This is used to create copies of expression trees that
-        still reference the same *leaf nodes*.  The new structure can
-        then be restructured without affecting the original.
-        """
-
-        return self
-
     def _find_engine(self):
         """Default strategy for locating an engine within the clause element.
 
@@ -1429,9 +1288,6 @@ class Selectable(ClauseElement):
     def _selectable(self):
         return self
 
-    def accept_visitor(self, visitor):
-        raise NotImplementedError(repr(self))
-
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
 
@@ -1589,19 +1445,18 @@ class FromClause(Selectable):
     clause of a ``SELECT`` statement.
     """
 
+    __visit_name__ = 'fromclause'
+    
     def __init__(self, name=None):
         self.name = name
 
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         # this could also be [self], at the moment it doesnt matter to the Select object
         return []
 
     def default_order_by(self):
         return [self.oid_column]
 
-    def accept_visitor(self, visitor):
-        visitor.visit_fromclause(self)
-
     def count(self, whereclause=None, **params):
         if len(self.primary_key):
             col = list(self.primary_key)[0]
@@ -1643,6 +1498,13 @@ class FromClause(Selectable):
         FindCols().traverse(self)
         return ret
 
+    def is_derived_from(self, fromclause):
+        """return True if this FromClause is 'derived' from the given FromClause.
+        
+        An example would be an Alias of a Table is derived from that Table."""
+        
+        return False
+        
     def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
         """Given a ``ColumnElement``, return the exported
         ``ColumnElement`` object from this ``Selectable`` which
@@ -1701,6 +1563,15 @@ class FromClause(Selectable):
             self._export_columns()
             return getattr(self, name)
 
+    def _clone_from_clause(self):
+        # delete all the "generated" collections of columns for a newly cloned FromClause,
+        # so that they will be re-derived from the item.
+        # this is because FromClause subclasses, when cloned, need to reestablish new "proxied" 
+        # columns that are linked to the new item
+        for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'):
+            if hasattr(self, attr):
+                delattr(self, attr)
+
     columns = property(lambda s:s._get_exported_attribute('_columns'))
     c = property(lambda s:s._get_exported_attribute('_columns'))
     primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
@@ -1731,7 +1602,7 @@ class FromClause(Selectable):
         self._primary_key = ColumnCollection()
         self._foreign_keys = util.Set()
         self._orig_cols = {}
-        for co in self._adjusted_exportable_columns():
+        for co in self._flatten_exportable_columns():
             cp = self._proxy_column(co)
             for ci in cp.orig_set:
                 # note that some ambiguity is raised here, whereby a selectable might have more than 
@@ -1741,13 +1612,13 @@ class FromClause(Selectable):
             for ci in self.oid_column.orig_set:
                 self._orig_cols[ci] = self.oid_column
     
-    def _adjusted_exportable_columns(self):
+    def _flatten_exportable_columns(self):
         """return the list of ColumnElements represented within this FromClause's _exportable_columns"""
         export = self._exportable_columns()
         for column in export:
-            try:
+            if hasattr(column, '_selectable'):
                 s = column._selectable()
-            except AttributeError:
+            else:
                 continue
             for co in s.columns:
                 yield co
@@ -1764,6 +1635,8 @@ class _BindParamClause(ClauseElement, _CompareMixin):
     Public constructor is the ``bindparam()`` function.
     """
 
+    __visit_name__ = 'bindparam'
+    
     def __init__(self, key, value, shortname=None, type=None, unique=False):
         """Construct a _BindParamClause.
 
@@ -1805,15 +1678,9 @@ class _BindParamClause(ClauseElement, _CompareMixin):
         self.unique = unique
         self.type = sqltypes.to_instance(type)
 
-    def accept_visitor(self, visitor):
-        visitor.visit_bindparam(self)
-
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         return []
 
-    def copy_container(self):
-        return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique)
-
     def typeprocess(self, value, dialect):
         return self.type.dialect_impl(dialect).convert_bind_param(value, dialect)
 
@@ -1836,13 +1703,12 @@ class _TypeClause(ClauseElement):
     Used by the ``Case`` statement.
     """
 
+    __visit_name__ = 'typeclause'
+    
     def __init__(self, type):
         self.type = type
 
-    def accept_visitor(self, visitor):
-        visitor.visit_typeclause(self)
-
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         return []
 
 class _TextClause(ClauseElement):
@@ -1851,6 +1717,8 @@ class _TextClause(ClauseElement):
     Public constructor is the ``text()`` function.
     """
 
+    __visit_name__ = 'textclause'
+    
     def __init__(self, text = "", engine=None, bindparams=None, typemap=None):
         self._engine = engine
         self.bindparams = {}
@@ -1879,13 +1747,13 @@ class _TextClause(ClauseElement):
 
     columns = property(lambda s:[])
 
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self.bindparams = [b._clone() for b in self.bindparams]
+            
         return self.bindparams.values()
 
-    def accept_visitor(self, visitor):
-        visitor.visit_textclause(self)
-
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         return []
 
     def supports_execution(self):
@@ -1900,10 +1768,7 @@ class _Null(ColumnElement):
     def __init__(self):
         self.type = sqltypes.NULLTYPE
 
-    def accept_visitor(self, visitor):
-        visitor.visit_null(self)
-
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         return []
 
 class ClauseList(ClauseElement):
@@ -1911,14 +1776,16 @@ class ClauseList(ClauseElement):
 
     By default, is comma-separated, such as a column listing.
     """
-
+    __visit_name__ = 'clauselist'
+    
     def __init__(self, *clauses, **kwargs):
         self.clauses = []
         self.operator = kwargs.pop('operator', ',')
         self.group = kwargs.pop('group', True)
         self.group_contents = kwargs.pop('group_contents', True)
         for c in clauses:
-            if c is None: continue
+            if c is None: 
+                continue
             self.append(c)
 
     def __iter__(self):
@@ -1926,10 +1793,6 @@ class ClauseList(ClauseElement):
     def __len__(self):
         return len(self.clauses)
         
-    def copy_container(self):
-        clauses = [clause.copy_container() for clause in self.clauses]
-        return ClauseList(operator=self.operator, *clauses)
-
     def self_group(self, against=None):
         if self.group:
             return _Grouping(self)
@@ -1940,20 +1803,20 @@ class ClauseList(ClauseElement):
         # TODO: not sure if i like the 'group_contents' flag.  need to define the difference between
         # a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists.  flatten() method ?
         if self.group_contents:
-            self.clauses.append(_literals_as_text(clause).self_group(against=self.operator))
+            self.clauses.append(_literal_as_text(clause).self_group(against=self.operator))
         else:
-            self.clauses.append(_literals_as_text(clause))
+            self.clauses.append(_literal_as_text(clause))
 
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self.clauses = [clause._clone() for clause in self.clauses]
+            
         return self.clauses
 
-    def accept_visitor(self, visitor):
-        visitor.visit_clauselist(self)
-
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         f = []
         for c in self.clauses:
-            f += c._get_from_objects()
+            f += c._get_from_objects(**modifiers)
         return f
 
     def self_group(self, against=None):
@@ -1984,7 +1847,8 @@ class _CalculatedClause(ColumnElement):
     Extends ``ColumnElement`` to provide column-level comparison
     operators.
     """
-
+    __visit_name__ = 'calculatedclause'
+    
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
         self.type = sqltypes.to_instance(kwargs.get('type', None))
@@ -1998,17 +1862,13 @@ class _CalculatedClause(ColumnElement):
             
     key = property(lambda self:self.name or "_calc_")
 
-    def copy_container(self):
-        clauses = [clause.copy_container() for clause in self.clauses]
-        return _CalculatedClause(type=self.type, engine=self._engine, *clauses)
-
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self.clause_expr = self.clause_expr._clone()
         return self.clause_expr,
         
-    def accept_visitor(self, visitor):
-        visitor.visit_calculatedclause(self)
-    def _get_from_objects(self):
-        return self.clauses._get_from_objects()
+    def _get_from_objects(self, **modifiers):
+        return self.clauses._get_from_objects(**modifiers)
 
     def _bind_param(self, obj):
         return _BindParamClause(self.name, obj, type=self.type, unique=True)
@@ -2043,18 +1903,16 @@ class _Function(_CalculatedClause, FromClause):
 
     key = property(lambda self:self.name)
 
-
-    def append(self, clause):
-        self.clauses.append(_literals_as_binds(clause, self.name))
-
-    def copy_container(self):
-        clauses = [clause.copy_container() for clause in self.clauses]
-        return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self._clone_from_clause()
+        return _CalculatedClause.get_children(self, clone=clone, **kwargs)
         
-    def accept_visitor(self, visitor):
-        visitor.visit_function(self)
+    def append(self, clause):
+        self.clauses.append(_literal_as_binds(clause, self.name))
 
 class _Cast(ColumnElement):
+
     def __init__(self, clause, totype, **kwargs):
         if not hasattr(clause, 'label'):
             clause = literal(clause)
@@ -2062,13 +1920,15 @@ class _Cast(ColumnElement):
         self.clause = clause
         self.typeclause = _TypeClause(self.type)
 
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self.clause = self.clause._clone()
+            self.typeclause = self.typeclause._clone()
+            
         return self.clause, self.typeclause
-    def accept_visitor(self, visitor):
-        visitor.visit_cast(self)
 
-    def _get_from_objects(self):
-        return self.clause._get_from_objects()
+    def _get_from_objects(self, **modifiers):
+        return self.clause._get_from_objects(**modifiers)
 
     def _make_proxy(self, selectable, name=None):
         if name is not None:
@@ -2085,22 +1945,18 @@ class _UnaryExpression(ColumnElement):
         self.operator = operator
         self.modifier = modifier
         
-        self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier)
+        self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier)
         self.type = sqltypes.to_instance(type)
         self.negate = negate
         
-    def copy_container(self):
-        return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate)
-
-    def _get_from_objects(self):
-        return self.element._get_from_objects()
+    def _get_from_objects(self, **modifiers):
+        return self.element._get_from_objects(**modifiers)
 
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self.element = self.element._clone()
         return self.element,
 
-    def accept_visitor(self, visitor):
-        visitor.visit_unary(self)
-
     def compare(self, other):
         """Compare this ``_UnaryClause`` against the given ``ClauseElement``."""
 
@@ -2109,6 +1965,7 @@ class _UnaryExpression(ColumnElement):
             self.modifier == other.modifier and 
             self.element.compare(other.element)
         )
+
     def _negate(self):
         if self.negate is not None:
             return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
@@ -2120,24 +1977,22 @@ class _BinaryExpression(ColumnElement):
     """Represent an expression that is ``LEFT <operator> RIGHT``."""
     
     def __init__(self, left, right, operator, type=None, negate=None):
-        self.left = _literals_as_text(left).self_group(against=operator)
-        self.right = _literals_as_text(right).self_group(against=operator)
+        self.left = _literal_as_text(left).self_group(against=operator)
+        self.right = _literal_as_text(right).self_group(against=operator)
         self.operator = operator
         self.type = sqltypes.to_instance(type)
         self.negate = negate
 
-    def copy_container(self):
-        return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator)
-
-    def _get_from_objects(self):
-        return self.left._get_from_objects() + self.right._get_from_objects()
+    def _get_from_objects(self, **modifiers):
+        return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
 
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self.left = self.left._clone()
+            self.right = self.right._clone()
+            
         return self.left, self.right
 
-    def accept_visitor(self, visitor):
-        visitor.visit_binary(self)
-
     def compare(self, other):
         """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``."""
 
@@ -2159,13 +2014,15 @@ class _BinaryExpression(ColumnElement):
             return super(_BinaryExpression, self)._negate()
 
 class _Exists(_UnaryExpression):
+    __visit_name__ = _UnaryExpression.__visit_name__
+    
     def __init__(self, *args, **kwargs):
         kwargs['correlate'] = True
         s = select(*args, **kwargs).self_group()
         _UnaryExpression.__init__(self, s, operator="EXISTS")
 
-    def _hide_froms(self):
-        return self._get_from_objects()
+    def _hide_froms(self, **modifiers):
+        return self._get_from_objects(**modifiers)
 
 class Join(FromClause):
     """represent a ``JOIN`` construct between two ``FromClause``
@@ -2192,7 +2049,7 @@ class Join(FromClause):
 
     def _init_primary_key(self):
         pkcol = util.OrderedSet()
-        for col in self._adjusted_exportable_columns():
+        for col in self._flatten_exportable_columns():
             if col.primary_key:
                 pkcol.add(col)
         for col in list(pkcol):
@@ -2213,6 +2070,16 @@ class Join(FromClause):
             self._foreign_keys.add(f)
         return column
 
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self._clone_from_clause()
+            self.left = self.left._clone()
+            self.right = self.right._clone()
+            self.onclause = self.onclause._clone()
+            self.__folded_equivalents = None
+            self._init_primary_key()
+        return self.left, self.right, self.onclause
+
     def _match_primaries(self, primary, secondary):
         crit = []
         constraints = util.Set()
@@ -2300,12 +2167,6 @@ class Join(FromClause):
             
         return select(collist, whereclause, from_obj=[self], **kwargs)
 
-    def get_children(self, **kwargs):
-        return self.left, self.right, self.onclause
-
-    def accept_visitor(self, visitor):
-        visitor.visit_join(self)
-
     engine = property(lambda s:s.left.engine or s.right.engine)
 
     def alias(self, name=None):
@@ -2316,11 +2177,11 @@ class Join(FromClause):
 
         return self.select(use_labels=True, correlate=False).alias(name)
 
-    def _hide_froms(self):
-        return self.left._get_from_objects() + self.right._get_from_objects()
+    def _hide_froms(self, **modifiers):
+        return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
 
-    def _get_from_objects(self):
-        return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
+    def _get_from_objects(self, **modifiers):
+        return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
 
 class Alias(FromClause):
     """represent an alias, as typically applied to any 
@@ -2351,6 +2212,14 @@ class Alias(FromClause):
         self.encodedname = alias.encode('ascii', 'backslashreplace')
         self.case_sensitive = getattr(baseselectable, "case_sensitive", True)
 
+    def is_derived_from(self, fromclause):
+        x = self.selectable
+        while isinstance(x, Alias):
+            x = x.selectable
+            if x is fromclause:
+                return True
+        return False
+
     def supports_execution(self):
         return self.original.supports_execution()
 
@@ -2367,14 +2236,18 @@ class Alias(FromClause):
         #return self.selectable._exportable_columns()
         return self.selectable.columns
 
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self._clone_from_clause()
+            self.selectable = self.selectable._clone()
+            baseselectable = self.selectable
+            while isinstance(baseselectable, Alias):
+                baseselectable = baseselectable.selectable
+            self.original = baseselectable
         for c in self.c:
             yield c
         yield self.selectable
         
-    def accept_visitor(self, visitor):
-        visitor.visit_alias(self)
-
     def _get_from_objects(self):
         return [self]
 
@@ -2392,17 +2265,16 @@ class _Grouping(ColumnElement):
     _label = property(lambda s: s.elem._label)
     orig_set = property(lambda s:s.elem.orig_set)
     
-    def copy_container(self):
-        return _Grouping(self.elem.copy_container())
-        
-    def accept_visitor(self, visitor):
-        visitor.visit_grouping(self)
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self.elem = self.elem._clone()
         return self.elem,
-    def _hide_froms(self):
-        return self.elem._hide_froms()
-    def _get_from_objects(self):
-        return self.elem._get_from_objects()
+        
+    def _hide_froms(self, **modifiers):
+        return self.elem._hide_froms(**modifiers)
+        
+    def _get_from_objects(self, **modifiers):
+        return self.elem._get_from_objects(**modifiers)
         
 class _Label(ColumnElement):
     """represent a label, as typically applied to any column-level element
@@ -2429,17 +2301,16 @@ class _Label(ColumnElement):
     def _compare_self(self):
         return self.obj
     
-    def get_children(self, **kwargs):
+    def get_children(self, clone=False, **kwargs):
+        if clone:
+            self.obj = self.obj._clone()
         return self.obj,
 
-    def accept_visitor(self, visitor):
-        visitor.visit_label(self)
+    def _get_from_objects(self, **modifiers):
+        return self.obj._get_from_objects(**modifiers)
 
-    def _get_from_objects(self):
-        return self.obj._get_from_objects()
-
-    def _hide_froms(self):
-        return self.obj._hide_froms()
+    def _hide_froms(self, **modifiers):
+        return self.obj._hide_froms(**modifiers)
         
     def _make_proxy(self, selectable, name = None):
         if isinstance(self.obj, Selectable):
@@ -2489,7 +2360,11 @@ class _ColumnClause(ColumnElement):
         self.__label = None
         self.case_sensitive = case_sensitive
         self.is_literal = is_literal
-
+    
+    def _clone(self):
+        # ColumnClause is immutable
+        return self
+        
     def _get_label(self):
         """Generate a 'label' for this column.
         
@@ -2527,10 +2402,7 @@ class _ColumnClause(ColumnElement):
         else:
             return super(_ColumnClause, self).label(name)
             
-    def accept_visitor(self, visitor):
-        visitor.visit_column(self)
-
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         if self.table is not None:
             return [self.table]
         else:
@@ -2575,6 +2447,10 @@ class TableClause(FromClause):
             self.append_column(c)
         self._oid_column = _ColumnClause('oid', self, _is_oid=True)
 
+    def _clone(self):
+        # TableClause is immutable
+        return self
+
     def named_with_column(self):
         return True
 
@@ -2603,9 +2479,6 @@ class TableClause(FromClause):
         else:
             return []
             
-    def accept_visitor(self, visitor):
-        visitor.visit_table(self)
-
     def _exportable_columns(self):
         raise NotImplementedError()
 
@@ -2640,67 +2513,95 @@ class TableClause(FromClause):
     def delete(self, whereclause = None):
         return delete(self, whereclause)
 
-    def _get_from_objects(self):
+    def _get_from_objects(self, **modifiers):
         return [self]
 
 class _SelectBaseMixin(object):
     """Base class for ``Select`` and ``CompoundSelects``."""
 
+    def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, connectable=None, scalar=False, engine=None):
+        self.use_labels = use_labels
+        self.for_update = for_update
+        self._limit = limit
+        self._offset = offset
+        self._engine = connectable or engine
+        self.is_scalar = scalar
+        if self.is_scalar:
+            # allow corresponding_column to return None
+            self.orig_set = util.Set()
+        
+        self.append_order_by(*util.to_list(order_by, []))
+        self.append_group_by(*util.to_list(group_by, []))
+        
     def supports_execution(self):
         return True
 
+    def _generate(self):
+        s = self._clone()
+        s._clone_from_clause()
+        return s
+    
+    def limit(self, limit):
+        s = self._generate()
+        s._limit = limit
+        return s
+    
+    def offset(self, offset):
+        s = self._generate()
+        s._offset = offset
+        return s
+    
     def order_by(self, *clauses):
-        if len(clauses) == 1 and clauses[0] is None:
-            self.order_by_clause = ClauseList()
-        elif getattr(self, 'order_by_clause', None):
-            self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses)))
-        else:
-            self.order_by_clause = ClauseList(*clauses)
+        s = self._generate()
+        s.append_order_by(*clauses)
+        return s
 
     def group_by(self, *clauses):
-        if len(clauses) == 1 and clauses[0] is None:
-            self.group_by_clause = ClauseList()
-        elif getattr(self, 'group_by_clause', None):
-            self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses)))
+        s = self._generate()
+        s.append_group_by(*clauses)
+        return s
+
+    def append_order_by(self, *clauses):
+        if clauses == [None]:
+            self._order_by_clause = ClauseList()
         else:
-            self.group_by_clause = ClauseList(*clauses)
+            if getattr(self, '_order_by_clause', None):
+                clauses = list(self._order_by_clause) + list(clauses)
+            self._order_by_clause = ClauseList(*clauses)
 
+    def append_group_by(self, *clauses):
+        if clauses == [None]:
+            self._group_by_clause = ClauseList()
+        else:
+            if getattr(self, '_group_by_clause', None):
+                clauses = list(self._group_by_clause) + list(clauses)
+            self._group_by_clause = ClauseList(*clauses)
+            
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
 
-    def _get_from_objects(self):
-        if self.is_where or self.is_scalar:
+    def _get_from_objects(self, is_where=False, **modifiers):
+        if is_where or self.is_scalar:
             return []
         else:
             return [self]
 
 class CompoundSelect(_SelectBaseMixin, FromClause):
     def __init__(self, keyword, *selects, **kwargs):
-        _SelectBaseMixin.__init__(self)
+        self._should_correlate = kwargs.pop('correlate', False)
         self.keyword = keyword
-        self.use_labels = kwargs.pop('use_labels', False)
-        self.should_correlate = kwargs.pop('correlate', False)
-        self.for_update = kwargs.pop('for_update', False)
-        self.nowait = kwargs.pop('nowait', False)
-        self.limit = kwargs.pop('limit', None)
-        self.offset = kwargs.pop('offset', None)
-        self.is_compound = True
-        self.is_where = False
-        self.is_scalar = False
-        self.is_subquery = False
-
-        self.selects = selects
+        self.selects = []
 
         # some DBs do not like ORDER BY in the inner queries of a UNION, etc.
         for s in selects:
-            s.order_by(None)
+            if len(s._order_by_clause):
+                s = s.order_by(None)
+            self.selects.append(s)
 
-        self.group_by(*kwargs.pop('group_by', [None]))
-        self.order_by(*kwargs.pop('order_by', [None]))
-        if len(kwargs):
-            raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys()))
         self._col_map = {}
 
+        _SelectBaseMixin.__init__(self, **kwargs)
+
     name = property(lambda s:s.keyword + " statement")
 
     def self_group(self, against=None):
@@ -2728,12 +2629,18 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         col.orig_set = colset
         return col
 
-    def get_children(self, column_collections=True, **kwargs):
-        return (column_collections and list(self.c) or []) + \
-            [self.order_by_clause, self.group_by_clause] + list(self.selects)
-    def accept_visitor(self, visitor):
-        visitor.visit_compound_select(self)
+    def get_children(self, clone=False, column_collections=True, **kwargs):
+        if clone:
+            self._clone_from_clause()
+            self._col_map = {}
+            self.selects = [s._clone() for s in self.selects]
+            for attr in ('_order_by_clause', '_group_by_clause'):
+                if getattr(self, attr) is not None:
+                    setattr(self, attr, getattr(self, attr)._clone())
 
+        return (column_collections and list(self.c) or []) + \
+            [self._order_by_clause, self._group_by_clause] + list(self.selects)
+            
     def _find_engine(self):
         for s in self.selects:
             e = s._find_engine()
@@ -2748,127 +2655,212 @@ class Select(_SelectBaseMixin, FromClause):
     
     """
 
-    def __init__(self, columns=None, whereclause=None, from_obj=[],
-                 order_by=None, group_by=None, having=None,
-                 use_labels=False, distinct=False, for_update=False,
-                 engine=None, limit=None, offset=None, scalar=False,
-                 correlate=True):
+    def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, **kwargs):
         """construct a Select object.
         
         The public constructor for Select is the [sqlalchemy.sql#select()] function; 
         see that function for argument descriptions.
         """
-        _SelectBaseMixin.__init__(self)
-        self.__froms = util.OrderedSet()
-        self.__hide_froms = util.Set([self])
-        self.use_labels = use_labels
-        self.whereclause = None
-        self.having = None
-        self._engine = engine
-        self.limit = limit
-        self.offset = offset
-        self.for_update = for_update
-        self.is_compound = False
         
-        # indicates that this select statement should not expand its columns
-        # into the column clause of an enclosing select, and should instead
-        # act like a single scalar column
-        self.is_scalar = scalar
-        if scalar:
-            # allow corresponding_column to return None
-            self.orig_set = util.Set()
-            
-        # indicates if this select statement, as a subquery, should automatically correlate
-        # its FROM clause to that of an enclosing select, update, or delete statement.
-        # note that the "correlate" method can be used to explicitly add a value to be correlated.
-        self.should_correlate = correlate
-
-        # indicates if this select statement is a subquery inside another query
-        self.is_subquery = False
-
-        # indicates if this select statement is in the from clause of another query
-        self.is_selected_from = False
+        self._should_correlate = correlate
+        self._distinct = distinct
 
-        # indicates if this select statement is a subquery as a criterion
-        # inside of a WHERE clause
-        self.is_where = False
-
-        self.distinct = distinct
         self._raw_columns = []
-        self.__correlated = {}
-        self.__correlator = Select._CorrelatedVisitor(self, False)
-        self.__wherecorrelator = Select._CorrelatedVisitor(self, True)
-        self.__fromvisitor = Select._FromVisitor(self)
-
-        
-        self.order_by_clause = self.group_by_clause = None
+        self.__correlate = util.Set()
+        self._froms = util.OrderedSet()
+        self._whereclause = None
+        self._having = None
         
         if columns is not None:
             for c in columns:
                 self.append_column(c)
 
-        if order_by:
-            order_by = util.to_list(order_by)
-        if group_by:
-            group_by = util.to_list(group_by)
-        self.order_by(*(order_by or [None]))
-        self.group_by(*(group_by or [None]))
-        for c in self.order_by_clause:
-            self.__correlator.traverse(c)
-        for c in self.group_by_clause:
-            self.__correlator.traverse(c)
-
-        for f in from_obj:
-            self.append_from(f)
-
-        # whereclauses must be appended after the columns/FROM, since it affects
-        # the correlation of subqueries.  see test/sql/select.py SelectTest.testwheresubquery
+        if from_obj is not None:
+            for f in from_obj:
+                self.append_from(f)
+
         if whereclause is not None:
             self.append_whereclause(whereclause)
+            
         if having is not None:
             self.append_having(having)
 
+        _SelectBaseMixin.__init__(self, **kwargs)
 
-    class _CorrelatedVisitor(NoColumnVisitor):
-        """Visit a clause, locate any ``Select`` clauses, and tell
-        them that they should correlate their ``FROM`` list to that of
-        their parent.
-        """
-
-        def __init__(self, select, is_where):
-            NoColumnVisitor.__init__(self)
-            self.select = select
-            self.is_where = is_where
-
-        def visit_compound_select(self, cs):
-            self.visit_select(cs)
-
-        def visit_column(self, c):
-            pass
+    def get_display_froms(self, correlation_state=None):
+        froms = util.Set()
+        hide_froms = util.Set()
+        
+        for col in self._raw_columns:
+            for f in col._hide_froms():
+                hide_froms.add(f)
+            for f in col._get_from_objects():
+                froms.add(f)
 
-        def visit_table(self, c):
-            pass
+        if self._whereclause is not None:
+            for f in self._whereclause._get_from_objects(is_where=True):
+                froms.add(f)
+        
+        for elem in self._froms:
+            froms.add(elem)
+            for f in elem._get_from_objects():
+                froms.add(f)
 
-        def visit_select(self, select):
-            if select is self.select:
-                return
-            select.is_where = self.is_where
-            select.is_subquery = True
-            if not select.should_correlate:
-                return
-            [select.correlate(x) for x in self.select._Select__froms]
+        for elem in froms:
+            for f in elem._hide_froms():
+                hide_froms.add(f)
 
-    class _FromVisitor(NoColumnVisitor):
-        def __init__(self, select):
-            NoColumnVisitor.__init__(self)
-            self.select = select
+        froms = froms.difference(hide_froms)
+        
+        if len(froms) > 1:
+            corr = self.__correlate
+            if correlation_state is not None:
+                corr = correlation_state[self].get('correlate', util.Set()).union(corr)
+            return froms.difference(corr)
+        else:
+            return froms
+    
+    def locate_all_froms(self):
+        froms = util.Set()
+        for col in self._raw_columns:
+            for f in col._get_from_objects():
+                froms.add(f)
+
+        if self._whereclause is not None:
+            for f in self._whereclause._get_from_objects(is_where=True):
+                froms.add(f)
+        
+        for elem in self._froms:
+            froms.add(elem)
+            for f in elem._get_from_objects():
+                froms.add(f)
+        return froms
+        
+    def calculate_correlations(self, correlation_state):
+        if self not in correlation_state:
+            correlation_state[self] = {}
+
+        display_froms = self.get_display_froms(correlation_state)
+        
+        class CorrelatedVisitor(NoColumnVisitor):
+            def __init__(self, is_where=False, is_column=False, is_from=False):
+                self.is_where = is_where
+                self.is_column = is_column
+                self.is_from = is_from
+                
+            def visit_compound_select(self, cs):
+                self.visit_select(cs)
+
+            def visit_select(s, select):
+                if select not in correlation_state:
+                    correlation_state[select] = {}
+                    
+                if select is self:
+                    return
+                    
+                select_state = correlation_state[select]
+                if s.is_from:
+                    select_state['is_selected_from'] = True
+                if s.is_where:
+                    select_state['is_where'] = True
+                select_state['is_subquery'] = True
+
+                if select._should_correlate:
+                    corr = select_state.setdefault('correlate', util.Set())
+                    # not crazy about this part.  need to be clearer on what elements in the
+                    # subquery correspond to elements in the enclosing query.
+                    for f in display_froms:
+                        corr.add(f)
+                        for f2 in f._get_from_objects():
+                            corr.add(f2)
+        
+        col_vis = CorrelatedVisitor(is_column=True)
+        where_vis = CorrelatedVisitor(is_where=True)
+        from_vis = CorrelatedVisitor(is_from=True)
+    
+        for col in self._raw_columns:
+            col_vis.traverse(col)
+            for f in col._get_from_objects():
+                if f is not self:
+                    from_vis.traverse(f)
+
+        for col in list(self._order_by_clause) + list(self._group_by_clause):
+            col_vis.traverse(col)
+            
+        if self._whereclause is not None:
+            where_vis.traverse(self._whereclause)
+            for f in self._whereclause._get_from_objects(is_where=True):
+                if f is not self:
+                    from_vis.traverse(f)
+                
+        for elem in self._froms:
+            from_vis.traverse(elem)
+
+    def _get_inner_columns(self):
+        for c in self._raw_columns:
+            if hasattr(c, '_selectable'):
+                for co in c._selectable().columns:
+                    yield co
+            else:
+                yield c
+            
+    inner_columns = property(_get_inner_columns)
+    
+    def get_children(self, clone=False, column_collections=True, **kwargs):
+        if clone:
+            self._clone_from_clause()
+            self._raw_columns = [c._clone() for c in self._raw_columns]
+            self._recorrelate_froms([f._clone() for f in self._froms])
+            for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
+                if getattr(self, attr) is not None:
+                    setattr(self, attr, getattr(self, attr)._clone())
+        
+        return (column_collections and list(self.columns) or []) + \
+            list(self._froms) + \
+            [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+
+    def _recorrelate_froms(self, froms):
+        newcorrelate = util.Set()
+        for f in froms:
+            if f in self.__correlate:
+                newcorrelate.add(cl)
+                self.__correlate.remove(f)
+        self.__correlate = self.__correlate.union(newcorrelate)
+        self._froms = froms
+        
+    def column(self, column):
+        s = self._generate()
+        s.append_column(column)
+        return s
+    
+    def where(self, whereclause):
+        s = self._generate()
+        s.append_whereclause(whereclause)
+        return s
+    
+    def having(self, having):
+        s = self._generate()
+        s.append_having(having)
+        return s
+    
+    def distinct(self):
+        s = self._generate()
+        s.distinct = True
+        return s
+
+    def select_from(self, fromclause):
+        s = self._generate()
+        s.append_from(fromclause)
+        return s
+    
+    def correlate_to(self, fromclause):
+        s = self._generate()
+        s.append_correlation(fromclause)
+        return s
+    
+    def append_correlation(self, fromclause):
+        self.__correlate.add(fromclause)
             
-        def visit_select(self, select):
-            if select is self.select:
-                return
-            select.is_selected_from = True
-            select.is_subquery = True
-
     def append_column(self, column):
         if _is_literal(column):
             column = literal_column(str(column))
@@ -2878,22 +2870,26 @@ class Select(_SelectBaseMixin, FromClause):
 
         self._raw_columns.append(column)
 
-        if self.is_scalar and not hasattr(self, 'type'):
-            self.type = column.type
-        
-        # if the column is a Select statement itself,
-        # accept visitor
-        self.__correlator.traverse(column)
+    def append_whereclause(self, whereclause):
+        if self._whereclause  is not None:
+            self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
+        else:
+            self._whereclause = _literal_as_text(whereclause)
+            
+    def append_having(self, having):
+        if self._having is not None:
+            self._having = and_(self._having, _literal_as_text(having))
+        else:
+            self._having = _literal_as_text(having)
 
-        # visit the FROM objects of the column looking for more Selects
-        for f in column._get_from_objects():
-            if f is not self:
-                self.__correlator.traverse(f)
-        self._process_froms(column, False)
+    def append_from(self, fromclause):
+        if _is_literal(fromclause):
+            fromclause = FromClause(fromclause)
+        self._froms.add(fromclause)
 
     def _make_proxy(self, selectable, name):
         if self.is_scalar:
-            return self._raw_columns[0]._make_proxy(selectable, name)
+            return list(self.inner_columns)[0]._make_proxy(selectable, name)
         else:
             raise exceptions.InvalidRequestError("Not a scalar select statement")
 
@@ -2903,6 +2899,13 @@ class Select(_SelectBaseMixin, FromClause):
         else:
             return label(name, self)
 
+    def _get_type(self):
+        if self.is_scalar:
+            return list(self.inner_columns)[0].type
+        else:
+            return None
+    type = property(_get_type)
+
     def _exportable_columns(self):
         return [c for c in self._raw_columns if isinstance(c, Selectable)]
         
@@ -2912,51 +2915,11 @@ class Select(_SelectBaseMixin, FromClause):
         else:
             return column._make_proxy(self)
 
-    def _process_froms(self, elem, asfrom):
-        for f in elem._get_from_objects():
-            self.__fromvisitor.traverse(f)
-            self.__froms.add(f)
-        if asfrom:
-            self.__froms.add(elem)
-        for f in elem._hide_froms():
-            self.__hide_froms.add(f)
-
     def self_group(self, against=None):
         return _Grouping(self)
-    
-    def append_whereclause(self, whereclause):
-        self._append_condition('whereclause', whereclause)
-
-    def append_having(self, having):
-        self._append_condition('having', having)
-
-    def _append_condition(self, attribute, condition):
-        if type(condition) == str:
-            condition = _TextClause(condition)
-        self.__wherecorrelator.traverse(condition)
-        self._process_froms(condition, False)
-        if getattr(self, attribute) is not None:
-            setattr(self, attribute, and_(getattr(self, attribute), condition))
-        else:
-            setattr(self, attribute, condition)
-
-    def correlate(self, from_obj):
-        """Given a ``FROM`` object, correlate this ``SELECT`` statement to it.
-
-        This basically means the given from object will not come out
-        in this select statement's ``FROM`` clause when printed.
-        """
-
-        self.__correlated[from_obj] = from_obj
-
-    def append_from(self, fromclause):
-        if type(fromclause) == str:
-            fromclause = FromClause(fromclause)
-        self.__correlator.traverse(fromclause)
-        self._process_froms(fromclause, True)
 
     def _locate_oid_column(self):
-        for f in self.__froms:
+        for f in self.locate_all_froms():
             if f is self:
                 # we might be in our own _froms list if a column with us as the parent is attached,
                 # which includes textual columns.
@@ -2967,25 +2930,6 @@ class Select(_SelectBaseMixin, FromClause):
         else:
             return None
 
-    def _calc_froms(self):
-        f = self.__froms.difference(self.__hide_froms)
-        if (len(f) > 1):
-            return f.difference(self.__correlated)
-        else:
-            return f
-
-    froms = property(_calc_froms,
-                     doc="""A collection containing all elements
-                     of the ``FROM`` clause.""")
-    
-    def get_children(self, column_collections=True, **kwargs):
-        return (column_collections and list(self.columns) or []) + \
-            list(self.froms) + \
-            [x for x in (self.whereclause, self.having, self.order_by_clause, self.group_by_clause) if x is not None]
-
-    def accept_visitor(self, visitor):
-        visitor.visit_select(self)
-
     def union(self, other, **kwargs):
         return union(self, other, **kwargs)
 
@@ -2999,7 +2943,7 @@ class Select(_SelectBaseMixin, FromClause):
 
         if self._engine is not None:
             return self._engine
-        for f in self.__froms:
+        for f in self._froms:
             if f is self:
                 continue
             e = f.engine
@@ -3024,20 +2968,24 @@ class _UpdateBase(ClauseElement):
     def supports_execution(self):
         return True
 
-    class _SelectCorrelator(NoColumnVisitor):
-        def __init__(self, table):
-            NoColumnVisitor.__init__(self)
-            self.table = table
-            
-        def visit_select(self, select):
-            if select.should_correlate:
-                select.correlate(self.table)
-    
-    def _process_whereclause(self, whereclause):
-        if whereclause is not None:
-            _UpdateBase._SelectCorrelator(self.table).traverse(whereclause)
-        return whereclause
-        
+    def calculate_correlations(self, correlate_state):
+        class SelectCorrelator(NoColumnVisitor):
+            def visit_select(s, select):
+                if select._should_correlate:
+                    select_state = correlate_state.setdefault(select, {})
+                    corr = select_state.setdefault('correlate', util.Set())
+                    corr.add(self.table)
+                    
+        vis = SelectCorrelator()
+        
+        if self._whereclause is not None:
+            vis.traverse(self._whereclause)
+        
+        if getattr(self, 'parameters', None) is not None:
+            for key, value in self.parameters.items():
+                if isinstance(value, ClauseElement):
+                    vis.traverse(value)
+                
     def _process_colparams(self, parameters):
         """Receive the *values* of an ``INSERT`` or ``UPDATE``
         statement and construct appropriate bind parameters.
@@ -3054,11 +3002,10 @@ class _UpdateBase(ClauseElement):
                 i +=1
             parameters = pp
 
-        correlator = _UpdateBase._SelectCorrelator(self.table)
         for key in parameters.keys():
             value = parameters[key]
             if isinstance(value, ClauseElement):
-                correlator.traverse(value)
+                pass
             elif _is_literal(value):
                 if _is_literal(key):
                     col = self.table.c[key]
@@ -3073,7 +3020,7 @@ class _UpdateBase(ClauseElement):
     def _find_engine(self):
         return self.table.engine
 
-class _Insert(_UpdateBase):
+class Insert(_UpdateBase):
     def __init__(self, table, values=None):
         self.table = table
         self.select = None
@@ -3084,32 +3031,26 @@ class _Insert(_UpdateBase):
             return self.select,
         else:
             return ()
-    def accept_visitor(self, visitor):
-        visitor.visit_insert(self)
 
-class _Update(_UpdateBase):
+class Update(_UpdateBase):
     def __init__(self, table, whereclause, values=None):
         self.table = table
-        self.whereclause = self._process_whereclause(whereclause)
+        self._whereclause = whereclause
         self.parameters = self._process_colparams(values)
 
     def get_children(self, **kwargs):
-        if self.whereclause is not None:
-            return self.whereclause,
+        if self._whereclause is not None:
+            return self._whereclause,
         else:
             return ()
-    def accept_visitor(self, visitor):
-        visitor.visit_update(self)
 
-class _Delete(_UpdateBase):
+class Delete(_UpdateBase):
     def __init__(self, table, whereclause):
         self.table = table
-        self.whereclause = self._process_whereclause(whereclause)
+        self._whereclause = whereclause
 
     def get_children(self, **kwargs):
-        if self.whereclause is not None:
-            return self.whereclause,
+        if self._whereclause is not None:
+            return self._whereclause,
         else:
             return ()
-    def accept_visitor(self, visitor):
-        visitor.visit_delete(self)
index 7a67402318d395fef9c81eaa0ab60c134b3be891..36d127c98cbcb799548d504abf378b46957a5aa3 100644 (file)
@@ -125,7 +125,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
         process the new list.
         """
 
-        list_ = [o.copy_container() for o in list_]
+        list_ = list(list_)
         self.process_list(list_)
         return list_
 
@@ -137,7 +137,7 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
             if elem is not None:
                 list_[i] = elem
             else:
-                self.traverse(list_[i])
+                list_[i] = self.traverse(list_[i], clone=True)
     
     def visit_grouping(self, grouping):
         elem = self.convert_element(grouping.elem)
@@ -162,8 +162,25 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
         elem = self.convert_element(binary.right)
         if elem is not None:
             binary.right = elem
-
-    # TODO: visit_select().  
+    
+    def visit_select(self, select):
+        fr = util.OrderedSet()
+        for elem in select._froms:
+            n = self.convert_element(elem)
+            if n is None:
+                fr.add(elem)
+            else:
+                fr.add(n)
+        select._recorrelate_froms(fr)
+
+        col = []
+        for elem in select._raw_columns:
+            n = self.convert_element(elem)
+            if n is None:
+                col.append(elem)
+            else:
+                col.append(n)
+        select._raw_columns = col
     
 class ClauseAdapter(AbstractClauseProcessor):
     """Given a clause (like as in a WHERE criterion), locate columns
@@ -200,6 +217,9 @@ class ClauseAdapter(AbstractClauseProcessor):
         self.equivalents = equivalents
 
     def convert_element(self, col):
+        if isinstance(col, sql.FromClause):
+            if self.selectable.is_derived_from(col):
+                return self.selectable
         if not isinstance(col, sql.ColumnElement):
             return None
         if self.include is not None:
index 38f06584fc67ddd3c06c3fd0b533b85f7f761ade..a0088f1366247f39c08f211e1686a209f916d006 100644 (file)
@@ -55,9 +55,9 @@ else:
                 self[key] = value = self.creator(key)
                 return value
 
-def to_list(x):
+def to_list(x, default=None):
     if x is None:
-        return None
+        return default
     if not isinstance(x, list) and not isinstance(x, tuple):
         return [x]
     else:
index ad07b5b21a1eb90f94aa3e4041aab24637c987de..a83b81758ade9ae0892c5303756854a9f033004d 100644 (file)
@@ -177,7 +177,7 @@ class RelationsTest(AssertMixin):
         })
         session = create_session()
         query = session.query(tables.User)
-        x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+        x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
         print x.compile()
         self.assert_result(list(x), tables.User, *tables.user_result[1:3])
     def test_outerjointo_count(self):
@@ -189,7 +189,7 @@ class RelationsTest(AssertMixin):
         })
         session = create_session()
         query = session.query(tables.User)
-        x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
+        x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
         assert x==2
     def test_from(self):
         mapper(tables.User, tables.users, properties={
index a9482f28c919c9a1bc6d19f854f084bff46ddce1..7858689b1dc2a7bc07e5d6f62f2cc4765b6e96b8 100644 (file)
@@ -29,14 +29,15 @@ class PolymorphicCircularTest(testbase.ORMTest):
             Column('data', String(30))
             )
             
-        join = polymorphic_union(
-            {
-            'table3' : table1.join(table3),
-            'table2' : table1.join(table2),
-            'table1' : table1.select(table1.c.type.in_('table1', 'table1b')),
-            }, None, 'pjoin')
-
-        # still with us so far ?
+        #join = polymorphic_union(
+        #    {
+        #    'table3' : table1.join(table3),
+        #    'table2' : table1.join(table2),
+        #    'table1' : table1.select(table1.c.type.in_('table1', 'table1b')),
+        #    }, None, 'pjoin')
+        
+        join = table1.outerjoin(table2).outerjoin(table3).alias('pjoin')
+        #join = None
         
         class Table1(object):
             def __init__(self, name, data=None):
@@ -62,10 +63,10 @@ class PolymorphicCircularTest(testbase.ORMTest):
                 return "%s(%d, %s)" % (self.__class__.__name__, self.id, repr(str(self.data)))
                 
         try:
-            # this is how the mapping used to work.  insure that this raises an error now
+            # this is how the mapping used to work.  ensure that this raises an error now
             table1_mapper = mapper(Table1, table1,
                                    select_table=join,
-                                   polymorphic_on=join.c.type,
+                                   polymorphic_on=table1.c.type,
                                    polymorphic_identity='table1',
                                    properties={
                                     'next': relation(Table1, 
@@ -86,8 +87,8 @@ class PolymorphicCircularTest(testbase.ORMTest):
         # exception now.  since eager loading would never work for that relation anyway, its better that the user
         # gets an exception instead of it silently not eager loading.
         table1_mapper = mapper(Table1, table1,
-                               select_table=join,
-                               polymorphic_on=join.c.type,
+                               #select_table=join,
+                               polymorphic_on=table1.c.type,
                                polymorphic_identity='table1',
                                properties={
                                'next': relation(Table1, 
@@ -104,7 +105,10 @@ class PolymorphicCircularTest(testbase.ORMTest):
                                polymorphic_identity='table2')
 
         table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3')
-
+        
+        table1_mapper.compile()
+        assert table1_mapper.primary_key == [table1.c.id], table1_mapper.primary_key
+        
     def testone(self):
         self.do_testlist([Table1, Table2, Table1, Table2])
 
index b061416ae814bb02d090234e3f5c70acc1375789..ab53e36c452a45dddf435210a07051388b04700f 100644 (file)
@@ -328,6 +328,10 @@ class MapperTest(MapperSuperTest):
             'concat': column_property(f),
             'count': column_property(select([func.count(addresses.c.address_id)], users.c.user_id==addresses.c.user_id, scalar=True).label('count'))
         })
+
+        mapper(Address, addresses, properties={
+            'user':relation(User, lazy=False)
+        })    
         
         sess = create_session()
         l = sess.query(User).select()
@@ -336,24 +340,19 @@ class MapperTest(MapperSuperTest):
         assert l[0].concat == l[0].user_id * 2 == 14
         assert l[1].concat == l[1].user_id * 2 == 16
         
-        ### eager loads, not really working across all DBs, no column aliasing in place so
-        # results still wont be good for larger situations
-        clear_mappers()
-        mapper(Address, addresses, properties={
-            'user':relation(User, lazy=False)
-        })    
-        
-        mapper(User, users, properties={
-            'concat': column_property(f),
-        })
-
-        for x in range(0, 2):
-            sess.clear()
-            l = sess.query(Address).select()
-            for a in l:
-                print "User", a.user.user_id, a.user.user_name, a.user.concat
-            assert l[0].user.concat == l[0].user.user_id * 2 == 14
-            assert l[1].user.concat == l[1].user.user_id * 2 == 16
+        for option in (None, eagerload('user')):
+            for x in range(0, 2):
+                sess.clear()
+                l = sess.query(Address)
+                if option:
+                    l = l.options(option)
+                l = l.all()
+                for a in l:
+                    print "User", a.user.user_id, a.user.user_name, a.user.concat, a.user.count
+                assert l[0].user.concat == l[0].user.user_id * 2 == 14
+                assert l[1].user.concat == l[1].user.user_id * 2 == 16
+                assert l[0].user.count == 1
+                assert l[1].user.count == 3
             
         
     @testbase.unsupported('firebird') 
@@ -1114,6 +1113,7 @@ class EagerTest(MapperSuperTest):
         """test eager loading of a mapper which is against a select"""
         
         s = select([orders], orders.c.isopen==1).alias('openorders')
+        print "SELECT:", id(s), str(s)
         mapper(Order, s, properties={
             'user':relation(User, lazy=False)
         })
index dc3416089f6c3bb00903966f69f9bd0d1ffd6f45..2d87b391e6dbd59cfa6d947555698e1a4084207b 100644 (file)
@@ -1,6 +1,7 @@
 from testbase import PersistTest, AssertMixin
 import unittest, sys, os
 from sqlalchemy import *
+from sqlalchemy.orm import *
 from testbase import Table, Column
 import StringIO
 import testbase
index 7be1a3ffb6fda3b802c1c3c5fbf77ac32784897c..ebb3fe34c612b8417dbef3a0d67705e4aa9ecabc 100644 (file)
@@ -7,6 +7,8 @@ def suite():
         'sql.testtypes',
         'sql.constraints',
 
+        'sql.generative',
+        
         # SQL syntax
         'sql.select',
         'sql.selectable',
diff --git a/test/sql/generative.py b/test/sql/generative.py
new file mode 100644 (file)
index 0000000..befa5f9
--- /dev/null
@@ -0,0 +1,212 @@
+import testbase
+from sqlalchemy import *
+
+class TraversalTest(testbase.AssertMixin):
+    """test ClauseVisitor's traversal, particularly its ability to copy and modify
+    a ClauseElement in place."""
+    
+    def setUpAll(self):
+        global A, B
+        
+        # establish two ficticious ClauseElements.
+        # define deep equality semantics as well as deep identity semantics.
+        class A(ClauseElement):
+            def __init__(self, expr):
+                self.expr = expr
+
+            def accept_visitor(self, visitor):
+                visitor.visit_a(self)
+
+            def is_other(self, other):
+                return other is self
+            
+            def __eq__(self, other):
+                return other.expr == self.expr
+            
+            def __ne__(self, other):
+                return other.expr != self.expr
+                
+            def __str__(self):
+                return "A(%s)" % repr(self.expr)
+                
+        class B(ClauseElement):
+            def __init__(self, *items):
+                self.items = items
+
+            def is_other(self, other):
+                if other is not self:
+                    return False
+                for i1, i2 in zip(self.items, other.items):
+                    if i1 is not i2:
+                        return False
+                return True
+
+            def __eq__(self, other):
+                for i1, i2 in zip(self.items, other.items):
+                    if i1 != i2:
+                        return False
+                return True
+            
+            def __ne__(self, other):
+                for i1, i2 in zip(self.items, other.items):
+                    if i1 != i2:
+                        return True
+                return False
+                
+            def get_children(self, clone=False, **kwargs):
+                if clone:
+                    self.items = [i._clone() for i in self.items]
+                return self.items
+            
+            def accept_visitor(self, visitor):
+                visitor.visit_b(self)
+                
+            def __str__(self):
+                return "B(%s)" % repr([str(i) for i in self.items])
+    
+    def test_test_classes(self):
+        a1 = A("expr1")
+        struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+        struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+        struct3 = B(a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
+
+        assert a1.is_other(a1)
+        assert struct.is_other(struct)
+        assert struct == struct2
+        assert struct != struct3
+        assert not struct.is_other(struct2)
+        assert not struct.is_other(struct3)
+        
+    def test_clone(self):    
+        struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+        
+        class Vis(ClauseVisitor):
+            def visit_a(self, a):
+                pass
+            def visit_b(self, b):
+                pass
+                
+        vis = Vis()
+        s2 = vis.traverse(struct, clone=True)
+        assert struct == s2
+        assert not struct.is_other(s2)
+
+    def test_no_clone(self):    
+        struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+
+        class Vis(ClauseVisitor):
+            def visit_a(self, a):
+                pass
+            def visit_b(self, b):
+                pass
+
+        vis = Vis()
+        s2 = vis.traverse(struct, clone=False)
+        assert struct == s2
+        assert struct.is_other(s2)
+        
+    def test_change_in_place(self):
+        struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+        struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3"))
+        struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
+
+        class Vis(ClauseVisitor):
+            def visit_a(self, a):
+                if a.expr == "expr2":
+                    a.expr = "expr2modified"
+            def visit_b(self, b):
+                pass
+
+        vis = Vis()
+        s2 = vis.traverse(struct, clone=True)
+        assert struct != s2
+        assert struct is not s2
+        assert struct2 == s2
+
+        class Vis2(ClauseVisitor):
+            def visit_a(self, a):
+                if a.expr == "expr2b":
+                    a.expr = "expr2bmodified"
+            def visit_b(self, b):
+                pass
+
+        vis2 = Vis2()
+        s3 = vis2.traverse(struct, clone=True)
+        assert struct != s3
+        assert struct3 == s3
+
+class ClauseTest(testbase.AssertMixin):
+    def setUpAll(self):
+        global t1, t2
+        t1 = table("table1", 
+            column("col1"),
+            column("col2"),
+            column("col3"),
+            )
+        t2 = table("table2", 
+            column("col1"),
+            column("col2"),
+            column("col3"),
+            )
+            
+    def test_binary(self):
+        clause = t1.c.col2 == t2.c.col2
+        assert str(clause) == ClauseVisitor().traverse(clause, clone=True)
+    
+    def test_join(self):
+        clause = t1.join(t2, t1.c.col2==t2.c.col2)
+        c1 = str(clause)
+        assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True))
+    
+        class Vis(ClauseVisitor):
+            def visit_binary(self, binary):
+                binary.right = t2.c.col3
+                
+        clause2 = Vis().traverse(clause, clone=True)
+        assert c1 == str(clause)
+        assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
+    
+    def test_select(self):
+        s = t1.select()
+        s2 = select([s])
+        s2_assert = str(s2)
+        s3_assert = str(select([t1.select()], t1.c.col2==7))
+        class Vis(ClauseVisitor):
+            def visit_select(self, select):
+                select.append_whereclause(t1.c.col2==7)
+        s3 = Vis().traverse(s2, clone=True)
+        assert str(s3) == s3_assert
+        assert str(s2) == s2_assert
+        print str(s2)
+        print str(s3)
+        Vis().traverse(s2)
+        assert str(s2) == s3_assert
+
+        print "------------------"
+        
+        s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9)))
+        class Vis(ClauseVisitor):
+            def visit_select(self, select):
+                select.append_whereclause(t1.c.col3==9)
+        s4 = Vis().traverse(s3, clone=True)
+        print str(s3)
+        print str(s4)
+        assert str(s4) == s4_assert
+        assert str(s3) == s3_assert
+        
+        print "------------------"
+        s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9)))
+        class Vis(ClauseVisitor):
+            def visit_binary(self, binary):
+                if binary.left is t1.c.col3:
+                    binary.left = t1.c.col1
+                    binary.right = bindparam("table1_col1")
+        s5 = Vis().traverse(s4, clone=True)
+        print str(s4)
+        print str(s5)
+        assert str(s5) == s5_assert
+        assert str(s4) == s4_assert
+        
+        
+if __name__ == '__main__':
+    testbase.main()        
\ No newline at end of file
index 157c62300052135e1424cd8b1c8f558f6d1f11e7..ad2fd13e3c490bf68e86fee63f787404e9e17654 100644 (file)
@@ -143,6 +143,11 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={})
         
     def testwheresubquery(self):
+        s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s')
+        self.runtest(
+            select([users, s.c.street], from_obj=[s]),
+            """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
+
         # TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet.
         #self.runtest(
         #    table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), ""
@@ -223,6 +228,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
                          order_by = ['dist', places.c.nm]
                          )
         self.runtest(q, "SELECT places.id, places.nm, main_zip.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = main_zip.zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = main_zip.zipcode)) AS dist FROM places, zips AS main_zip ORDER BY dist, places.nm")
+
+        a1 = table2.alias('t2alias')
+        s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True)
+        j1 = table1.join(table2, table1.c.myid==table2.c.otherid)
+        s2 = select([table1, s1], from_obj=[j1])
+        self.runtest(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
     
     def testlabelcomparison(self):
         x = func.lala(table1.c.myid).label('foo')
@@ -410,7 +421,7 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
         s.append_column("column2")
         s.append_whereclause("column1=12")
         s.append_whereclause("column2=19")
-        s.order_by("column1")
+        s = s.order_by("column1")
         s.append_from("table1")
         self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1")
 
@@ -850,16 +861,6 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
         )
         
     
-    def testlateargs(self):
-        """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments
-        are sent"""
-        
-        self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'myid':'3', 'name':'jack'})
-
-        self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3'})
-
-        self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'})
-        
     def testcast(self):
         tbl = table('casttest',
                     column('id', Integer),
@@ -969,7 +970,18 @@ class CRUDTest(SQLTest):
         
     def testdelete(self):
         self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
-        
+    
+    def testcorrelateddelete(self):
+        # test a non-correlated WHERE clause
+        s = select([table2.c.othername], table2.c.otherid == 7)
+        u = delete(table1, table1.c.name==s)
+        self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)")
+
+        # test one that is actually correlated...
+        s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
+        u = table1.delete(table1.c.name==s)
+        self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
+            
 class SchemaTest(SQLTest):
     def testselect(self):
         # these tests will fail with the MS-SQL compiler since it will alias schema-qualified tables