From 6f604f911640d92f705fc6611bfaa3e2600c4ee1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 24 Nov 2007 01:59:29 +0000 Subject: [PATCH] - decruftify old visitors used by orm, convert to functions that use a common traversal function. - TranslatingDict is finally gone, thanks to column.proxy_set simpleness...hooray ! - shoved "slice" use case on RowProxy into an exception case. knocks noticeable time off of large result set operations. --- lib/sqlalchemy/engine/base.py | 19 ++-- lib/sqlalchemy/orm/mapper.py | 20 ++--- lib/sqlalchemy/orm/query.py | 10 +-- lib/sqlalchemy/orm/strategies.py | 2 +- lib/sqlalchemy/orm/util.py | 37 -------- lib/sqlalchemy/schema.py | 3 +- lib/sqlalchemy/sql/util.py | 143 ++++++++++--------------------- lib/sqlalchemy/topological.py | 9 +- 8 files changed, 78 insertions(+), 165 deletions(-) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 9c7e70ba92..21977b689b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1243,12 +1243,7 @@ class RowProxy(object): return self.__parent._has_key(self.__row, key) def __getitem__(self, key): - if isinstance(key, slice): - indices = key.indices(len(self)) - return tuple([self.__parent._get_col(self.__row, i) - for i in xrange(*indices)]) - else: - return self.__parent._get_col(self.__row, key) + return self.__parent._get_col(self.__row, key) def __getattr__(self, name): try: @@ -1474,7 +1469,15 @@ class ResultProxy(object): return self.dialect.supports_sane_multi_rowcount def _get_col(self, row, key): - rec = self._key_cache[key] + try: + rec = self._key_cache[key] + except TypeError: + # the 'slice' use case is very infrequent, + # so we use an exception catch to reduce conditionals in _get_col + if isinstance(key, slice): + indices = key.indices(len(row)) + return tuple([self._get_col(row, i) for i in xrange(*indices)]) + if rec[1]: return rec[1](row[rec[2]]) else: @@ -1482,8 +1485,10 @@ class ResultProxy(object): def _fetchone_impl(self): return self.cursor.fetchone() + def _fetchmany_impl(self, size=None): return self.cursor.fetchmany(size) + def _fetchall_impl(self): return self.cursor.fetchall() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 3bacb13e93..1414336ac6 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -392,7 +392,7 @@ class Mapper(object): # locate all tables contained within the "table" passed in, which # may be a join or other construct - self.tables = sqlutil.TableFinder(self.mapped_table) + self.tables = sqlutil.find_tables(self.mapped_table) if not self.tables: raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) @@ -576,7 +576,7 @@ class Mapper(object): # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as # populating multiple object attributes - self._columntoproperty = mapperutil.TranslatingDict(self.mapped_table) + self._columntoproperty = {} #mapperutil.TranslatingDict(self.mapped_table) # load custom properties if self._init_properties is not None: @@ -663,7 +663,8 @@ class Mapper(object): self.columns[key] = col for col in prop.columns: - self._columntoproperty[col] = prop + for col in col.proxy_set: + self._columntoproperty[col] = prop self.__props[key] = prop @@ -995,7 +996,7 @@ class Mapper(object): for t in mapper.tables: table_to_mapper.setdefault(t, mapper) - for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False): + for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=False): # two lists to store parameters for each table/object pair located insert = [] update = [] @@ -1217,7 +1218,7 @@ class Mapper(object): for t in mapper.tables: table_to_mapper.setdefault(t, mapper) - for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True): + for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True): delete = {} for (obj, connection) in tups: mapper = object_mapper(obj) @@ -1260,16 +1261,13 @@ class Mapper(object): mapper.extension.after_delete(mapper, connection, obj) def _has_pks(self, table): - try: - pk = self.pks_by_table[table] - if not pk: - return False - for k in pk: + if self.pks_by_table.get(table, None): + for k in self.pks_by_table[table]: if k not in self._columntoproperty: return False else: return True - except KeyError: + else: return False def register_dependencies(self, uowcommit, *args, **kwargs): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 77f7fbe04d..1214025849 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -900,12 +900,12 @@ class Query(object): # "inner" select statement where they'll be available to the enclosing # statement's "order by" - cf = sql_util.ColumnFinder() + cf = util.Set() if order_by: order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] for o in order_by: - cf.traverse(o) + cf.update(sql_util.find_columns(o)) s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args()) @@ -934,9 +934,9 @@ class Query(object): # to use it in "order_by". ensure they are in the column criterion (particularly oid). if self._distinct and order_by: order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] - cf = sql_util.ColumnFinder() + cf = util.Set() for o in order_by: - cf.traverse(o) + cf.update(sql_util.find_columns(o)) [statement.append_column(c) for c in cf] @@ -976,7 +976,7 @@ class Query(object): return None elif isinstance(m, sql.ColumnElement): aliases = [] - for table in sql_util.TableFinder(m, check_columns=True): + for table in sql_util.find_tables(m, check_columns=True): for a in self._alias_ids.get(table, []): aliases.append(a) if len(aliases) > 1: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 7911b93c8b..0277218c73 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -495,7 +495,7 @@ class EagerLoader(AbstractRelationLoader): towrap = fromclause break elif isinstance(fromclause, sql.Join): - if localparent.mapped_table in sql_util.TableFinder(fromclause, include_aliases=True): + if localparent.mapped_table in sql_util.find_tables(fromclause, include_aliases=True): towrap = fromclause break else: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index a5e2d1e6e8..7be72dc3c1 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -77,43 +77,6 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'): result.append(sql.select([col(name, table) for name in colnames], from_obj=[table])) return sql.union_all(*result).alias(aliasname) -class TranslatingDict(dict): - """A dictionary that stores ``ColumnElement`` objects as keys. - - Incoming ``ColumnElement`` keys are translated against those of an - underling ``FromClause`` for all operations. This way the columns - from any ``Selectable`` that is derived from or underlying this - ``TranslatingDict`` 's selectable can be used as keys. - """ - - def __init__(self, selectable): - super(TranslatingDict, self).__init__() - self.selectable = selectable - - def __translate_col(self, col): - ourcol = self.selectable.corresponding_column(col, raiseerr=False) - if ourcol is None: - return col - else: - return ourcol - - def __getitem__(self, col): - try: - return super(TranslatingDict, self).__getitem__(col) - except KeyError: - return super(TranslatingDict, self).__getitem__(self.__translate_col(col)) - - def has_key(self, col): - return col in self - - def __setitem__(self, col, value): - return super(TranslatingDict, self).__setitem__(self.__translate_col(col), value) - - def __contains__(self, col): - return super(TranslatingDict, self).__contains__(self.__translate_col(col)) - - def setdefault(self, col, value): - return super(TranslatingDict, self).setdefault(self.__translate_col(col), value) class ExtensionCarrier(object): """stores a collection of MapperExtension objects. diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 21e36fe6f8..0e1e5f7a9c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -1143,8 +1143,7 @@ class MetaData(SchemaItem): tables = self.tables.values() else: tables = util.Set(tables).intersection(self.tables.values()) - sorter = sql_util.TableCollection(list(tables)) - return iter(sorter.sort(reverse=reverse)) + return iter(sql_util.sort_tables(tables, reverse=reverse)) def _get_parent(self): return None diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 3e2d4ec311..d3e89d57e6 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -3,105 +3,50 @@ from sqlalchemy.sql import expression, visitors """Utility functions that build upon SQL and Schema constructs.""" -# TODO: replace with plain list. break out sorting funcs into module-level funcs -class TableCollection(object): - def __init__(self, tables=None): - self.tables = tables or [] - - def __len__(self): - return len(self.tables) - - def __getitem__(self, i): - return self.tables[i] - - def __iter__(self): - return iter(self.tables) - - def __contains__(self, obj): - return obj in self.tables - - def __add__(self, obj): - return self.tables + list(obj) - - def add(self, table): - self.tables.append(table) - if hasattr(self, '_sorted'): - del self._sorted - - def sort(self, reverse=False): - try: - sorted = self._sorted - except AttributeError, e: - self._sorted = self._do_sort() - sorted = self._sorted - if reverse: - x = sorted[:] - x.reverse() - return x - else: - return sorted - - def _do_sort(self): - tuples = [] - class TVisitor(schema.SchemaVisitor): - def visit_foreign_key(_self, fkey): - if fkey.use_alter: - return - parent_table = fkey.column.table - if parent_table in self: - child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) - vis = TVisitor() - for table in self.tables: - vis.traverse(table) - sorter = topological.QueueDependencySorter( tuples, self.tables ) - head = sorter.sort() - sequence = [] - def to_sequence( node, seq=sequence): - seq.append( node.item ) - for child in node.children: - to_sequence( child ) - if head is not None: - to_sequence( head ) - return sequence - - -# TODO: replace with plain module-level func -class TableFinder(TableCollection, visitors.NoColumnVisitor): - """locate all Tables within a clause.""" - - def __init__(self, clause, check_columns=False, include_aliases=False): - TableCollection.__init__(self) - self.check_columns = check_columns - self.include_aliases = include_aliases - for clause in util.to_list(clause): - self.traverse(clause) - - def visit_alias(self, alias): - if self.include_aliases: - self.tables.append(alias) - - def visit_table(self, table): - self.tables.append(table) - - def visit_column(self, column): - if self.check_columns: - self.tables.append(column.table) - -class ColumnFinder(visitors.ClauseVisitor): - def __init__(self): - self.columns = util.Set() - - def visit_column(self, c): - self.columns.add(c) - - def __iter__(self): - return iter(self.columns) - -def find_columns(selectable): - cf = ColumnFinder() - cf.traverse(selectable) - return iter(cf) +def sort_tables(tables, reverse=False): + tuples = [] + class TVisitor(schema.SchemaVisitor): + def visit_foreign_key(_self, fkey): + if fkey.use_alter: + return + parent_table = fkey.column.table + if parent_table in tables: + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) + vis = TVisitor() + for table in tables: + vis.traverse(table) + sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False) + if reverse: + sequence.reverse() + return sequence + +def find_tables(clause, check_columns=False, include_aliases=False): + tables = [] + kwargs = {} + if include_aliases: + def visit_alias(alias): + tables.append(alias) + kwargs['visit_alias'] = visit_alias + + if check_columns: + def visit_column(column): + tables.append(column.table) + kwargs['visit_column'] = visit_column + + def visit_table(table): + tables.append(table) + kwargs['visit_table'] = visit_table + + visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs) + return tables + +def find_columns(clause): + cols = util.Set() + def visit_column(col): + cols.add(col) + visitors.traverse(clause, visit_column=visit_column) + return cols class ColumnsInClause(visitors.ClauseVisitor): """Given a selectable, visit clauses and determine if any columns diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index a47968519d..ccded5d47d 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -143,7 +143,7 @@ class QueueDependencySorter(object): self.tuples = tuples self.allitems = allitems - def sort(self, allow_self_cycles=True, allow_all_cycles=False): + def sort(self, allow_self_cycles=True, allow_all_cycles=False, create_tree=True): (tuples, allitems) = (self.tuples, self.allitems) #print "\n---------------------------------\n" #print repr([t for t in tuples]) @@ -152,7 +152,7 @@ class QueueDependencySorter(object): nodes = {} edges = _EdgeCollection() - for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]: + for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]: if id(item) not in nodes: node = _Node(item) nodes[id(item)] = node @@ -205,7 +205,10 @@ class QueueDependencySorter(object): del nodes[id(node.item)] for childnode in edges.pop_node(node): queue.append(childnode) - return self._create_batched_tree(output) + if create_tree: + return self._create_batched_tree(output) + else: + return [n.item for n in output] def _create_batched_tree(self, nodes): -- 2.47.2