use a common traversal function.
- TranslatingDict is finally gone, thanks to column.proxy_set simpleness...hooray !
- shoved "slice" use case on RowProxy into an exception case. knocks noticeable time off of large result set operations.
return self.__parent._has_key(self.__row, key)
def __getitem__(self, key):
- if isinstance(key, slice):
- indices = key.indices(len(self))
- return tuple([self.__parent._get_col(self.__row, i)
- for i in xrange(*indices)])
- else:
- return self.__parent._get_col(self.__row, key)
+ return self.__parent._get_col(self.__row, key)
def __getattr__(self, name):
try:
return self.dialect.supports_sane_multi_rowcount
def _get_col(self, row, key):
- rec = self._key_cache[key]
+ try:
+ rec = self._key_cache[key]
+ except TypeError:
+ # the 'slice' use case is very infrequent,
+ # so we use an exception catch to reduce conditionals in _get_col
+ if isinstance(key, slice):
+ indices = key.indices(len(row))
+ return tuple([self._get_col(row, i) for i in xrange(*indices)])
+
if rec[1]:
return rec[1](row[rec[2]])
else:
def _fetchone_impl(self):
return self.cursor.fetchone()
+
def _fetchmany_impl(self, size=None):
return self.cursor.fetchmany(size)
+
def _fetchall_impl(self):
return self.cursor.fetchall()
# locate all tables contained within the "table" passed in, which
# may be a join or other construct
- self.tables = sqlutil.TableFinder(self.mapped_table)
+ self.tables = sqlutil.find_tables(self.mapped_table)
if not self.tables:
raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
# table columns mapped to lists of MapperProperty objects
# using a list allows a single column to be defined as
# populating multiple object attributes
- self._columntoproperty = mapperutil.TranslatingDict(self.mapped_table)
+ self._columntoproperty = {} #mapperutil.TranslatingDict(self.mapped_table)
# load custom properties
if self._init_properties is not None:
self.columns[key] = col
for col in prop.columns:
- self._columntoproperty[col] = prop
+ for col in col.proxy_set:
+ self._columntoproperty[col] = prop
self.__props[key] = prop
for t in mapper.tables:
table_to_mapper.setdefault(t, mapper)
- for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False):
+ for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=False):
# two lists to store parameters for each table/object pair located
insert = []
update = []
for t in mapper.tables:
table_to_mapper.setdefault(t, mapper)
- for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True):
+ for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True):
delete = {}
for (obj, connection) in tups:
mapper = object_mapper(obj)
mapper.extension.after_delete(mapper, connection, obj)
def _has_pks(self, table):
- try:
- pk = self.pks_by_table[table]
- if not pk:
- return False
- for k in pk:
+ if self.pks_by_table.get(table, None):
+ for k in self.pks_by_table[table]:
if k not in self._columntoproperty:
return False
else:
return True
- except KeyError:
+ else:
return False
def register_dependencies(self, uowcommit, *args, **kwargs):
# "inner" select statement where they'll be available to the enclosing
# statement's "order by"
- cf = sql_util.ColumnFinder()
+ cf = util.Set()
if order_by:
order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
for o in order_by:
- cf.traverse(o)
+ cf.update(sql_util.find_columns(o))
s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args())
# to use it in "order_by". ensure they are in the column criterion (particularly oid).
if self._distinct and order_by:
order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
- cf = sql_util.ColumnFinder()
+ cf = util.Set()
for o in order_by:
- cf.traverse(o)
+ cf.update(sql_util.find_columns(o))
[statement.append_column(c) for c in cf]
return None
elif isinstance(m, sql.ColumnElement):
aliases = []
- for table in sql_util.TableFinder(m, check_columns=True):
+ for table in sql_util.find_tables(m, check_columns=True):
for a in self._alias_ids.get(table, []):
aliases.append(a)
if len(aliases) > 1:
towrap = fromclause
break
elif isinstance(fromclause, sql.Join):
- if localparent.mapped_table in sql_util.TableFinder(fromclause, include_aliases=True):
+ if localparent.mapped_table in sql_util.find_tables(fromclause, include_aliases=True):
towrap = fromclause
break
else:
result.append(sql.select([col(name, table) for name in colnames], from_obj=[table]))
return sql.union_all(*result).alias(aliasname)
-class TranslatingDict(dict):
- """A dictionary that stores ``ColumnElement`` objects as keys.
-
- Incoming ``ColumnElement`` keys are translated against those of an
- underling ``FromClause`` for all operations. This way the columns
- from any ``Selectable`` that is derived from or underlying this
- ``TranslatingDict`` 's selectable can be used as keys.
- """
-
- def __init__(self, selectable):
- super(TranslatingDict, self).__init__()
- self.selectable = selectable
-
- def __translate_col(self, col):
- ourcol = self.selectable.corresponding_column(col, raiseerr=False)
- if ourcol is None:
- return col
- else:
- return ourcol
-
- def __getitem__(self, col):
- try:
- return super(TranslatingDict, self).__getitem__(col)
- except KeyError:
- return super(TranslatingDict, self).__getitem__(self.__translate_col(col))
-
- def has_key(self, col):
- return col in self
-
- def __setitem__(self, col, value):
- return super(TranslatingDict, self).__setitem__(self.__translate_col(col), value)
-
- def __contains__(self, col):
- return super(TranslatingDict, self).__contains__(self.__translate_col(col))
-
- def setdefault(self, col, value):
- return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)
class ExtensionCarrier(object):
"""stores a collection of MapperExtension objects.
tables = self.tables.values()
else:
tables = util.Set(tables).intersection(self.tables.values())
- sorter = sql_util.TableCollection(list(tables))
- return iter(sorter.sort(reverse=reverse))
+ return iter(sql_util.sort_tables(tables, reverse=reverse))
def _get_parent(self):
return None
"""Utility functions that build upon SQL and Schema constructs."""
-# TODO: replace with plain list. break out sorting funcs into module-level funcs
-class TableCollection(object):
- def __init__(self, tables=None):
- self.tables = tables or []
-
- def __len__(self):
- return len(self.tables)
-
- def __getitem__(self, i):
- return self.tables[i]
-
- def __iter__(self):
- return iter(self.tables)
-
- def __contains__(self, obj):
- return obj in self.tables
-
- def __add__(self, obj):
- return self.tables + list(obj)
-
- def add(self, table):
- self.tables.append(table)
- if hasattr(self, '_sorted'):
- del self._sorted
-
- def sort(self, reverse=False):
- try:
- sorted = self._sorted
- except AttributeError, e:
- self._sorted = self._do_sort()
- sorted = self._sorted
- if reverse:
- x = sorted[:]
- x.reverse()
- return x
- else:
- return sorted
-
- def _do_sort(self):
- tuples = []
- class TVisitor(schema.SchemaVisitor):
- def visit_foreign_key(_self, fkey):
- if fkey.use_alter:
- return
- parent_table = fkey.column.table
- if parent_table in self:
- child_table = fkey.parent.table
- tuples.append( ( parent_table, child_table ) )
- vis = TVisitor()
- for table in self.tables:
- vis.traverse(table)
- sorter = topological.QueueDependencySorter( tuples, self.tables )
- head = sorter.sort()
- sequence = []
- def to_sequence( node, seq=sequence):
- seq.append( node.item )
- for child in node.children:
- to_sequence( child )
- if head is not None:
- to_sequence( head )
- return sequence
-
-
-# TODO: replace with plain module-level func
-class TableFinder(TableCollection, visitors.NoColumnVisitor):
- """locate all Tables within a clause."""
-
- def __init__(self, clause, check_columns=False, include_aliases=False):
- TableCollection.__init__(self)
- self.check_columns = check_columns
- self.include_aliases = include_aliases
- for clause in util.to_list(clause):
- self.traverse(clause)
-
- def visit_alias(self, alias):
- if self.include_aliases:
- self.tables.append(alias)
-
- def visit_table(self, table):
- self.tables.append(table)
-
- def visit_column(self, column):
- if self.check_columns:
- self.tables.append(column.table)
-
-class ColumnFinder(visitors.ClauseVisitor):
- def __init__(self):
- self.columns = util.Set()
-
- def visit_column(self, c):
- self.columns.add(c)
-
- def __iter__(self):
- return iter(self.columns)
-
-def find_columns(selectable):
- cf = ColumnFinder()
- cf.traverse(selectable)
- return iter(cf)
+def sort_tables(tables, reverse=False):
+ tuples = []
+ class TVisitor(schema.SchemaVisitor):
+ def visit_foreign_key(_self, fkey):
+ if fkey.use_alter:
+ return
+ parent_table = fkey.column.table
+ if parent_table in tables:
+ child_table = fkey.parent.table
+ tuples.append( ( parent_table, child_table ) )
+ vis = TVisitor()
+ for table in tables:
+ vis.traverse(table)
+ sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False)
+ if reverse:
+ sequence.reverse()
+ return sequence
+
+def find_tables(clause, check_columns=False, include_aliases=False):
+ tables = []
+ kwargs = {}
+ if include_aliases:
+ def visit_alias(alias):
+ tables.append(alias)
+ kwargs['visit_alias'] = visit_alias
+
+ if check_columns:
+ def visit_column(column):
+ tables.append(column.table)
+ kwargs['visit_column'] = visit_column
+
+ def visit_table(table):
+ tables.append(table)
+ kwargs['visit_table'] = visit_table
+
+ visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs)
+ return tables
+
+def find_columns(clause):
+ cols = util.Set()
+ def visit_column(col):
+ cols.add(col)
+ visitors.traverse(clause, visit_column=visit_column)
+ return cols
class ColumnsInClause(visitors.ClauseVisitor):
"""Given a selectable, visit clauses and determine if any columns
self.tuples = tuples
self.allitems = allitems
- def sort(self, allow_self_cycles=True, allow_all_cycles=False):
+ def sort(self, allow_self_cycles=True, allow_all_cycles=False, create_tree=True):
(tuples, allitems) = (self.tuples, self.allitems)
#print "\n---------------------------------\n"
#print repr([t for t in tuples])
nodes = {}
edges = _EdgeCollection()
- for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]:
+ for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]:
if id(item) not in nodes:
node = _Node(item)
nodes[id(item)] = node
del nodes[id(node.item)]
for childnode in edges.pop_node(node):
queue.append(childnode)
- return self._create_batched_tree(output)
+ if create_tree:
+ return self._create_batched_tree(output)
+ else:
+ return [n.item for n in output]
def _create_batched_tree(self, nodes):