]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- for hackers, refactored the "visitor" system of ClauseElement and
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 11 Mar 2007 20:52:02 +0000 (20:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 11 Mar 2007 20:52:02 +0000 (20:52 +0000)
SchemaItem so that the traversal of items is controlled by the
ClauseVisitor itself, using the method visitor.traverse(item).
accept_visitor() methods can still be called directly but will
not do any traversal of child items.  ClauseElement/SchemaItem now
have a configurable get_children() method to return the collection
of child elements for each parent object. This allows the full
traversal of items to be clear and unambiguous (as well as loggable),
with an easy method of limiting a traversal (just pass flags which
are picked up by appropriate get_children() methods). [ticket:501]
- accept_schema_visitor() methods removed, replaced with
get_children(schema_visitor=True)
- various docstring/changelog cleanup/reformatting

15 files changed:
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/sql_util.py
test/engine/reflection.py
test/sql/constraints.py

diff --git a/CHANGES b/CHANGES
index d8d4f4e5645b70caf0d2f687bb5195bc29144432..18b24cf54087817500e2d65f33f7da6c0eef04d0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,31 +6,53 @@
      with conflicting names, specify "unique=True" - this option is
      still used internally for all the auto-genererated (value-based) 
      bind parameters.    
+
     - exists() becomes useable as a standalone selectable, not just in a 
-    WHERE clause
+    WHERE clause, i.e. exists([columns], criterion).select()
+
     - correlated subqueries work inside of ORDER BY, GROUP BY
-    - fixed function execution with explicit connections, when you dont 
-    explicitly say "select()" off the function, i.e. 
+
+    - fixed function execution with explicit connections, i.e. 
     conn.execute(func.dosomething())
+
     - use_labels flag on select() wont auto-create labels for literal text
       column elements, since we can make no assumptions about the text. to
       create labels for literal columns, you can say "somecol AS somelabel",
       or use literal_column("somecol").label("somelabel")
+
     - quoting wont occur for literal columns when they are "proxied" into the
-    column collection for their selectable (is_literal flag is propigated)
-    - added "fold_equivalents" argument to Join.select(), which removes
+    column collection for their selectable (is_literal flag is propigated).
+    literal columns are specified via literal_column("somestring").
+
+    - added "fold_equivalents" boolean argument to Join.select(), which removes
     'duplicate' columns from the resulting column clause that are known to be 
     equivalent based on the join condition.  this is of great usage when 
     constructing subqueries of joins which Postgres complains about if 
     duplicate column names are present.
+
     - fixed use_alter flag on ForeignKeyConstraint [ticket:503]
+
     - fixed usage of 2.4-only "reversed" in topological.py [ticket:506]
+
+    - for hackers, refactored the "visitor" system of ClauseElement and
+    SchemaItem so that the traversal of items is controlled by the 
+    ClauseVisitor itself, using the method visitor.traverse(item).
+    accept_visitor() methods can still be called directly but will
+    not do any traversal of child items.  ClauseElement/SchemaItem now 
+    have a configurable get_children() method to return the collection
+    of child elements for each parent object. This allows the full
+    traversal of items to be clear and unambiguous (as well as loggable),
+    with an easy method of limiting a traversal (just pass flags which
+    are picked up by appropriate get_children() methods). [ticket:501]
+
 - oracle:
     - got binary working for any size input !  cx_oracle works fine,
       it was my fault as BINARY was being passed and not BLOB for
       setinputsizes (also unit tests werent even setting input sizes).
+
     - auto_setinputsizes defaults to True for Oracle, fixed cases where
       it improperly propigated bad types.
+
 - orm:
     - the full featureset of the SelectResults extension has been merged
       into a new set of methods available off of Query.  These methods
       as a list of tuples.  this corresponds to the documented behavior.
       So that instances match up properly, the "uniquing" is disabled when 
       this feature is used.
+
     - Query has add_entity() and add_column() generative methods.  these
       will add the given mapper/class or ColumnElement to the query at compile
-      time, and apply them to the instances method.  the user is responsible
+      time, and apply them to the instances() method.  the user is responsible
       for constructing reasonable join conditions (otherwise you can get
       full cartesian products).  result set is the list of tuples, non-uniqued.
+
     - strings and columns can also be sent to the *args of instances() where
       those exact result columns will be part of the result tuples.
+
     - a full select() construct can be passed to query.select() (which
       worked anyway), but also query.selectfirst(), query.selectone() which
       will be used as is (i.e. no query is compiled). works similarly to
       sending the results to instances().
-    - added "refresh-expire" cascade [ticket:492]
+      
+    - added "refresh-expire" cascade [ticket:492].  allows refresh() and
+      expire() calls to propigate along relationships.
+    
     - more fixes to polymorphic relations, involving proper lazy-clause
       generation on many-to-one relationships to polymorphic mappers 
       [ticket:493]. also fixes to detection of "direction", more specific
       targeting of columns that belong to the polymorphic union vs. those
       that dont.
+
     - put an aggressive check for "flushing object A with a collection
       of B's, but you put a C in the collection" error condition - 
       **even if C is a subclass of B**, unless B's mapper loads polymorphically.
       (since its not polymorphic) which breaks in bi-directional relationships
       (i.e. C has its A, but A's backref will lazyload it as a different 
       instance of type "B") [ticket:500]
+      This check is going to bite some of you who do this without issues, 
+      so the error message will also document a flag "enable_typechecks=False" 
+      to disable this checking.  But be aware that bi-directional relationships
+      in particular become fragile without this check.
+
 - extensions:
+
     - options() method on SelectResults now implemented "generatively"
-      like the rest of the SelectResults methods [ticket:472]
+      like the rest of the SelectResults methods [ticket:472].  But
+      you're going to just use Query now anyway.
+
     - query() method is added by assignmapper.  this helps with 
       navigating to all the new generative methods on Query.
     
index 5d5c42208ccbe10d386790d1ffe2867fc04c7ea5..ebaedca542d2b97adb37fd63e52b3dc572fce589 100644 (file)
@@ -75,6 +75,8 @@ class ANSICompiler(sql.Compiled):
     Compiles ClauseElements into ANSI-compliant SQL strings.
     """
 
+    __traverse_options__ = {'column_collections':False}
+
     def __init__(self, dialect, statement, parameters=None, **kwargs):
         """Construct a new ``ANSICompiler`` object.
 
@@ -388,13 +390,13 @@ class ANSICompiler(sql.Compiled):
         self.select_stack.append(select)
         for c in select._raw_columns:
             if isinstance(c, sql.Select) and c.is_scalar:
-                c.accept_visitor(self)
+                self.traverse(c)
                 inner_columns[self.get_str(c)] = c
                 continue
             if hasattr(c, '_selectable'):
                 s = c._selectable()
             else:
-                c.accept_visitor(self)
+                self.traverse(c)
                 inner_columns[self.get_str(c)] = c
                 continue
             for co in s.columns:
@@ -402,10 +404,10 @@ class ANSICompiler(sql.Compiled):
                     labelname = co._label
                     if labelname is not None:
                         l = co.label(labelname)
-                        l.accept_visitor(self)
+                        self.traverse(l)
                         inner_columns[labelname] = l
                     else:
-                        co.accept_visitor(self)
+                        self.traverse(co)
                         inner_columns[self.get_str(co)] = co
                 # TODO: figure this out, a ColumnClause with a select as a parent
                 # is different from any other kind of parent
@@ -414,10 +416,10 @@ class ANSICompiler(sql.Compiled):
                     # names look like table.colname, so add a label synonomous with
                     # the column name
                     l = co.label(co.name)
-                    l.accept_visitor(self)
+                    self.traverse(l)
                     inner_columns[self.get_str(l.obj)] = l
                 else:
-                    co.accept_visitor(self)
+                    self.traverse(co)
                     inner_columns[self.get_str(co)] = co
         self.select_stack.pop(-1)
 
@@ -443,7 +445,7 @@ class ANSICompiler(sql.Compiled):
                     else:
                         continue
                     clause = c==value
-                    clause.accept_visitor(self)
+                    self.traverse(clause)
                     whereclause = sql.and_(clause, whereclause)
                     self.visit_compound(whereclause)
 
@@ -596,7 +598,7 @@ class ANSICompiler(sql.Compiled):
         vis = DefaultVisitor()
         for c in insert_stmt.table.c:
             if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
-                c.accept_schema_visitor(vis)
+                vis.traverse(c)
 
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt, default_params)
@@ -610,7 +612,7 @@ class ANSICompiler(sql.Compiled):
                 return self.bindparam_string(p.key)
             else:
                 self.inline_params.add(col)
-                p.accept_visitor(self)
+                self.traverse(p)
                 if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
                     return "(" + self.get_str(p) + ")"
                 else:
@@ -631,7 +633,7 @@ class ANSICompiler(sql.Compiled):
         vis = OnUpdateVisitor()
         for c in update_stmt.table.c:
             if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
-                c.accept_schema_visitor(vis)
+                vis.traverse(c)
 
         self.isupdate = True
         colparams = self._get_colparams(update_stmt, default_params)
@@ -643,7 +645,7 @@ class ANSICompiler(sql.Compiled):
                 self.binds[p.shortname] = p
                 return self.bindparam_string(p.key)
             else:
-                p.accept_visitor(self)
+                self.traverse(p)
                 self.inline_params.add(col)
                 if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
                     return "(" + self.get_str(p) + ")"
@@ -734,7 +736,7 @@ class ANSISchemaBase(engine.SchemaIterator):
         findalterables = FindAlterables()
         for table in tables:
             for c in table.constraints:
-                c.accept_schema_visitor(findalterables)
+                findalterables.traverse(c)
         return alterables
 
 class ANSISchemaGenerator(ANSISchemaBase):
@@ -752,7 +754,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
     def visit_metadata(self, metadata):
         collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
         for table in collection:
-            table.accept_schema_visitor(self, traverse=False)
+            table.accept_visitor(self)
         if self.supports_alter():
             for alterable in self.find_alterables(collection):
                 self.add_foreignkey(alterable)
@@ -760,9 +762,9 @@ class ANSISchemaGenerator(ANSISchemaBase):
     def visit_table(self, table):
         for column in table.columns:
             if column.default is not None:
-                column.default.accept_schema_visitor(self, traverse=False)
+                column.default.accept_visitor(self)
             #if column.onupdate is not None:
-            #    column.onupdate.accept_schema_visitor(visitor, traverse=False)
+            #    column.onupdate.accept_visitor(visitor)
 
         self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
 
@@ -777,20 +779,20 @@ class ANSISchemaGenerator(ANSISchemaBase):
             if column.primary_key:
                 first_pk = True
             for constraint in column.constraints:
-                constraint.accept_schema_visitor(self, traverse=False)
+                constraint.accept_visitor(self)
 
         # On some DB order is significant: visit PK first, then the
         # other constraints (engine.ReflectionTest.testbasic failed on FB2)
         if len(table.primary_key):
-            table.primary_key.accept_schema_visitor(self, traverse=False)
+            table.primary_key.accept_visitor(self)
         for constraint in [c for c in table.constraints if c is not table.primary_key]:
-            constraint.accept_schema_visitor(self, traverse=False)
+            constraint.accept_visitor(self)
 
         self.append("\n)%s\n\n" % self.post_create_table(table))
         self.execute()
         if hasattr(table, 'indexes'):
             for index in table.indexes:
-                index.accept_schema_visitor(self, traverse=False)
+                index.accept_visitor(self)
 
     def post_create_table(self, table):
         return ''
@@ -890,7 +892,7 @@ class ANSISchemaDropper(ANSISchemaBase):
             for alterable in self.find_alterables(collection):
                 self.drop_foreignkey(alterable)
         for table in collection:
-            table.accept_schema_visitor(self, traverse=False)
+            table.accept_visitor(self)
 
     def supports_alter(self):
         return True
@@ -906,7 +908,7 @@ class ANSISchemaDropper(ANSISchemaBase):
     def visit_table(self, table):
         for column in table.columns:
             if column.default is not None:
-                column.default.accept_schema_visitor(self, traverse=False)
+                column.default.accept_visitor(self)
 
         self.append("\nDROP TABLE " + self.preparer.format_table(table))
         self.execute()
index 254ea6013169f76420af47d3f4897f17a7540001..8c3c71f6edd39d991da54d389d46e7ce1ecd4f01 100644 (file)
@@ -657,11 +657,11 @@ class MSSQLCompiler(ansisql.ANSICompiler):
         if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table):
             alias = table.alias()
             self.tablealiases[table] = alias
-            alias.accept_visitor(self)
+            self.traverse(alias)
             self.froms[('alias', table)] = self.froms[table]
             for c in alias.c:
-                c.accept_visitor(self)
-            alias.oid_column.accept_visitor(self)
+                self.traverse(c)
+            self.traverse(alias.oid_column)
             self.tablealiases[alias] = self.froms[table]
             self.froms[table] = self.froms[alias]
         else:
index 966834eb251252944c79924a5c017a64591736cf..1dba60c1d7419503eff9848c14842cbc38060dc3 100644 (file)
@@ -434,7 +434,7 @@ class OracleCompiler(ansisql.ANSICompiler):
 
             # now re-visit the onclause, which will be used as a where clause
             # (the first visit occured via the Join object itself right before it called visit_join())
-            join.onclause.accept_visitor(self)
+            self.traverse(join.onclause)
 
             self._outertable = None
 
@@ -488,12 +488,12 @@ class OracleCompiler(ansisql.ANSICompiler):
             orderby = self.strings[select.order_by_clause]
             if not orderby:
                 orderby = select.oid_column
-                orderby.accept_visitor(self)
+                self.traverse(orderby)
                 orderby = self.strings[orderby]
-            class SelectVisitor(sql.ClauseVisitor):
+            class SelectVisitor(sql.NoColumnVisitor):
                 def visit_select(self, select):
                     select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
-            select.accept_visitor(SelectVisitor())
+            SelectVisitor().traverse(select)
             limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
             if select.offset is not None:
                 limitselect.append_whereclause("ora_rn>%d" % select.offset)
@@ -501,7 +501,7 @@ class OracleCompiler(ansisql.ANSICompiler):
                     limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
             else:
                 limitselect.append_whereclause("ora_rn<=%d" % select.limit)
-            limitselect.accept_visitor(self)
+            self.traverse(limitselect)
             self.strings[select] = self.strings[limitselect]
             self.froms[select] = self.froms[limitselect]
         else:
@@ -527,7 +527,7 @@ class OracleCompiler(ansisql.ANSICompiler):
             orderby = self.strings[select.order_by_clause]
             if not orderby:
                 orderby = select.oid_column
-                orderby.accept_visitor(self)
+                self.traverse(orderby)
                 orderby = self.strings[orderby]
             select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
             limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
@@ -537,7 +537,7 @@ class OracleCompiler(ansisql.ANSICompiler):
                     limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
             else:
                 limitselect.append_whereclause("ora_rn<=%d" % select.limit)
-            limitselect.accept_visitor(self)
+            self.traverse(limitselect)
             self.strings[select] = self.strings[limitselect]
             self.froms[select] = self.froms[limitselect]
         else:
index 0a53da92841c3772ebb990393a2743100dd1d426..f79167abc3a65f431a491344c08727c24e45bbfa 100644 (file)
@@ -446,7 +446,7 @@ class Connection(Connectable):
             raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
 
     def execute_default(self, default, **kwargs):
-        return default.accept_schema_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs))
+        return default.accept_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs))
 
     def execute_text(self, statement, *multiparams, **params):
         if len(multiparams) == 0:
@@ -672,7 +672,7 @@ class Engine(sql.Executor, Connectable):
         else:
             conn = connection
         try:
-            element.accept_schema_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs), traverse=False)
+            element.accept_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs))
         finally:
             if connection is None:
                 conn.close()
@@ -1164,13 +1164,13 @@ class DefaultRunner(schema.SchemaVisitor):
 
     def get_column_default(self, column):
         if column.default is not None:
-            return column.default.accept_schema_visitor(self)
+            return column.default.accept_visitor(self)
         else:
             return None
 
     def get_column_onupdate(self, column):
         if column.onupdate is not None:
-            return column.onupdate.accept_schema_visitor(self)
+            return column.onupdate.accept_visitor(self)
         else:
             return None
 
index d28445be61bf6f2740406dd2610a11b996d4175f..74dd58a3f6d5bc185306d0b257a949a514f12a61 100644 (file)
@@ -767,7 +767,7 @@ class Mapper(object):
         vis = mapperutil.BinaryVisitor(visit_binary)
         for mapper in self.base_mapper().polymorphic_iterator():
             if mapper.inherit_condition is not None:
-                mapper.inherit_condition.accept_visitor(vis)
+                vis.traverse(mapper.inherit_condition)
         return result
 
     def add_properties(self, dict_of_properties):
index 2d10b2f9db9a1f1a148b5cad12977c078dd86254..c9a2dbe597dde558f5db95b7a20a0fadf827573c 100644 (file)
@@ -233,9 +233,9 @@ class PropertyLoader(StrategizedProperty):
         # error message in case its the "old" way.
         if self.loads_polymorphic:
             vis = sql_util.ColumnsInClause(self.mapper.select_table)
-            self.primaryjoin.accept_visitor(vis)
+            vis.traverse(self.primaryjoin)
             if self.secondaryjoin:
-                self.secondaryjoin.accept_visitor(vis)
+                vis.traverse(self.secondaryjoin)
             if vis.result:
                 raise exceptions.ArgumentError("In relationship '%s', primary and secondary join conditions must not include columns from the polymorphic 'select_table' argument as of SA release 0.3.4.  Construct join conditions using the base tables of the related mappers." % (str(self)))
 
@@ -251,9 +251,9 @@ class PropertyLoader(StrategizedProperty):
                     self._opposite_side.add(binary.right)
                 if binary.right in self.foreign_keys:
                     self._opposite_side.add(binary.left)
-            self.primaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary))
+            mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin)
             if self.secondaryjoin is not None:
-                self.secondaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary))
+                mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin)
         else:
             self.foreign_keys = util.Set()
             self._opposite_side = util.Set()
@@ -268,12 +268,12 @@ class PropertyLoader(StrategizedProperty):
                     if f.references(binary.left.table):
                         self.foreign_keys.add(binary.right)
                         self._opposite_side.add(binary.left)
-            self.primaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary))
+            mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin)
 
             if len(self.foreign_keys) == 0:
                 raise exceptions.ArgumentError("Cant locate any foreign key columns in primary join condition '%s' for relationship '%s'.  Specify 'foreign_keys' argument to indicate which columns in the join condition are foreign." %(str(self.primaryjoin), str(self)))
             if self.secondaryjoin is not None:
-                self.secondaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary))
+                mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin)
 
     def _determine_direction(self):
         """Determine our *direction*, i.e. do we represent one to
@@ -343,14 +343,14 @@ class PropertyLoader(StrategizedProperty):
         if self.loads_polymorphic:
             if self.secondaryjoin:
                 self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container()
-                self.polymorphic_secondaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table))
+                sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin)
                 self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
             else:
                 self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
                 if self.direction is sync.ONETOMANY:
-                    self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents))
+                    sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
                 elif self.direction is sync.MANYTOONE:
-                    self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents))
+                    sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
                 self.polymorphic_secondaryjoin = None
             # load "polymorphic" versions of the columns present in "remote_side" - this is
             # important for lazy-clause generation which goes off the polymorphic target selectable
@@ -411,11 +411,11 @@ class PropertyLoader(StrategizedProperty):
             else:
                 secondaryjoin = None
             if self.direction is sync.ONETOMANY:
-                primaryjoin.accept_visitor(sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents))
+                sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
             elif self.direction is sync.MANYTOONE:
-                primaryjoin.accept_visitor(sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents))
+                sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
             elif self.secondaryjoin:
-                primaryjoin.accept_visitor(sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents))
+                sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
 
             if secondaryjoin is not None:
                 j = primaryjoin & secondaryjoin
index a1c8b6af51aa39b147e242014f93bcc4c8a3df7e..a0b520f339c9cbe1aa406ccaa78209c78693c6e4 100644 (file)
@@ -775,7 +775,7 @@ class Query(object):
             # adapt the given WHERECLAUSE to adjust instances of this query's mapped 
             # table to be that of our select_table,
             # which may be the "polymorphic" selectable used by our mapper.
-            whereclause.accept_visitor(sql_util.ClauseAdapter(self.table))
+            sql_util.ClauseAdapter(self.table).traverse(whereclause)
 
             # if extra entities, adapt the criterion to those as well
             for m in self._entities:
@@ -783,7 +783,7 @@ class Query(object):
                     m = mapper.class_mapper(m)
                 if isinstance(m, mapper.Mapper):
                     table = m.select_table
-                    whereclause.accept_visitor(sql_util.ClauseAdapter(m.select_table))
+                    sql_util.ClauseAdapter(m.select_table).traverse(whereclause)
         
         # get/create query context.  get the ultimate compile arguments
         # from there
@@ -827,7 +827,7 @@ class Query(object):
                 order_by = util.to_list(order_by) or []
                 cf = sql_util.ColumnFinder()
                 for o in order_by:
-                    o.accept_visitor(cf)
+                    cf.traverse(o)
             else:
                 cf = []
 
index 8e19be5367e148366cd9c02dc49ed81f1b37230f..a295ed862d6e564ef57b30ea090e49192c7eb02d 100644 (file)
@@ -260,7 +260,7 @@ class LazyLoader(AbstractRelationLoader):
             class FindColumnInColumnClause(sql.ClauseVisitor):
                 def visit_column(self, c):
                     columns.append(c)
-            expr.accept_visitor(FindColumnInColumnClause())
+            FindColumnInColumnClause().traverse(expr)
             return len(columns) and columns[0] or None
         
         def col_in_collection(column, collection):
@@ -294,7 +294,7 @@ class LazyLoader(AbstractRelationLoader):
 
         lazywhere = primaryjoin.copy_container()
         li = mapperutil.BinaryVisitor(visit_binary)
-        lazywhere.accept_visitor(li)
+        li.traverse(lazywhere)
         
         if secondaryjoin is not None:
             secondaryjoin = secondaryjoin.copy_container()
@@ -363,16 +363,16 @@ class EagerLoader(AbstractRelationLoader):
                         eagerloader.secondary:self.eagersecondary
                         })
                 self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container()
-                self.eagersecondaryjoin.accept_visitor(self.aliasizer)
+                self.aliasizer.traverse(self.eagersecondaryjoin)
                 self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
-                self.eagerprimary.accept_visitor(self.aliasizer)
+                self.aliasizer.traverse(self.eagerprimary)
             else:
                 self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
                 self.aliasizer = sql_util.Aliasizer(self.target, aliases={self.target:self.eagertarget})
-                self.eagerprimary.accept_visitor(self.aliasizer)
+                self.aliasizer.traverse(self.eagerprimary)
 
             if parentclauses is not None:
-                self.eagerprimary.accept_visitor(parentclauses.aliasizer)
+                parentclauses.aliasizer.traverse(self.eagerprimary)
 
             if eagerloader.order_by:
                 self.eager_order_by = self._aliasize_orderby(eagerloader.order_by)
index 68fa9cee160df2ff08cc2ff2bc1e9a24d6e8d59e..8c70f8cf81ab6498e6fe20105ac6d8752a76f593 100644 (file)
@@ -80,8 +80,7 @@ class ClauseSynchronizer(object):
                         self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary))
 
         rules_added = len(self.syncrules)
-        processor = BinaryVisitor(compile_binary)
-        sqlclause.accept_visitor(processor)
+        BinaryVisitor(compile_binary).traverse(sqlclause)
         if len(self.syncrules) == rules_added:
             raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
 
index 52324e63eed53d484e4fe41e24fe7bd6b11ad2dd..78c31e9acd8173df0504ecbfec451d34d58eabbd 100644 (file)
@@ -42,7 +42,11 @@ class SchemaItem(object):
         """Associate with this SchemaItem's parent object."""
 
         raise NotImplementedError()
-
+    
+    def get_children(self, **kwargs):
+        """used to allow SchemaVisitor access"""
+        return []
+        
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
@@ -322,11 +326,14 @@ class Table(SchemaItem, sql.TableClause):
         metadata.tables[_get_table_key(self.name, self.schema)] = self
         self._metadata = metadata
 
-    def accept_schema_visitor(self, visitor, traverse=True):
-        if traverse:
-            for c in self.columns:
-                c.accept_schema_visitor(visitor, True)
-        return visitor.visit_table(self)
+    def get_children(self, column_collections=True, schema_visitor=False, **kwargs):
+        if not schema_visitor:
+            return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs)
+        else:
+            if column_collections:
+                return [c for c in self.columns]
+            else:
+                return []
 
     def exists(self, connectable=None):
         """Return True if this table exists."""
@@ -604,20 +611,12 @@ class Column(SchemaItem, sql._ColumnClause):
         return self.__originating_column._get_case_sensitive()
     case_sensitive = property(_case_sens, lambda s,v:None)
 
-    def accept_schema_visitor(self, visitor, traverse=True):
-        """Traverse the given visitor to this ``Column``'s default and foreign key object,
-        then call `visit_column` on the visitor."""
-
-        if traverse:
-            if self.default is not None:
-                self.default.accept_schema_visitor(visitor, traverse=True)
-            if self.onupdate is not None:
-                self.onupdate.accept_schema_visitor(visitor, traverse=True)
-            for f in self.foreign_keys:
-                f.accept_schema_visitor(visitor, traverse=True)
-            for constraint in self.constraints:
-                constraint.accept_schema_visitor(visitor, traverse=True)
-        visitor.visit_column(self)
+    def get_children(self, schema_visitor=False, **kwargs):
+        if schema_visitor:
+            return [x for x in (self.default, self.onupdate) if x is not None] + \
+                list(self.foreign_keys) + list(self.constraints)
+        else:
+            return sql._ColumnClause.get_children(self, **kwargs)
 
 
 class ForeignKey(SchemaItem):
@@ -715,7 +714,7 @@ class ForeignKey(SchemaItem):
 
     column = property(lambda s: s._init_column())
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         """Call the `visit_foreign_key` method on the given visitor."""
 
         visitor.visit_foreign_key(self)
@@ -771,7 +770,7 @@ class PassiveDefault(DefaultGenerator):
         super(PassiveDefault, self).__init__(**kwargs)
         self.arg = arg
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         return visitor.visit_passive_default(self)
 
     def __repr__(self):
@@ -788,7 +787,7 @@ class ColumnDefault(DefaultGenerator):
         super(ColumnDefault, self).__init__(**kwargs)
         self.arg = arg
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         """Call the visit_column_default method on the given visitor."""
 
         if self.for_update:
@@ -828,7 +827,7 @@ class Sequence(DefaultGenerator):
     def drop(self):
        self.get_engine().drop(self)
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         """Call the visit_seauence method on the given visitor."""
 
         return visitor.visit_sequence(self)
@@ -871,7 +870,7 @@ class CheckConstraint(Constraint):
         super(CheckConstraint, self).__init__(name)
         self.sqltext = sqltext
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         if isinstance(self.parent, Table):
             visitor.visit_check_constraint(self)
         else:
@@ -904,7 +903,7 @@ class ForeignKeyConstraint(Constraint):
         for (c, r) in zip(self.__colnames, self.__refcolnames):
             self.append_element(c,r)
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         visitor.visit_foreign_key_constraint(self)
 
     def append_element(self, col, refcol):
@@ -930,7 +929,7 @@ class PrimaryKeyConstraint(Constraint):
         for c in self.__colnames:
             self.append_column(table.c[c])
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         visitor.visit_primary_key_constraint(self)
 
     def add(self, col):
@@ -964,7 +963,7 @@ class UniqueConstraint(Constraint):
     def append_column(self, col):
         self.columns.add(col)
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         visitor.visit_unique_constraint(self)
 
     def copy(self):
@@ -1042,7 +1041,7 @@ class Index(SchemaItem):
         else:
             self.get_engine().drop(self)
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         visitor.visit_index(self)
 
     def __str__(self):
@@ -1118,7 +1117,7 @@ class MetaData(SchemaItem):
             connectable = self.get_engine()
         connectable.drop(self, checkfirst=checkfirst, tables=tables)
 
-    def accept_schema_visitor(self, visitor, traverse=True):
+    def accept_visitor(self, visitor):
         visitor.visit_metadata(self)
 
     def _derived_metadata(self):
@@ -1190,6 +1189,8 @@ class DynamicMetaData(MetaData):
 class SchemaVisitor(sql.ClauseVisitor):
     """Define the visiting for ``SchemaItem`` objects."""
 
+    __traverse_options__ = {'schema_visitor':True}
+
     def visit_schema(self, schema):
         """Visit a generic ``SchemaItem``."""
         pass
index 073277d576b1809ebf977d3bfa3a9ef7f92c5890..190ec29d407b06559af277228c7728bf625589cb 100644 (file)
@@ -5,7 +5,7 @@
 
 """Define the base components of SQL expression trees."""
 
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exceptions, logging
 from sqlalchemy import types as sqltypes
 import string, re, random, sets
 
@@ -485,44 +485,103 @@ class ClauseParameters(dict):
         return d
 
 class ClauseVisitor(object):
-    """Define the visiting of ``ClauseElements``."""
-
-    def visit_column(self, column):pass
-    def visit_table(self, column):pass
-    def visit_fromclause(self, fromclause):pass
-    def visit_bindparam(self, bindparam):pass
-    def visit_textclause(self, textclause):pass
-    def visit_compound(self, compound):pass
-    def visit_compound_select(self, compound):pass
-    def visit_binary(self, binary):pass
-    def visit_alias(self, alias):pass
-    def visit_select(self, select):pass
-    def visit_join(self, join):pass
-    def visit_null(self, null):pass
-    def visit_clauselist(self, list):pass
-    def visit_calculatedclause(self, calcclause):pass
-    def visit_function(self, func):pass
-    def visit_cast(self, cast):pass
-    def visit_label(self, label):pass
-    def visit_typeclause(self, typeclause):pass
-
-class VisitColumnMixin(object):
-    """a mixin that adds Column traversal to a ClauseVisitor"""
+    """A class that knows how to traverse and visit
+    ``ClauseElements``.
+    
+    Each ``ClauseElement``'s accept_visitor() method will call a
+    corresponding visit_XXXX() method here. Traversal of a
+    hierarchy of ``ClauseElements`` is achieved via the
+    ``traverse()`` method, which is passed the lead
+    ``ClauseElement``.
+    
+    By default, ``ClauseVisitor`` traverses all elements
+    fully.  Options can be specified at the class level via the 
+    ``__traverse_options__`` dictionary which will be passed
+    to the ``get_children()`` method of each ``ClauseElement``;
+    these options can indicate modifications to the set of 
+    elements returned, such as to not return column collections
+    (column_collections=False) or to return Schema-level items
+    (schema_visitor=True)."""
+    __traverse_options__ = {}
+    def traverse(self, obj):
+        for n in obj.get_children(**self.__traverse_options__):
+            self.traverse(n)
+        obj.accept_visitor(self)
+    def visit_column(self, column):
+        pass
     def visit_table(self, table):
-        for c in table.c:
-            c.accept_visitor(self)
-    def visit_select(self, select):
-        for c in select.c:
-            c.accept_visitor(self)
-    def visit_compound_select(self, select):
-        for c in select.c:
-            c.accept_visitor(self)
+        pass
+    def visit_fromclause(self, fromclause):
+        pass
+    def visit_bindparam(self, bindparam):
+        pass
+    def visit_textclause(self, textclause):
+        pass
+    def visit_compound(self, compound):
+        pass
+    def visit_compound_select(self, compound):
+        pass
+    def visit_binary(self, binary):
+        pass
     def visit_alias(self, alias):
-        for c in alias.c:
-            c.accept_visitor(self)
-        
+        pass
+    def visit_select(self, select):
+        pass
+    def visit_join(self, join):
+        pass
+    def visit_null(self, null):
+        pass
+    def visit_clauselist(self, list):
+        pass
+    def visit_calculatedclause(self, calcclause):
+        pass
+    def visit_function(self, func):
+        pass
+    def visit_cast(self, cast):
+        pass
+    def visit_label(self, label):
+        pass
+    def visit_typeclause(self, typeclause):
+        pass
+
+class LoggingClauseVisitor(ClauseVisitor):
+    """extends ClauseVisitor to include debug logging of all traversal.
+    
+    To install this visitor, set logging.DEBUG for 
+    'sqlalchemy.sql.ClauseVisitor' **before** you import the 
+    sqlalchemy.sql module.
+    """
+    
+    def traverse(self, obj):
+        indent = getattr(self, '_indent', "")
+        self.logger.debug(indent + "START " + repr(obj))
+        setattr(self, "_indent", indent + "    ")
+        for n in obj.get_children(**self.__traverse_options__):
+            self.traverse(n)
+        obj.accept_visitor(self)
+        setattr(self, "_indent", indent)
+        self.logger.debug(indent+ "END " + repr(obj))
+
+LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor)
+
+if logging.is_debug_enabled(LoggingClauseVisitor.logger):
+    ClauseVisitor=LoggingClauseVisitor
+
+class NoColumnVisitor(ClauseVisitor):
+    """a ClauseVisitor that will not traverse the exported Column 
+    collections on Table, Alias, Select, and CompoundSelect objects
+    (i.e. their 'columns' or 'c' attribute).
+    
+    this is useful because most traversals don't need those columns, or
+    in the case of ANSICompiler it traverses them explicitly; so
+    skipping their traversal here greatly cuts down on method call overhead.
+    """
+    
+    __traverse_options__ = {'column_collections':False}
+    
 class Executor(object):
-    """Represent a *thing that can produce Compiled objects and execute them*."""
+    """Interface representing a *thing that can produce Compiled objects 
+    and execute them*."""
 
     def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
         """Execute a Compiled object."""
@@ -539,7 +598,7 @@ class Compiled(ClauseVisitor):
 
     The ``__str__`` method of the ``Compiled`` object should produce
     the actual text of the statement.  ``Compiled`` objects are
-    specific to the database library that created them, and also may
+    specific to their underlying database dialect, and also may
     or may not be specific to the columns referenced within a
     particular set of bind parameters.  In no case should the
     ``Compiled`` object be dependent on the actual values of those
@@ -547,7 +606,7 @@ class Compiled(ClauseVisitor):
     defaults.
     """
 
-    def __init__(self, dialect, statement, parameters, engine=None):
+    def __init__(self, dialect, statement, parameters, engine=None, traversal=None):
         """Construct a new Compiled object.
 
         statement
@@ -570,7 +629,7 @@ class Compiled(ClauseVisitor):
         engine
           Optional Engine to compile this statement against.
         """
-
+        ClauseVisitor.__init__(self, traversal=traversal)
         self.dialect = dialect
         self.statement = statement
         self.parameters = parameters
@@ -578,7 +637,7 @@ class Compiled(ClauseVisitor):
         self.can_execute = statement.supports_execution()
 
     def compile(self):
-        self.statement.accept_visitor(self)
+        self.traverse(self.statement)
         self.after_compile()
 
     def __str__(self):
@@ -649,7 +708,19 @@ class ClauseElement(object):
         """
 
         raise NotImplementedError(repr(self))
-
+    
+    def get_children(self, **kwargs):
+        """return immediate child elements of this ``ClauseElement``.
+        
+        this is used for visit traversal.
+        
+        **kwargs may contain flags that change the collection
+        that is returned, for example to return a subset of items
+        in order to cut down on larger traversals, or to return 
+        child items from a different context (such as schema-level
+        collections instead of clause-level)."""
+        return []
+        
     def supports_execution(self):
         """Return True if this clause element represents a complete
         executable statement.
@@ -1058,16 +1129,38 @@ class FromClause(Selectable):
 
     def _get_all_embedded_columns(self):
         ret = []
-        class FindCols(VisitColumnMixin, ClauseVisitor):
+        class FindCols(ClauseVisitor):
             def visit_column(self, col):
                 ret.append(col)
-        self.accept_visitor(FindCols())
+        FindCols().traverse(self)
         return ret
 
     def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
-        """Given a ``ColumnElement``, return the ``ColumnElement``
-        object from this ``Selectable`` which corresponds to that
-        original ``Column`` via a proxy relationship.
+        """Given a ``ColumnElement``, return the exported
+        ``ColumnElement`` object from this ``Selectable`` which
+        corresponds to that original ``Column`` via a common
+        anscestor column.
+        
+        column
+          the target ``ColumnElement`` to be matched
+            
+        raiseerr
+          if True, raise an error if the given ``ColumnElement``
+          could not be matched. if False, non-matches will
+          return None.
+            
+        keys_ok
+          if the ``ColumnElement`` cannot be matched, attempt to
+          match based on the string "key" property of the column
+          alone. This makes the search much more liberal.
+            
+        require_embedded
+          only return corresponding columns for the given
+          ``ColumnElement``, if the given ``ColumnElement`` is
+          actually present within a sub-element of this
+          ``FromClause``.  Normally the column will match if
+          it merely shares a common anscestor with one of
+          the exported columns of this ``FromClause``.
         """
 
         if require_embedded and column not in util.Set(self._get_all_embedded_columns()):
@@ -1258,11 +1351,14 @@ class _TextClause(ClauseElement):
         if bindparams is not None:
             for b in bindparams:
                 self.bindparams[b.key] = b
-    columns = property(lambda s:[])        
-    def accept_visitor(self, visitor): 
-        for item in self.bindparams.values():
-            item.accept_visitor(visitor)
+    columns = property(lambda s:[])
+
+    def get_children(self, **kwargs):
+        return self.bindparams.values()
+
+    def accept_visitor(self, visitor):
         visitor.visit_textclause(self)
+
     def _get_from_objects(self):
         return []
     def supports_execution(self):
@@ -1296,9 +1392,9 @@ class ClauseList(ClauseElement):
         if _is_literal(clause):
             clause = _TextClause(str(clause))
         self.clauses.append(clause)
+    def get_children(self, **kwargs):
+        return self.clauses
     def accept_visitor(self, visitor):
-        for c in self.clauses:
-            c.accept_visitor(visitor)
         visitor.visit_clauselist(self)
     def _get_from_objects(self):
         f = []
@@ -1338,9 +1434,9 @@ class _CompoundClause(ClauseList):
             clause.parens = True
         ClauseList.append(self, clause)
 
+    def get_children(self, **kwargs):
+        return self.clauses
     def accept_visitor(self, visitor):
-        for c in self.clauses:
-            c.accept_visitor(visitor)
         visitor.visit_compound(self)
 
     def _get_from_objects(self):
@@ -1384,9 +1480,9 @@ class _CalculatedClause(ClauseList, ColumnElement):
         clauses = [clause.copy_container() for clause in self.clauses]
         return _CalculatedClause(type=self.type, engine=self._engine, *clauses)
 
+    def get_children(self, **kwargs):
+        return self.clauses
     def accept_visitor(self, visitor):
-        for c in self.clauses:
-            c.accept_visitor(visitor)
         visitor.visit_calculatedclause(self)
 
     def _bind_param(self, obj):
@@ -1432,9 +1528,9 @@ class _Function(_CalculatedClause, FromClause):
         clauses = [clause.copy_container() for clause in self.clauses]
         return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
 
+    def get_children(self, **kwargs):
+        return self.clauses
     def accept_visitor(self, visitor):
-        for c in self.clauses:
-            c.accept_visitor(visitor)
         visitor.visit_function(self)
 
 class _Cast(ColumnElement):
@@ -1445,9 +1541,9 @@ class _Cast(ColumnElement):
         self.clause = clause
         self.typeclause = _TypeClause(self.type)
 
+    def get_children(self, **kwargs):
+        return self.clause, self.typeclause
     def accept_visitor(self, visitor):
-        self.clause.accept_visitor(visitor)
-        self.typeclause.accept_visitor(visitor)
         visitor.visit_cast(self)
 
     def _get_from_objects(self):
@@ -1494,9 +1590,9 @@ class _BinaryClause(ClauseElement):
         return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator)
     def _get_from_objects(self):
         return self.left._get_from_objects() + self.right._get_from_objects()
+    def get_children(self, **kwargs):
+        return self.left, self.right
     def accept_visitor(self, visitor):
-        self.left.accept_visitor(visitor)
-        self.right.accept_visitor(visitor)
         visitor.visit_binary(self)
     def swap(self):
         c = self.left
@@ -1589,12 +1685,12 @@ class Join(FromClause):
     def _get_folded_equivalents(self, equivs=None):
         if equivs is None:
             equivs = util.Set()
-        class LocateEquivs(ClauseVisitor):
+        class LocateEquivs(NoColumnVisitor):
             def visit_binary(self, binary):
                 if binary.operator == '=' and binary.left.name == binary.right.name:
                     equivs.add(binary.right)
                     equivs.add(binary.left)
-        self.onclause.accept_visitor(LocateEquivs())
+        LocateEquivs().traverse(self.onclause)
         collist = []
         if isinstance(self.left, Join):
             left = self.left._get_folded_equivalents(equivs)
@@ -1636,10 +1732,9 @@ class Join(FromClause):
             
         return select(collist, whereclause, from_obj=[self], **kwargs)
 
+    def get_children(self, **kwargs):
+        return self.left, self.right, self.onclause
     def accept_visitor(self, visitor):
-        self.left.accept_visitor(visitor)
-        self.right.accept_visitor(visitor)
-        self.onclause.accept_visitor(visitor)
         visitor.visit_join(self)
 
     engine = property(lambda s:s.left.engine or s.right.engine)
@@ -1692,8 +1787,11 @@ class Alias(FromClause):
         #return self.selectable._exportable_columns()
         return self.selectable.columns
 
+    def get_children(self, **kwargs):
+        for c in self.c:
+            yield c
+        yield self.selectable
     def accept_visitor(self, visitor):
-        self.selectable.accept_visitor(visitor)
         visitor.visit_alias(self)
 
     def _get_from_objects(self):
@@ -1717,9 +1815,10 @@ class _Label(ColumnElement):
     key = property(lambda s: s.name)
     _label = property(lambda s: s.name)
     orig_set = property(lambda s:s.obj.orig_set)
-
+    
+    def get_children(self, **kwargs):
+        return self.obj,
     def accept_visitor(self, visitor):
-        self.obj.accept_visitor(visitor)
         visitor.visit_label(self)
 
     def _get_from_objects(self):
@@ -1841,6 +1940,11 @@ class TableClause(FromClause):
 
     original_columns = property(_orig_columns)
 
+    def get_children(self, column_collections=True, **kwargs):
+        if column_collections:
+            return [c for c in self.c]
+        else:
+            return []
     def accept_visitor(self, visitor):
         visitor.visit_table(self)
 
@@ -1964,11 +2068,10 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         col.orig_set = colset
         return col
 
+    def get_children(self, column_collections=True, **kwargs):
+        return (column_collections and list(self.c) or []) + \
+            [self.order_by_clause, self.group_by_clause] + list(self.selects)
     def accept_visitor(self, visitor):
-        self.order_by_clause.accept_visitor(visitor)
-        self.group_by_clause.accept_visitor(visitor)
-        for s in self.selects:
-            s.accept_visitor(visitor)
         visitor.visit_compound_select(self)
 
     def _find_engine(self):
@@ -2028,9 +2131,9 @@ class Select(_SelectBaseMixin, FromClause):
         self.order_by(*(order_by or [None]))
         self.group_by(*(group_by or [None]))
         for c in self.order_by_clause:
-            c.accept_visitor(self.__correlator)
+            self.__correlator.traverse(c)
         for c in self.group_by_clause:
-            c.accept_visitor(self.__correlator)
+            self.__correlator.traverse(c)
 
         for f in from_obj:
             self.append_from(f)
@@ -2044,13 +2147,14 @@ class Select(_SelectBaseMixin, FromClause):
             self.append_having(having)
 
 
-    class _CorrelatedVisitor(ClauseVisitor):
+    class _CorrelatedVisitor(NoColumnVisitor):
         """Visit a clause, locate any ``Select`` clauses, and tell
         them that they should correlate their ``FROM`` list to that of
         their parent.
         """
 
         def __init__(self, select, is_where):
+            NoColumnVisitor.__init__(self)
             self.select = select
             self.is_where = is_where
 
@@ -2084,12 +2188,12 @@ class Select(_SelectBaseMixin, FromClause):
 
         # if the column is a Select statement itself, 
         # accept visitor
-        column.accept_visitor(self.__correlator)
+        self.__correlator.traverse(column)
         
         # visit the FROM objects of the column looking for more Selects
         for f in column._get_from_objects():
             if f is not self:
-                f.accept_visitor(self.__correlator)
+                self.__correlator.traverse(f)
         self._process_froms(column, False)
     def _make_proxy(self, selectable, name):
         if self.is_scalar:
@@ -2127,7 +2231,7 @@ class Select(_SelectBaseMixin, FromClause):
     def _append_condition(self, attribute, condition):
         if type(condition) == str:
             condition = _TextClause(condition)
-        condition.accept_visitor(self.__wherecorrelator)
+        self.__wherecorrelator.traverse(condition)
         self._process_froms(condition, False)
         if getattr(self, attribute) is not None:
             setattr(self, attribute, and_(getattr(self, attribute), condition))
@@ -2146,7 +2250,7 @@ class Select(_SelectBaseMixin, FromClause):
     def append_from(self, fromclause):
         if type(fromclause) == str:
             fromclause = FromClause(fromclause)
-        fromclause.accept_visitor(self.__correlator)
+        self.__correlator.traverse(fromclause)
         self._process_froms(fromclause, True)
 
     def _locate_oid_column(self):
@@ -2169,16 +2273,14 @@ class Select(_SelectBaseMixin, FromClause):
             return f
 
     froms = property(_calc_froms, doc="""A collection containing all elements of the FROM clause""")
+    
+    def get_children(self, column_collections=True, **kwargs):
+        return (column_collections and list(self.columns) or []) + \
+            list(self.froms) + \
+            [x for x in (self.whereclause, self.having) if x is not None] + \
+            [self.order_by_clause, self.group_by_clause]
 
     def accept_visitor(self, visitor):
-        for f in self.froms:
-            f.accept_visitor(visitor)
-        if self.whereclause is not None:
-            self.whereclause.accept_visitor(visitor)
-        if self.having is not None:
-            self.having.accept_visitor(visitor)
-        self.order_by_clause.accept_visitor(visitor)
-        self.group_by_clause.accept_visitor(visitor)
         visitor.visit_select(self)
 
     def union(self, other, **kwargs):
@@ -2259,10 +2361,12 @@ class _Insert(_UpdateBase):
         self.select = None
         self.parameters = self._process_colparams(values)
 
-    def accept_visitor(self, visitor):
+    def get_children(self, **kwargs):
         if self.select is not None:
-            self.select.accept_visitor(visitor)
-
+            return self.select,
+        else:
+            return ()
+    def accept_visitor(self, visitor):
         visitor.visit_insert(self)
 
 class _Update(_UpdateBase):
@@ -2271,9 +2375,12 @@ class _Update(_UpdateBase):
         self.whereclause = whereclause
         self.parameters = self._process_colparams(values)
 
-    def accept_visitor(self, visitor):
+    def get_children(self, **kwargs):
         if self.whereclause is not None:
-            self.whereclause.accept_visitor(visitor)
+            return self.whereclause,
+        else:
+            return ()
+    def accept_visitor(self, visitor):
         visitor.visit_update(self)
 
 class _Delete(_UpdateBase):
@@ -2281,7 +2388,10 @@ class _Delete(_UpdateBase):
         self.table = table
         self.whereclause = whereclause
 
-    def accept_visitor(self, visitor):
+    def get_children(self, **kwargs):
         if self.whereclause is not None:
-            self.whereclause.accept_visitor(visitor)
+            return self.whereclause,
+        else:
+            return ()
+    def accept_visitor(self, visitor):
         visitor.visit_delete(self)
index 70fc85702e198e28e10adb533bbfdd92719a36bd..1d185bbc5fe443e42279b58ccae49671ea88d4c7 100644 (file)
@@ -51,7 +51,7 @@ class TableCollection(object):
                     tuples.append( ( parent_table, child_table ) )
         vis = TVisitor()
         for table in self.tables:
-            table.accept_schema_visitor(vis)
+            vis.traverse(table)
         sorter = topological.QueueDependencySorter( tuples, self.tables )
         head =  sorter.sort()
         sequence = []
@@ -64,21 +64,21 @@ class TableCollection(object):
         return sequence
 
 
-class TableFinder(TableCollection, sql.ClauseVisitor):
+class TableFinder(TableCollection, sql.NoColumnVisitor):
     """Given a ``Clause``, locate all the ``Tables`` within it into a list."""
 
     def __init__(self, table, check_columns=False):
         TableCollection.__init__(self)
         self.check_columns = check_columns
         if table is not None:
-            table.accept_visitor(self)
+            self.traverse(table)
 
     def visit_table(self, table):
         self.tables.append(table)
 
     def visit_column(self, column):
         if self.check_columns:
-            column.table.accept_visitor(self)
+            self.traverse(column.table)
 
 class ColumnFinder(sql.ClauseVisitor):
     def __init__(self):
@@ -103,7 +103,7 @@ class ColumnsInClause(sql.ClauseVisitor):
         if self.selectable.c.get(column.key) is column:
             self.result = True
 
-class AbstractClauseProcessor(sql.ClauseVisitor):
+class AbstractClauseProcessor(sql.NoColumnVisitor):
     """Traverse a clause and attempt to convert the contents of container elements
     to a converted element.
 
@@ -132,7 +132,7 @@ class AbstractClauseProcessor(sql.ClauseVisitor):
             if elem is not None:
                 list_[i] = elem
             else:
-                list_[i].accept_visitor(self)
+                self.traverse(list_[i])
 
     def visit_compound(self, compound):
         self.visit_clauselist(compound)
@@ -198,7 +198,7 @@ class ClauseAdapter(AbstractClauseProcessor):
 
       s = table1.alias('foo')
 
-    calling ``condition.accept_visitor(ClauseAdapter(s))`` converts
+    calling ``ClauseAdapter(s).traverse(condition)`` converts
     condition to read::
 
       s.c.col1 == table2.c.col1
index 388ed30c80d848d7bb04e5eae9e540594baf2f8e..51a3d35c675768303a564b3bcc9866c53b1e0e1a 100644 (file)
@@ -500,8 +500,8 @@ class SchemaTest(PersistTest):
         def foo(s, p):
             buf.write(s)
         gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo, None)
-        table1.accept_schema_visitor(gen)
-        table2.accept_schema_visitor(gen)
+        gen.traverse(table1)
+        gen.traverse(table2)
         buf = buf.getvalue()
         print buf
         assert buf.index("CREATE TABLE someschema.table1") > -1
index 79ccee4da2be5671b09a59302d7c766ae2a12e22..231a491b5272b31571d19cdfd092462ad624d614 100644 (file)
@@ -177,7 +177,7 @@ class ConstraintTest(testbase.AssertMixin):
             capt.append(repr(parameters))
             connection.proxy(statement, parameters)
         schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy, connection)
-        events.accept_schema_visitor(schemagen)
+        schemagen.traverse(events)
         
         assert capt[0].strip().startswith('CREATE TABLE events')