From: Mike Bayer Date: Sun, 11 Mar 2007 20:52:02 +0000 (+0000) Subject: - for hackers, refactored the "visitor" system of ClauseElement and X-Git-Tag: rel_0_3_6~29 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6a3c374b955299f0065356ef1de6cc0920d5382e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - for hackers, refactored the "visitor" system of ClauseElement and SchemaItem so that the traversal of items is controlled by the ClauseVisitor itself, using the method visitor.traverse(item). accept_visitor() methods can still be called directly but will not do any traversal of child items. ClauseElement/SchemaItem now have a configurable get_children() method to return the collection of child elements for each parent object. This allows the full traversal of items to be clear and unambiguous (as well as loggable), with an easy method of limiting a traversal (just pass flags which are picked up by appropriate get_children() methods). [ticket:501] - accept_schema_visitor() methods removed, replaced with get_children(schema_visitor=True) - various docstring/changelog cleanup/reformatting --- diff --git a/CHANGES b/CHANGES index d8d4f4e564..18b24cf540 100644 --- a/CHANGES +++ b/CHANGES @@ -6,31 +6,53 @@ with conflicting names, specify "unique=True" - this option is still used internally for all the auto-genererated (value-based) bind parameters. + - exists() becomes useable as a standalone selectable, not just in a - WHERE clause + WHERE clause, i.e. exists([columns], criterion).select() + - correlated subqueries work inside of ORDER BY, GROUP BY - - fixed function execution with explicit connections, when you dont - explicitly say "select()" off the function, i.e. + + - fixed function execution with explicit connections, i.e. conn.execute(func.dosomething()) + - use_labels flag on select() wont auto-create labels for literal text column elements, since we can make no assumptions about the text. to create labels for literal columns, you can say "somecol AS somelabel", or use literal_column("somecol").label("somelabel") + - quoting wont occur for literal columns when they are "proxied" into the - column collection for their selectable (is_literal flag is propigated) - - added "fold_equivalents" argument to Join.select(), which removes + column collection for their selectable (is_literal flag is propigated). + literal columns are specified via literal_column("somestring"). + + - added "fold_equivalents" boolean argument to Join.select(), which removes 'duplicate' columns from the resulting column clause that are known to be equivalent based on the join condition. this is of great usage when constructing subqueries of joins which Postgres complains about if duplicate column names are present. + - fixed use_alter flag on ForeignKeyConstraint [ticket:503] + - fixed usage of 2.4-only "reversed" in topological.py [ticket:506] + + - for hackers, refactored the "visitor" system of ClauseElement and + SchemaItem so that the traversal of items is controlled by the + ClauseVisitor itself, using the method visitor.traverse(item). + accept_visitor() methods can still be called directly but will + not do any traversal of child items. ClauseElement/SchemaItem now + have a configurable get_children() method to return the collection + of child elements for each parent object. This allows the full + traversal of items to be clear and unambiguous (as well as loggable), + with an easy method of limiting a traversal (just pass flags which + are picked up by appropriate get_children() methods). [ticket:501] + - oracle: - got binary working for any size input ! cx_oracle works fine, it was my fault as BINARY was being passed and not BLOB for setinputsizes (also unit tests werent even setting input sizes). + - auto_setinputsizes defaults to True for Oracle, fixed cases where it improperly propigated bad types. + - orm: - the full featureset of the SelectResults extension has been merged into a new set of methods available off of Query. These methods @@ -60,23 +82,30 @@ as a list of tuples. this corresponds to the documented behavior. So that instances match up properly, the "uniquing" is disabled when this feature is used. + - Query has add_entity() and add_column() generative methods. these will add the given mapper/class or ColumnElement to the query at compile - time, and apply them to the instances method. the user is responsible + time, and apply them to the instances() method. the user is responsible for constructing reasonable join conditions (otherwise you can get full cartesian products). result set is the list of tuples, non-uniqued. + - strings and columns can also be sent to the *args of instances() where those exact result columns will be part of the result tuples. + - a full select() construct can be passed to query.select() (which worked anyway), but also query.selectfirst(), query.selectone() which will be used as is (i.e. no query is compiled). works similarly to sending the results to instances(). - - added "refresh-expire" cascade [ticket:492] + + - added "refresh-expire" cascade [ticket:492]. allows refresh() and + expire() calls to propigate along relationships. + - more fixes to polymorphic relations, involving proper lazy-clause generation on many-to-one relationships to polymorphic mappers [ticket:493]. also fixes to detection of "direction", more specific targeting of columns that belong to the polymorphic union vs. those that dont. + - put an aggressive check for "flushing object A with a collection of B's, but you put a C in the collection" error condition - **even if C is a subclass of B**, unless B's mapper loads polymorphically. @@ -84,9 +113,17 @@ (since its not polymorphic) which breaks in bi-directional relationships (i.e. C has its A, but A's backref will lazyload it as a different instance of type "B") [ticket:500] + This check is going to bite some of you who do this without issues, + so the error message will also document a flag "enable_typechecks=False" + to disable this checking. But be aware that bi-directional relationships + in particular become fragile without this check. + - extensions: + - options() method on SelectResults now implemented "generatively" - like the rest of the SelectResults methods [ticket:472] + like the rest of the SelectResults methods [ticket:472]. But + you're going to just use Query now anyway. + - query() method is added by assignmapper. this helps with navigating to all the new generative methods on Query. diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 5d5c42208c..ebaedca542 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -75,6 +75,8 @@ class ANSICompiler(sql.Compiled): Compiles ClauseElements into ANSI-compliant SQL strings. """ + __traverse_options__ = {'column_collections':False} + def __init__(self, dialect, statement, parameters=None, **kwargs): """Construct a new ``ANSICompiler`` object. @@ -388,13 +390,13 @@ class ANSICompiler(sql.Compiled): self.select_stack.append(select) for c in select._raw_columns: if isinstance(c, sql.Select) and c.is_scalar: - c.accept_visitor(self) + self.traverse(c) inner_columns[self.get_str(c)] = c continue if hasattr(c, '_selectable'): s = c._selectable() else: - c.accept_visitor(self) + self.traverse(c) inner_columns[self.get_str(c)] = c continue for co in s.columns: @@ -402,10 +404,10 @@ class ANSICompiler(sql.Compiled): labelname = co._label if labelname is not None: l = co.label(labelname) - l.accept_visitor(self) + self.traverse(l) inner_columns[labelname] = l else: - co.accept_visitor(self) + self.traverse(co) inner_columns[self.get_str(co)] = co # TODO: figure this out, a ColumnClause with a select as a parent # is different from any other kind of parent @@ -414,10 +416,10 @@ class ANSICompiler(sql.Compiled): # names look like table.colname, so add a label synonomous with # the column name l = co.label(co.name) - l.accept_visitor(self) + self.traverse(l) inner_columns[self.get_str(l.obj)] = l else: - co.accept_visitor(self) + self.traverse(co) inner_columns[self.get_str(co)] = co self.select_stack.pop(-1) @@ -443,7 +445,7 @@ class ANSICompiler(sql.Compiled): else: continue clause = c==value - clause.accept_visitor(self) + self.traverse(clause) whereclause = sql.and_(clause, whereclause) self.visit_compound(whereclause) @@ -596,7 +598,7 @@ class ANSICompiler(sql.Compiled): vis = DefaultVisitor() for c in insert_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): - c.accept_schema_visitor(vis) + vis.traverse(c) self.isinsert = True colparams = self._get_colparams(insert_stmt, default_params) @@ -610,7 +612,7 @@ class ANSICompiler(sql.Compiled): return self.bindparam_string(p.key) else: self.inline_params.add(col) - p.accept_visitor(self) + self.traverse(p) if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): return "(" + self.get_str(p) + ")" else: @@ -631,7 +633,7 @@ class ANSICompiler(sql.Compiled): vis = OnUpdateVisitor() for c in update_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): - c.accept_schema_visitor(vis) + vis.traverse(c) self.isupdate = True colparams = self._get_colparams(update_stmt, default_params) @@ -643,7 +645,7 @@ class ANSICompiler(sql.Compiled): self.binds[p.shortname] = p return self.bindparam_string(p.key) else: - p.accept_visitor(self) + self.traverse(p) self.inline_params.add(col) if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement): return "(" + self.get_str(p) + ")" @@ -734,7 +736,7 @@ class ANSISchemaBase(engine.SchemaIterator): findalterables = FindAlterables() for table in tables: for c in table.constraints: - c.accept_schema_visitor(findalterables) + findalterables.traverse(c) return alterables class ANSISchemaGenerator(ANSISchemaBase): @@ -752,7 +754,7 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_metadata(self, metadata): collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))] for table in collection: - table.accept_schema_visitor(self, traverse=False) + table.accept_visitor(self) if self.supports_alter(): for alterable in self.find_alterables(collection): self.add_foreignkey(alterable) @@ -760,9 +762,9 @@ class ANSISchemaGenerator(ANSISchemaBase): def visit_table(self, table): for column in table.columns: if column.default is not None: - column.default.accept_schema_visitor(self, traverse=False) + column.default.accept_visitor(self) #if column.onupdate is not None: - # column.onupdate.accept_schema_visitor(visitor, traverse=False) + # column.onupdate.accept_visitor(visitor) self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") @@ -777,20 +779,20 @@ class ANSISchemaGenerator(ANSISchemaBase): if column.primary_key: first_pk = True for constraint in column.constraints: - constraint.accept_schema_visitor(self, traverse=False) + constraint.accept_visitor(self) # On some DB order is significant: visit PK first, then the # other constraints (engine.ReflectionTest.testbasic failed on FB2) if len(table.primary_key): - table.primary_key.accept_schema_visitor(self, traverse=False) + table.primary_key.accept_visitor(self) for constraint in [c for c in table.constraints if c is not table.primary_key]: - constraint.accept_schema_visitor(self, traverse=False) + constraint.accept_visitor(self) self.append("\n)%s\n\n" % self.post_create_table(table)) self.execute() if hasattr(table, 'indexes'): for index in table.indexes: - index.accept_schema_visitor(self, traverse=False) + index.accept_visitor(self) def post_create_table(self, table): return '' @@ -890,7 +892,7 @@ class ANSISchemaDropper(ANSISchemaBase): for alterable in self.find_alterables(collection): self.drop_foreignkey(alterable) for table in collection: - table.accept_schema_visitor(self, traverse=False) + table.accept_visitor(self) def supports_alter(self): return True @@ -906,7 +908,7 @@ class ANSISchemaDropper(ANSISchemaBase): def visit_table(self, table): for column in table.columns: if column.default is not None: - column.default.accept_schema_visitor(self, traverse=False) + column.default.accept_visitor(self) self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 254ea60131..8c3c71f6ed 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -657,11 +657,11 @@ class MSSQLCompiler(ansisql.ANSICompiler): if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table): alias = table.alias() self.tablealiases[table] = alias - alias.accept_visitor(self) + self.traverse(alias) self.froms[('alias', table)] = self.froms[table] for c in alias.c: - c.accept_visitor(self) - alias.oid_column.accept_visitor(self) + self.traverse(c) + self.traverse(alias.oid_column) self.tablealiases[alias] = self.froms[table] self.froms[table] = self.froms[alias] else: diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 966834eb25..1dba60c1d7 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -434,7 +434,7 @@ class OracleCompiler(ansisql.ANSICompiler): # now re-visit the onclause, which will be used as a where clause # (the first visit occured via the Join object itself right before it called visit_join()) - join.onclause.accept_visitor(self) + self.traverse(join.onclause) self._outertable = None @@ -488,12 +488,12 @@ class OracleCompiler(ansisql.ANSICompiler): orderby = self.strings[select.order_by_clause] if not orderby: orderby = select.oid_column - orderby.accept_visitor(self) + self.traverse(orderby) orderby = self.strings[orderby] - class SelectVisitor(sql.ClauseVisitor): + class SelectVisitor(sql.NoColumnVisitor): def visit_select(self, select): select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) - select.accept_visitor(SelectVisitor()) + SelectVisitor().traverse(select) limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) if select.offset is not None: limitselect.append_whereclause("ora_rn>%d" % select.offset) @@ -501,7 +501,7 @@ class OracleCompiler(ansisql.ANSICompiler): limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) else: limitselect.append_whereclause("ora_rn<=%d" % select.limit) - limitselect.accept_visitor(self) + self.traverse(limitselect) self.strings[select] = self.strings[limitselect] self.froms[select] = self.froms[limitselect] else: @@ -527,7 +527,7 @@ class OracleCompiler(ansisql.ANSICompiler): orderby = self.strings[select.order_by_clause] if not orderby: orderby = select.oid_column - orderby.accept_visitor(self) + self.traverse(orderby) orderby = self.strings[orderby] select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")) limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) @@ -537,7 +537,7 @@ class OracleCompiler(ansisql.ANSICompiler): limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset)) else: limitselect.append_whereclause("ora_rn<=%d" % select.limit) - limitselect.accept_visitor(self) + self.traverse(limitselect) self.strings[select] = self.strings[limitselect] self.froms[select] = self.froms[limitselect] else: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 0a53da9284..f79167abc3 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -446,7 +446,7 @@ class Connection(Connectable): raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object))) def execute_default(self, default, **kwargs): - return default.accept_schema_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs)) + return default.accept_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs)) def execute_text(self, statement, *multiparams, **params): if len(multiparams) == 0: @@ -672,7 +672,7 @@ class Engine(sql.Executor, Connectable): else: conn = connection try: - element.accept_schema_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs), traverse=False) + element.accept_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs)) finally: if connection is None: conn.close() @@ -1164,13 +1164,13 @@ class DefaultRunner(schema.SchemaVisitor): def get_column_default(self, column): if column.default is not None: - return column.default.accept_schema_visitor(self) + return column.default.accept_visitor(self) else: return None def get_column_onupdate(self, column): if column.onupdate is not None: - return column.onupdate.accept_schema_visitor(self) + return column.onupdate.accept_visitor(self) else: return None diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index d28445be61..74dd58a3f6 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -767,7 +767,7 @@ class Mapper(object): vis = mapperutil.BinaryVisitor(visit_binary) for mapper in self.base_mapper().polymorphic_iterator(): if mapper.inherit_condition is not None: - mapper.inherit_condition.accept_visitor(vis) + vis.traverse(mapper.inherit_condition) return result def add_properties(self, dict_of_properties): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 2d10b2f9db..c9a2dbe597 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -233,9 +233,9 @@ class PropertyLoader(StrategizedProperty): # error message in case its the "old" way. if self.loads_polymorphic: vis = sql_util.ColumnsInClause(self.mapper.select_table) - self.primaryjoin.accept_visitor(vis) + vis.traverse(self.primaryjoin) if self.secondaryjoin: - self.secondaryjoin.accept_visitor(vis) + vis.traverse(self.secondaryjoin) if vis.result: raise exceptions.ArgumentError("In relationship '%s', primary and secondary join conditions must not include columns from the polymorphic 'select_table' argument as of SA release 0.3.4. Construct join conditions using the base tables of the related mappers." % (str(self))) @@ -251,9 +251,9 @@ class PropertyLoader(StrategizedProperty): self._opposite_side.add(binary.right) if binary.right in self.foreign_keys: self._opposite_side.add(binary.left) - self.primaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary)) + mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin) if self.secondaryjoin is not None: - self.secondaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary)) + mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin) else: self.foreign_keys = util.Set() self._opposite_side = util.Set() @@ -268,12 +268,12 @@ class PropertyLoader(StrategizedProperty): if f.references(binary.left.table): self.foreign_keys.add(binary.right) self._opposite_side.add(binary.left) - self.primaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary)) + mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin) if len(self.foreign_keys) == 0: raise exceptions.ArgumentError("Cant locate any foreign key columns in primary join condition '%s' for relationship '%s'. Specify 'foreign_keys' argument to indicate which columns in the join condition are foreign." %(str(self.primaryjoin), str(self))) if self.secondaryjoin is not None: - self.secondaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary)) + mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin) def _determine_direction(self): """Determine our *direction*, i.e. do we represent one to @@ -343,14 +343,14 @@ class PropertyLoader(StrategizedProperty): if self.loads_polymorphic: if self.secondaryjoin: self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container() - self.polymorphic_secondaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table)) + sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin) self.polymorphic_primaryjoin = self.primaryjoin.copy_container() else: self.polymorphic_primaryjoin = self.primaryjoin.copy_container() if self.direction is sync.ONETOMANY: - self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents)) + sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin) elif self.direction is sync.MANYTOONE: - self.polymorphic_primaryjoin.accept_visitor(sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents)) + sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin) self.polymorphic_secondaryjoin = None # load "polymorphic" versions of the columns present in "remote_side" - this is # important for lazy-clause generation which goes off the polymorphic target selectable @@ -411,11 +411,11 @@ class PropertyLoader(StrategizedProperty): else: secondaryjoin = None if self.direction is sync.ONETOMANY: - primaryjoin.accept_visitor(sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents)) + sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) elif self.direction is sync.MANYTOONE: - primaryjoin.accept_visitor(sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents)) + sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) elif self.secondaryjoin: - primaryjoin.accept_visitor(sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents)) + sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin) if secondaryjoin is not None: j = primaryjoin & secondaryjoin diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a1c8b6af51..a0b520f339 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -775,7 +775,7 @@ class Query(object): # adapt the given WHERECLAUSE to adjust instances of this query's mapped # table to be that of our select_table, # which may be the "polymorphic" selectable used by our mapper. - whereclause.accept_visitor(sql_util.ClauseAdapter(self.table)) + sql_util.ClauseAdapter(self.table).traverse(whereclause) # if extra entities, adapt the criterion to those as well for m in self._entities: @@ -783,7 +783,7 @@ class Query(object): m = mapper.class_mapper(m) if isinstance(m, mapper.Mapper): table = m.select_table - whereclause.accept_visitor(sql_util.ClauseAdapter(m.select_table)) + sql_util.ClauseAdapter(m.select_table).traverse(whereclause) # get/create query context. get the ultimate compile arguments # from there @@ -827,7 +827,7 @@ class Query(object): order_by = util.to_list(order_by) or [] cf = sql_util.ColumnFinder() for o in order_by: - o.accept_visitor(cf) + cf.traverse(o) else: cf = [] diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 8e19be5367..a295ed862d 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -260,7 +260,7 @@ class LazyLoader(AbstractRelationLoader): class FindColumnInColumnClause(sql.ClauseVisitor): def visit_column(self, c): columns.append(c) - expr.accept_visitor(FindColumnInColumnClause()) + FindColumnInColumnClause().traverse(expr) return len(columns) and columns[0] or None def col_in_collection(column, collection): @@ -294,7 +294,7 @@ class LazyLoader(AbstractRelationLoader): lazywhere = primaryjoin.copy_container() li = mapperutil.BinaryVisitor(visit_binary) - lazywhere.accept_visitor(li) + li.traverse(lazywhere) if secondaryjoin is not None: secondaryjoin = secondaryjoin.copy_container() @@ -363,16 +363,16 @@ class EagerLoader(AbstractRelationLoader): eagerloader.secondary:self.eagersecondary }) self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container() - self.eagersecondaryjoin.accept_visitor(self.aliasizer) + self.aliasizer.traverse(self.eagersecondaryjoin) self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container() - self.eagerprimary.accept_visitor(self.aliasizer) + self.aliasizer.traverse(self.eagerprimary) else: self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container() self.aliasizer = sql_util.Aliasizer(self.target, aliases={self.target:self.eagertarget}) - self.eagerprimary.accept_visitor(self.aliasizer) + self.aliasizer.traverse(self.eagerprimary) if parentclauses is not None: - self.eagerprimary.accept_visitor(parentclauses.aliasizer) + parentclauses.aliasizer.traverse(self.eagerprimary) if eagerloader.order_by: self.eager_order_by = self._aliasize_orderby(eagerloader.order_by) diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 68fa9cee16..8c70f8cf81 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -80,8 +80,7 @@ class ClauseSynchronizer(object): self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary)) rules_added = len(self.syncrules) - processor = BinaryVisitor(compile_binary) - sqlclause.accept_visitor(processor) + BinaryVisitor(compile_binary).traverse(sqlclause) if len(self.syncrules) == rules_added: raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause)) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 52324e63ee..78c31e9acd 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -42,7 +42,11 @@ class SchemaItem(object): """Associate with this SchemaItem's parent object.""" raise NotImplementedError() - + + def get_children(self, **kwargs): + """used to allow SchemaVisitor access""" + return [] + def __repr__(self): return "%s()" % self.__class__.__name__ @@ -322,11 +326,14 @@ class Table(SchemaItem, sql.TableClause): metadata.tables[_get_table_key(self.name, self.schema)] = self self._metadata = metadata - def accept_schema_visitor(self, visitor, traverse=True): - if traverse: - for c in self.columns: - c.accept_schema_visitor(visitor, True) - return visitor.visit_table(self) + def get_children(self, column_collections=True, schema_visitor=False, **kwargs): + if not schema_visitor: + return sql.TableClause.get_children(self, column_collections=column_collections, **kwargs) + else: + if column_collections: + return [c for c in self.columns] + else: + return [] def exists(self, connectable=None): """Return True if this table exists.""" @@ -604,20 +611,12 @@ class Column(SchemaItem, sql._ColumnClause): return self.__originating_column._get_case_sensitive() case_sensitive = property(_case_sens, lambda s,v:None) - def accept_schema_visitor(self, visitor, traverse=True): - """Traverse the given visitor to this ``Column``'s default and foreign key object, - then call `visit_column` on the visitor.""" - - if traverse: - if self.default is not None: - self.default.accept_schema_visitor(visitor, traverse=True) - if self.onupdate is not None: - self.onupdate.accept_schema_visitor(visitor, traverse=True) - for f in self.foreign_keys: - f.accept_schema_visitor(visitor, traverse=True) - for constraint in self.constraints: - constraint.accept_schema_visitor(visitor, traverse=True) - visitor.visit_column(self) + def get_children(self, schema_visitor=False, **kwargs): + if schema_visitor: + return [x for x in (self.default, self.onupdate) if x is not None] + \ + list(self.foreign_keys) + list(self.constraints) + else: + return sql._ColumnClause.get_children(self, **kwargs) class ForeignKey(SchemaItem): @@ -715,7 +714,7 @@ class ForeignKey(SchemaItem): column = property(lambda s: s._init_column()) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the `visit_foreign_key` method on the given visitor.""" visitor.visit_foreign_key(self) @@ -771,7 +770,7 @@ class PassiveDefault(DefaultGenerator): super(PassiveDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): return visitor.visit_passive_default(self) def __repr__(self): @@ -788,7 +787,7 @@ class ColumnDefault(DefaultGenerator): super(ColumnDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the visit_column_default method on the given visitor.""" if self.for_update: @@ -828,7 +827,7 @@ class Sequence(DefaultGenerator): def drop(self): self.get_engine().drop(self) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): """Call the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) @@ -871,7 +870,7 @@ class CheckConstraint(Constraint): super(CheckConstraint, self).__init__(name) self.sqltext = sqltext - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): if isinstance(self.parent, Table): visitor.visit_check_constraint(self) else: @@ -904,7 +903,7 @@ class ForeignKeyConstraint(Constraint): for (c, r) in zip(self.__colnames, self.__refcolnames): self.append_element(c,r) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_foreign_key_constraint(self) def append_element(self, col, refcol): @@ -930,7 +929,7 @@ class PrimaryKeyConstraint(Constraint): for c in self.__colnames: self.append_column(table.c[c]) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_primary_key_constraint(self) def add(self, col): @@ -964,7 +963,7 @@ class UniqueConstraint(Constraint): def append_column(self, col): self.columns.add(col) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_unique_constraint(self) def copy(self): @@ -1042,7 +1041,7 @@ class Index(SchemaItem): else: self.get_engine().drop(self) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_index(self) def __str__(self): @@ -1118,7 +1117,7 @@ class MetaData(SchemaItem): connectable = self.get_engine() connectable.drop(self, checkfirst=checkfirst, tables=tables) - def accept_schema_visitor(self, visitor, traverse=True): + def accept_visitor(self, visitor): visitor.visit_metadata(self) def _derived_metadata(self): @@ -1190,6 +1189,8 @@ class DynamicMetaData(MetaData): class SchemaVisitor(sql.ClauseVisitor): """Define the visiting for ``SchemaItem`` objects.""" + __traverse_options__ = {'schema_visitor':True} + def visit_schema(self, schema): """Visit a generic ``SchemaItem``.""" pass diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 073277d576..190ec29d40 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -5,7 +5,7 @@ """Define the base components of SQL expression trees.""" -from sqlalchemy import util, exceptions +from sqlalchemy import util, exceptions, logging from sqlalchemy import types as sqltypes import string, re, random, sets @@ -485,44 +485,103 @@ class ClauseParameters(dict): return d class ClauseVisitor(object): - """Define the visiting of ``ClauseElements``.""" - - def visit_column(self, column):pass - def visit_table(self, column):pass - def visit_fromclause(self, fromclause):pass - def visit_bindparam(self, bindparam):pass - def visit_textclause(self, textclause):pass - def visit_compound(self, compound):pass - def visit_compound_select(self, compound):pass - def visit_binary(self, binary):pass - def visit_alias(self, alias):pass - def visit_select(self, select):pass - def visit_join(self, join):pass - def visit_null(self, null):pass - def visit_clauselist(self, list):pass - def visit_calculatedclause(self, calcclause):pass - def visit_function(self, func):pass - def visit_cast(self, cast):pass - def visit_label(self, label):pass - def visit_typeclause(self, typeclause):pass - -class VisitColumnMixin(object): - """a mixin that adds Column traversal to a ClauseVisitor""" + """A class that knows how to traverse and visit + ``ClauseElements``. + + Each ``ClauseElement``'s accept_visitor() method will call a + corresponding visit_XXXX() method here. Traversal of a + hierarchy of ``ClauseElements`` is achieved via the + ``traverse()`` method, which is passed the lead + ``ClauseElement``. + + By default, ``ClauseVisitor`` traverses all elements + fully. Options can be specified at the class level via the + ``__traverse_options__`` dictionary which will be passed + to the ``get_children()`` method of each ``ClauseElement``; + these options can indicate modifications to the set of + elements returned, such as to not return column collections + (column_collections=False) or to return Schema-level items + (schema_visitor=True).""" + __traverse_options__ = {} + def traverse(self, obj): + for n in obj.get_children(**self.__traverse_options__): + self.traverse(n) + obj.accept_visitor(self) + def visit_column(self, column): + pass def visit_table(self, table): - for c in table.c: - c.accept_visitor(self) - def visit_select(self, select): - for c in select.c: - c.accept_visitor(self) - def visit_compound_select(self, select): - for c in select.c: - c.accept_visitor(self) + pass + def visit_fromclause(self, fromclause): + pass + def visit_bindparam(self, bindparam): + pass + def visit_textclause(self, textclause): + pass + def visit_compound(self, compound): + pass + def visit_compound_select(self, compound): + pass + def visit_binary(self, binary): + pass def visit_alias(self, alias): - for c in alias.c: - c.accept_visitor(self) - + pass + def visit_select(self, select): + pass + def visit_join(self, join): + pass + def visit_null(self, null): + pass + def visit_clauselist(self, list): + pass + def visit_calculatedclause(self, calcclause): + pass + def visit_function(self, func): + pass + def visit_cast(self, cast): + pass + def visit_label(self, label): + pass + def visit_typeclause(self, typeclause): + pass + +class LoggingClauseVisitor(ClauseVisitor): + """extends ClauseVisitor to include debug logging of all traversal. + + To install this visitor, set logging.DEBUG for + 'sqlalchemy.sql.ClauseVisitor' **before** you import the + sqlalchemy.sql module. + """ + + def traverse(self, obj): + indent = getattr(self, '_indent', "") + self.logger.debug(indent + "START " + repr(obj)) + setattr(self, "_indent", indent + " ") + for n in obj.get_children(**self.__traverse_options__): + self.traverse(n) + obj.accept_visitor(self) + setattr(self, "_indent", indent) + self.logger.debug(indent+ "END " + repr(obj)) + +LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor) + +if logging.is_debug_enabled(LoggingClauseVisitor.logger): + ClauseVisitor=LoggingClauseVisitor + +class NoColumnVisitor(ClauseVisitor): + """a ClauseVisitor that will not traverse the exported Column + collections on Table, Alias, Select, and CompoundSelect objects + (i.e. their 'columns' or 'c' attribute). + + this is useful because most traversals don't need those columns, or + in the case of ANSICompiler it traverses them explicitly; so + skipping their traversal here greatly cuts down on method call overhead. + """ + + __traverse_options__ = {'column_collections':False} + class Executor(object): - """Represent a *thing that can produce Compiled objects and execute them*.""" + """Interface representing a *thing that can produce Compiled objects + and execute them*.""" def execute_compiled(self, compiled, parameters, echo=None, **kwargs): """Execute a Compiled object.""" @@ -539,7 +598,7 @@ class Compiled(ClauseVisitor): The ``__str__`` method of the ``Compiled`` object should produce the actual text of the statement. ``Compiled`` objects are - specific to the database library that created them, and also may + specific to their underlying database dialect, and also may or may not be specific to the columns referenced within a particular set of bind parameters. In no case should the ``Compiled`` object be dependent on the actual values of those @@ -547,7 +606,7 @@ class Compiled(ClauseVisitor): defaults. """ - def __init__(self, dialect, statement, parameters, engine=None): + def __init__(self, dialect, statement, parameters, engine=None, traversal=None): """Construct a new Compiled object. statement @@ -570,7 +629,7 @@ class Compiled(ClauseVisitor): engine Optional Engine to compile this statement against. """ - + ClauseVisitor.__init__(self, traversal=traversal) self.dialect = dialect self.statement = statement self.parameters = parameters @@ -578,7 +637,7 @@ class Compiled(ClauseVisitor): self.can_execute = statement.supports_execution() def compile(self): - self.statement.accept_visitor(self) + self.traverse(self.statement) self.after_compile() def __str__(self): @@ -649,7 +708,19 @@ class ClauseElement(object): """ raise NotImplementedError(repr(self)) - + + def get_children(self, **kwargs): + """return immediate child elements of this ``ClauseElement``. + + this is used for visit traversal. + + **kwargs may contain flags that change the collection + that is returned, for example to return a subset of items + in order to cut down on larger traversals, or to return + child items from a different context (such as schema-level + collections instead of clause-level).""" + return [] + def supports_execution(self): """Return True if this clause element represents a complete executable statement. @@ -1058,16 +1129,38 @@ class FromClause(Selectable): def _get_all_embedded_columns(self): ret = [] - class FindCols(VisitColumnMixin, ClauseVisitor): + class FindCols(ClauseVisitor): def visit_column(self, col): ret.append(col) - self.accept_visitor(FindCols()) + FindCols().traverse(self) return ret def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False): - """Given a ``ColumnElement``, return the ``ColumnElement`` - object from this ``Selectable`` which corresponds to that - original ``Column`` via a proxy relationship. + """Given a ``ColumnElement``, return the exported + ``ColumnElement`` object from this ``Selectable`` which + corresponds to that original ``Column`` via a common + anscestor column. + + column + the target ``ColumnElement`` to be matched + + raiseerr + if True, raise an error if the given ``ColumnElement`` + could not be matched. if False, non-matches will + return None. + + keys_ok + if the ``ColumnElement`` cannot be matched, attempt to + match based on the string "key" property of the column + alone. This makes the search much more liberal. + + require_embedded + only return corresponding columns for the given + ``ColumnElement``, if the given ``ColumnElement`` is + actually present within a sub-element of this + ``FromClause``. Normally the column will match if + it merely shares a common anscestor with one of + the exported columns of this ``FromClause``. """ if require_embedded and column not in util.Set(self._get_all_embedded_columns()): @@ -1258,11 +1351,14 @@ class _TextClause(ClauseElement): if bindparams is not None: for b in bindparams: self.bindparams[b.key] = b - columns = property(lambda s:[]) - def accept_visitor(self, visitor): - for item in self.bindparams.values(): - item.accept_visitor(visitor) + columns = property(lambda s:[]) + + def get_children(self, **kwargs): + return self.bindparams.values() + + def accept_visitor(self, visitor): visitor.visit_textclause(self) + def _get_from_objects(self): return [] def supports_execution(self): @@ -1296,9 +1392,9 @@ class ClauseList(ClauseElement): if _is_literal(clause): clause = _TextClause(str(clause)) self.clauses.append(clause) + def get_children(self, **kwargs): + return self.clauses def accept_visitor(self, visitor): - for c in self.clauses: - c.accept_visitor(visitor) visitor.visit_clauselist(self) def _get_from_objects(self): f = [] @@ -1338,9 +1434,9 @@ class _CompoundClause(ClauseList): clause.parens = True ClauseList.append(self, clause) + def get_children(self, **kwargs): + return self.clauses def accept_visitor(self, visitor): - for c in self.clauses: - c.accept_visitor(visitor) visitor.visit_compound(self) def _get_from_objects(self): @@ -1384,9 +1480,9 @@ class _CalculatedClause(ClauseList, ColumnElement): clauses = [clause.copy_container() for clause in self.clauses] return _CalculatedClause(type=self.type, engine=self._engine, *clauses) + def get_children(self, **kwargs): + return self.clauses def accept_visitor(self, visitor): - for c in self.clauses: - c.accept_visitor(visitor) visitor.visit_calculatedclause(self) def _bind_param(self, obj): @@ -1432,9 +1528,9 @@ class _Function(_CalculatedClause, FromClause): clauses = [clause.copy_container() for clause in self.clauses] return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses) + def get_children(self, **kwargs): + return self.clauses def accept_visitor(self, visitor): - for c in self.clauses: - c.accept_visitor(visitor) visitor.visit_function(self) class _Cast(ColumnElement): @@ -1445,9 +1541,9 @@ class _Cast(ColumnElement): self.clause = clause self.typeclause = _TypeClause(self.type) + def get_children(self, **kwargs): + return self.clause, self.typeclause def accept_visitor(self, visitor): - self.clause.accept_visitor(visitor) - self.typeclause.accept_visitor(visitor) visitor.visit_cast(self) def _get_from_objects(self): @@ -1494,9 +1590,9 @@ class _BinaryClause(ClauseElement): return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator) def _get_from_objects(self): return self.left._get_from_objects() + self.right._get_from_objects() + def get_children(self, **kwargs): + return self.left, self.right def accept_visitor(self, visitor): - self.left.accept_visitor(visitor) - self.right.accept_visitor(visitor) visitor.visit_binary(self) def swap(self): c = self.left @@ -1589,12 +1685,12 @@ class Join(FromClause): def _get_folded_equivalents(self, equivs=None): if equivs is None: equivs = util.Set() - class LocateEquivs(ClauseVisitor): + class LocateEquivs(NoColumnVisitor): def visit_binary(self, binary): if binary.operator == '=' and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) - self.onclause.accept_visitor(LocateEquivs()) + LocateEquivs().traverse(self.onclause) collist = [] if isinstance(self.left, Join): left = self.left._get_folded_equivalents(equivs) @@ -1636,10 +1732,9 @@ class Join(FromClause): return select(collist, whereclause, from_obj=[self], **kwargs) + def get_children(self, **kwargs): + return self.left, self.right, self.onclause def accept_visitor(self, visitor): - self.left.accept_visitor(visitor) - self.right.accept_visitor(visitor) - self.onclause.accept_visitor(visitor) visitor.visit_join(self) engine = property(lambda s:s.left.engine or s.right.engine) @@ -1692,8 +1787,11 @@ class Alias(FromClause): #return self.selectable._exportable_columns() return self.selectable.columns + def get_children(self, **kwargs): + for c in self.c: + yield c + yield self.selectable def accept_visitor(self, visitor): - self.selectable.accept_visitor(visitor) visitor.visit_alias(self) def _get_from_objects(self): @@ -1717,9 +1815,10 @@ class _Label(ColumnElement): key = property(lambda s: s.name) _label = property(lambda s: s.name) orig_set = property(lambda s:s.obj.orig_set) - + + def get_children(self, **kwargs): + return self.obj, def accept_visitor(self, visitor): - self.obj.accept_visitor(visitor) visitor.visit_label(self) def _get_from_objects(self): @@ -1841,6 +1940,11 @@ class TableClause(FromClause): original_columns = property(_orig_columns) + def get_children(self, column_collections=True, **kwargs): + if column_collections: + return [c for c in self.c] + else: + return [] def accept_visitor(self, visitor): visitor.visit_table(self) @@ -1964,11 +2068,10 @@ class CompoundSelect(_SelectBaseMixin, FromClause): col.orig_set = colset return col + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.c) or []) + \ + [self.order_by_clause, self.group_by_clause] + list(self.selects) def accept_visitor(self, visitor): - self.order_by_clause.accept_visitor(visitor) - self.group_by_clause.accept_visitor(visitor) - for s in self.selects: - s.accept_visitor(visitor) visitor.visit_compound_select(self) def _find_engine(self): @@ -2028,9 +2131,9 @@ class Select(_SelectBaseMixin, FromClause): self.order_by(*(order_by or [None])) self.group_by(*(group_by or [None])) for c in self.order_by_clause: - c.accept_visitor(self.__correlator) + self.__correlator.traverse(c) for c in self.group_by_clause: - c.accept_visitor(self.__correlator) + self.__correlator.traverse(c) for f in from_obj: self.append_from(f) @@ -2044,13 +2147,14 @@ class Select(_SelectBaseMixin, FromClause): self.append_having(having) - class _CorrelatedVisitor(ClauseVisitor): + class _CorrelatedVisitor(NoColumnVisitor): """Visit a clause, locate any ``Select`` clauses, and tell them that they should correlate their ``FROM`` list to that of their parent. """ def __init__(self, select, is_where): + NoColumnVisitor.__init__(self) self.select = select self.is_where = is_where @@ -2084,12 +2188,12 @@ class Select(_SelectBaseMixin, FromClause): # if the column is a Select statement itself, # accept visitor - column.accept_visitor(self.__correlator) + self.__correlator.traverse(column) # visit the FROM objects of the column looking for more Selects for f in column._get_from_objects(): if f is not self: - f.accept_visitor(self.__correlator) + self.__correlator.traverse(f) self._process_froms(column, False) def _make_proxy(self, selectable, name): if self.is_scalar: @@ -2127,7 +2231,7 @@ class Select(_SelectBaseMixin, FromClause): def _append_condition(self, attribute, condition): if type(condition) == str: condition = _TextClause(condition) - condition.accept_visitor(self.__wherecorrelator) + self.__wherecorrelator.traverse(condition) self._process_froms(condition, False) if getattr(self, attribute) is not None: setattr(self, attribute, and_(getattr(self, attribute), condition)) @@ -2146,7 +2250,7 @@ class Select(_SelectBaseMixin, FromClause): def append_from(self, fromclause): if type(fromclause) == str: fromclause = FromClause(fromclause) - fromclause.accept_visitor(self.__correlator) + self.__correlator.traverse(fromclause) self._process_froms(fromclause, True) def _locate_oid_column(self): @@ -2169,16 +2273,14 @@ class Select(_SelectBaseMixin, FromClause): return f froms = property(_calc_froms, doc="""A collection containing all elements of the FROM clause""") + + def get_children(self, column_collections=True, **kwargs): + return (column_collections and list(self.columns) or []) + \ + list(self.froms) + \ + [x for x in (self.whereclause, self.having) if x is not None] + \ + [self.order_by_clause, self.group_by_clause] def accept_visitor(self, visitor): - for f in self.froms: - f.accept_visitor(visitor) - if self.whereclause is not None: - self.whereclause.accept_visitor(visitor) - if self.having is not None: - self.having.accept_visitor(visitor) - self.order_by_clause.accept_visitor(visitor) - self.group_by_clause.accept_visitor(visitor) visitor.visit_select(self) def union(self, other, **kwargs): @@ -2259,10 +2361,12 @@ class _Insert(_UpdateBase): self.select = None self.parameters = self._process_colparams(values) - def accept_visitor(self, visitor): + def get_children(self, **kwargs): if self.select is not None: - self.select.accept_visitor(visitor) - + return self.select, + else: + return () + def accept_visitor(self, visitor): visitor.visit_insert(self) class _Update(_UpdateBase): @@ -2271,9 +2375,12 @@ class _Update(_UpdateBase): self.whereclause = whereclause self.parameters = self._process_colparams(values) - def accept_visitor(self, visitor): + def get_children(self, **kwargs): if self.whereclause is not None: - self.whereclause.accept_visitor(visitor) + return self.whereclause, + else: + return () + def accept_visitor(self, visitor): visitor.visit_update(self) class _Delete(_UpdateBase): @@ -2281,7 +2388,10 @@ class _Delete(_UpdateBase): self.table = table self.whereclause = whereclause - def accept_visitor(self, visitor): + def get_children(self, **kwargs): if self.whereclause is not None: - self.whereclause.accept_visitor(visitor) + return self.whereclause, + else: + return () + def accept_visitor(self, visitor): visitor.visit_delete(self) diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 70fc85702e..1d185bbc5f 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -51,7 +51,7 @@ class TableCollection(object): tuples.append( ( parent_table, child_table ) ) vis = TVisitor() for table in self.tables: - table.accept_schema_visitor(vis) + vis.traverse(table) sorter = topological.QueueDependencySorter( tuples, self.tables ) head = sorter.sort() sequence = [] @@ -64,21 +64,21 @@ class TableCollection(object): return sequence -class TableFinder(TableCollection, sql.ClauseVisitor): +class TableFinder(TableCollection, sql.NoColumnVisitor): """Given a ``Clause``, locate all the ``Tables`` within it into a list.""" def __init__(self, table, check_columns=False): TableCollection.__init__(self) self.check_columns = check_columns if table is not None: - table.accept_visitor(self) + self.traverse(table) def visit_table(self, table): self.tables.append(table) def visit_column(self, column): if self.check_columns: - column.table.accept_visitor(self) + self.traverse(column.table) class ColumnFinder(sql.ClauseVisitor): def __init__(self): @@ -103,7 +103,7 @@ class ColumnsInClause(sql.ClauseVisitor): if self.selectable.c.get(column.key) is column: self.result = True -class AbstractClauseProcessor(sql.ClauseVisitor): +class AbstractClauseProcessor(sql.NoColumnVisitor): """Traverse a clause and attempt to convert the contents of container elements to a converted element. @@ -132,7 +132,7 @@ class AbstractClauseProcessor(sql.ClauseVisitor): if elem is not None: list_[i] = elem else: - list_[i].accept_visitor(self) + self.traverse(list_[i]) def visit_compound(self, compound): self.visit_clauselist(compound) @@ -198,7 +198,7 @@ class ClauseAdapter(AbstractClauseProcessor): s = table1.alias('foo') - calling ``condition.accept_visitor(ClauseAdapter(s))`` converts + calling ``ClauseAdapter(s).traverse(condition)`` converts condition to read:: s.c.col1 == table2.c.col1 diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 388ed30c80..51a3d35c67 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -500,8 +500,8 @@ class SchemaTest(PersistTest): def foo(s, p): buf.write(s) gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo, None) - table1.accept_schema_visitor(gen) - table2.accept_schema_visitor(gen) + gen.traverse(table1) + gen.traverse(table2) buf = buf.getvalue() print buf assert buf.index("CREATE TABLE someschema.table1") > -1 diff --git a/test/sql/constraints.py b/test/sql/constraints.py index 79ccee4da2..231a491b52 100644 --- a/test/sql/constraints.py +++ b/test/sql/constraints.py @@ -177,7 +177,7 @@ class ConstraintTest(testbase.AssertMixin): capt.append(repr(parameters)) connection.proxy(statement, parameters) schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy, connection) - events.accept_schema_visitor(schemagen) + schemagen.traverse(events) assert capt[0].strip().startswith('CREATE TABLE events')