]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- decruftify old visitors used by orm, convert to functions that
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Nov 2007 01:59:29 +0000 (01:59 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Nov 2007 01:59:29 +0000 (01:59 +0000)
use a common traversal function.
- TranslatingDict is finally gone, thanks to column.proxy_set simpleness...hooray !
- shoved "slice" use case on RowProxy into an exception case.  knocks noticeable time off of large result set operations.

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/topological.py

index 9c7e70ba92efb12e3ba7a1e3accfb100feac886e..21977b689b9807017cfc08c2294eaa18016ff380 100644 (file)
@@ -1243,12 +1243,7 @@ class RowProxy(object):
         return self.__parent._has_key(self.__row, key)
 
     def __getitem__(self, key):
-        if isinstance(key, slice):
-            indices = key.indices(len(self))
-            return tuple([self.__parent._get_col(self.__row, i)
-                          for i in xrange(*indices)])
-        else:
-            return self.__parent._get_col(self.__row, key)
+        return self.__parent._get_col(self.__row, key)
 
     def __getattr__(self, name):
         try:
@@ -1474,7 +1469,15 @@ class ResultProxy(object):
         return self.dialect.supports_sane_multi_rowcount
 
     def _get_col(self, row, key):
-        rec = self._key_cache[key]
+        try:
+            rec = self._key_cache[key]
+        except TypeError:
+            # the 'slice' use case is very infrequent,
+            # so we use an exception catch to reduce conditionals in _get_col
+            if isinstance(key, slice):
+                indices = key.indices(len(row))
+                return tuple([self._get_col(row, i) for i in xrange(*indices)])
+
         if rec[1]:
             return rec[1](row[rec[2]])
         else:
@@ -1482,8 +1485,10 @@ class ResultProxy(object):
 
     def _fetchone_impl(self):
         return self.cursor.fetchone()
+        
     def _fetchmany_impl(self, size=None):
         return self.cursor.fetchmany(size)
+        
     def _fetchall_impl(self):
         return self.cursor.fetchall()
 
index 3bacb13e93a04ece039a37397b434f977f86c126..1414336ac68f5931c043836cdb3e013039409558 100644 (file)
@@ -392,7 +392,7 @@ class Mapper(object):
 
         # locate all tables contained within the "table" passed in, which
         # may be a join or other construct
-        self.tables = sqlutil.TableFinder(self.mapped_table)
+        self.tables = sqlutil.find_tables(self.mapped_table)
 
         if not self.tables:
             raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
@@ -576,7 +576,7 @@ class Mapper(object):
         # table columns mapped to lists of MapperProperty objects
         # using a list allows a single column to be defined as
         # populating multiple object attributes
-        self._columntoproperty = mapperutil.TranslatingDict(self.mapped_table)
+        self._columntoproperty = {} #mapperutil.TranslatingDict(self.mapped_table)
 
         # load custom properties
         if self._init_properties is not None:
@@ -663,7 +663,8 @@ class Mapper(object):
 
             self.columns[key] = col
             for col in prop.columns:
-                self._columntoproperty[col] = prop
+                for col in col.proxy_set:
+                    self._columntoproperty[col] = prop
             
         self.__props[key] = prop
 
@@ -995,7 +996,7 @@ class Mapper(object):
             for t in mapper.tables:
                 table_to_mapper.setdefault(t, mapper)
 
-        for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=False):
+        for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=False):
             # two lists to store parameters for each table/object pair located
             insert = []
             update = []
@@ -1217,7 +1218,7 @@ class Mapper(object):
             for t in mapper.tables:
                 table_to_mapper.setdefault(t, mapper)
 
-        for table in sqlutil.TableCollection(list(table_to_mapper.keys())).sort(reverse=True):
+        for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True):
             delete = {}
             for (obj, connection) in tups:
                 mapper = object_mapper(obj)
@@ -1260,16 +1261,13 @@ class Mapper(object):
                     mapper.extension.after_delete(mapper, connection, obj)
 
     def _has_pks(self, table):
-        try:
-            pk = self.pks_by_table[table]
-            if not pk:
-                return False
-            for k in pk:
+        if self.pks_by_table.get(table, None):
+            for k in self.pks_by_table[table]:
                 if k not in self._columntoproperty:
                     return False
             else:
                 return True
-        except KeyError:
+        else:
             return False
 
     def register_dependencies(self, uowcommit, *args, **kwargs):
index 77f7fbe04d71d9c92893019e9c2905135f6d60a0..1214025849e95554ddfab53ddfd01103c6010e43 100644 (file)
@@ -900,12 +900,12 @@ class Query(object):
             # "inner" select statement where they'll be available to the enclosing
             # statement's "order by"
 
-            cf = sql_util.ColumnFinder()
+            cf = util.Set()
 
             if order_by:
                 order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
                 for o in order_by:
-                    cf.traverse(o)
+                    cf.update(sql_util.find_columns(o))
             
             s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args())
 
@@ -934,9 +934,9 @@ class Query(object):
             # to use it in "order_by".  ensure they are in the column criterion (particularly oid).
             if self._distinct and order_by:
                 order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
-                cf = sql_util.ColumnFinder()
+                cf = util.Set()
                 for o in order_by:
-                    cf.traverse(o)
+                    cf.update(sql_util.find_columns(o))
 
                 [statement.append_column(c) for c in cf]
 
@@ -976,7 +976,7 @@ class Query(object):
                 return None
         elif isinstance(m, sql.ColumnElement):
             aliases = []
-            for table in sql_util.TableFinder(m, check_columns=True):
+            for table in sql_util.find_tables(m, check_columns=True):
                 for a in self._alias_ids.get(table, []):
                     aliases.append(a)
             if len(aliases) > 1:
index 7911b93c8ba39e4870d00527a82535fc14ef7a08..0277218c730ded0928ab592f1efed0d938bf18ae 100644 (file)
@@ -495,7 +495,7 @@ class EagerLoader(AbstractRelationLoader):
                     towrap = fromclause
                     break
                 elif isinstance(fromclause, sql.Join):
-                    if localparent.mapped_table in sql_util.TableFinder(fromclause, include_aliases=True):
+                    if localparent.mapped_table in sql_util.find_tables(fromclause, include_aliases=True):
                         towrap = fromclause
                         break
             else:
index a5e2d1e6e880658e5e9cc8f7d6fe93a1c30d83e1..7be72dc3c18813352101fbe2cd0714024fe67044 100644 (file)
@@ -77,43 +77,6 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'):
             result.append(sql.select([col(name, table) for name in colnames], from_obj=[table]))
     return sql.union_all(*result).alias(aliasname)
 
-class TranslatingDict(dict):
-    """A dictionary that stores ``ColumnElement`` objects as keys.
-
-    Incoming ``ColumnElement`` keys are translated against those of an
-    underling ``FromClause`` for all operations.  This way the columns
-    from any ``Selectable`` that is derived from or underlying this
-    ``TranslatingDict`` 's selectable can be used as keys.
-    """
-
-    def __init__(self, selectable):
-        super(TranslatingDict, self).__init__()
-        self.selectable = selectable
-
-    def __translate_col(self, col):
-        ourcol = self.selectable.corresponding_column(col, raiseerr=False)
-        if ourcol is None:
-            return col
-        else:
-            return ourcol
-
-    def __getitem__(self, col):
-        try:
-            return super(TranslatingDict, self).__getitem__(col)
-        except KeyError:
-            return super(TranslatingDict, self).__getitem__(self.__translate_col(col))
-
-    def has_key(self, col):
-        return col in self
-
-    def __setitem__(self, col, value):
-        return super(TranslatingDict, self).__setitem__(self.__translate_col(col), value)
-
-    def __contains__(self, col):
-        return super(TranslatingDict, self).__contains__(self.__translate_col(col))
-
-    def setdefault(self, col, value):
-        return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)
 
 class ExtensionCarrier(object):
     """stores a collection of MapperExtension objects.
index 21e36fe6f832acc70f3993487349d17a4b72d163..0e1e5f7a9caf8e7df48197463e7c8c45f049ddf2 100644 (file)
@@ -1143,8 +1143,7 @@ class MetaData(SchemaItem):
             tables = self.tables.values()
         else:
             tables = util.Set(tables).intersection(self.tables.values())
-        sorter = sql_util.TableCollection(list(tables))
-        return iter(sorter.sort(reverse=reverse))
+        return iter(sql_util.sort_tables(tables, reverse=reverse))
 
     def _get_parent(self):
         return None
index 3e2d4ec311c30f299664abe6820ff29c7dae029c..d3e89d57e65fc7b80144d674151e68586571283b 100644 (file)
@@ -3,105 +3,50 @@ from sqlalchemy.sql import expression, visitors
 
 """Utility functions that build upon SQL and Schema constructs."""
 
-# TODO: replace with plain list.  break out sorting funcs into module-level funcs
-class TableCollection(object):
-    def __init__(self, tables=None):
-        self.tables = tables or []
-
-    def __len__(self):
-        return len(self.tables)
-
-    def __getitem__(self, i):
-        return self.tables[i]
-
-    def __iter__(self):
-        return iter(self.tables)
-
-    def __contains__(self, obj):
-        return obj in self.tables
-
-    def __add__(self, obj):
-        return self.tables + list(obj)
-
-    def add(self, table):
-        self.tables.append(table)
-        if hasattr(self, '_sorted'):
-            del self._sorted
-
-    def sort(self, reverse=False):
-        try:
-            sorted = self._sorted
-        except AttributeError, e:
-            self._sorted = self._do_sort()
-            sorted = self._sorted
-        if reverse:
-            x = sorted[:]
-            x.reverse()
-            return x
-        else:
-            return sorted
-
-    def _do_sort(self):
-        tuples = []
-        class TVisitor(schema.SchemaVisitor):
-            def visit_foreign_key(_self, fkey):
-                if fkey.use_alter:
-                    return
-                parent_table = fkey.column.table
-                if parent_table in self:
-                    child_table = fkey.parent.table
-                    tuples.append( ( parent_table, child_table ) )
-        vis = TVisitor()
-        for table in self.tables:
-            vis.traverse(table)
-        sorter = topological.QueueDependencySorter( tuples, self.tables )
-        head =  sorter.sort()
-        sequence = []
-        def to_sequence( node, seq=sequence):
-            seq.append( node.item )
-            for child in node.children:
-                to_sequence( child )
-        if head is not None:
-            to_sequence( head )
-        return sequence
-
-
-# TODO: replace with plain module-level func
-class TableFinder(TableCollection, visitors.NoColumnVisitor):
-    """locate all Tables within a clause."""
-
-    def __init__(self, clause, check_columns=False, include_aliases=False):
-        TableCollection.__init__(self)
-        self.check_columns = check_columns
-        self.include_aliases = include_aliases
-        for clause in util.to_list(clause):
-            self.traverse(clause)
-
-    def visit_alias(self, alias):
-        if self.include_aliases:
-            self.tables.append(alias)
-            
-    def visit_table(self, table):
-        self.tables.append(table)
-
-    def visit_column(self, column):
-        if self.check_columns:
-            self.tables.append(column.table)
-
-class ColumnFinder(visitors.ClauseVisitor):
-    def __init__(self):
-        self.columns = util.Set()
-
-    def visit_column(self, c):
-        self.columns.add(c)
-
-    def __iter__(self):
-        return iter(self.columns)
-
-def find_columns(selectable):
-    cf = ColumnFinder()
-    cf.traverse(selectable)
-    return iter(cf)
+def sort_tables(tables, reverse=False):
+    tuples = []
+    class TVisitor(schema.SchemaVisitor):
+        def visit_foreign_key(_self, fkey):
+            if fkey.use_alter:
+                return
+            parent_table = fkey.column.table
+            if parent_table in tables:
+                child_table = fkey.parent.table
+                tuples.append( ( parent_table, child_table ) )
+    vis = TVisitor()
+    for table in tables:
+        vis.traverse(table)
+    sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False)
+    if reverse:
+        sequence.reverse()
+    return sequence
+
+def find_tables(clause, check_columns=False, include_aliases=False):
+    tables = []
+    kwargs = {}
+    if include_aliases:
+        def visit_alias(alias):
+            tables.append(alias)
+        kwargs['visit_alias']  = visit_alias
+    
+    if check_columns:
+        def visit_column(column):
+            tables.append(column.table)
+        kwargs['visit_column'] = visit_column
+    
+    def visit_table(table):
+        tables.append(table)
+    kwargs['visit_table'] = visit_table
+    
+    visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs)
+    return tables
+
+def find_columns(clause):
+    cols = util.Set()
+    def visit_column(col):
+        cols.add(col)
+    visitors.traverse(clause, visit_column=visit_column)
+    return cols
     
 class ColumnsInClause(visitors.ClauseVisitor):
     """Given a selectable, visit clauses and determine if any columns
index a47968519df7ff33ab56cad1b76805eedf1218ae..ccded5d47d57e457eb9da65aabba4166ed0b05f9 100644 (file)
@@ -143,7 +143,7 @@ class QueueDependencySorter(object):
         self.tuples = tuples
         self.allitems = allitems
 
-    def sort(self, allow_self_cycles=True, allow_all_cycles=False):
+    def sort(self, allow_self_cycles=True, allow_all_cycles=False, create_tree=True):
         (tuples, allitems) = (self.tuples, self.allitems)
         #print "\n---------------------------------\n"
         #print repr([t for t in tuples])
@@ -152,7 +152,7 @@ class QueueDependencySorter(object):
 
         nodes = {}
         edges = _EdgeCollection()
-        for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]:
+        for item in list(allitems) + [t[0] for t in tuples] + [t[1] for t in tuples]:
             if id(item) not in nodes:
                 node = _Node(item)
                 nodes[id(item)] = node
@@ -205,7 +205,10 @@ class QueueDependencySorter(object):
             del nodes[id(node.item)]
             for childnode in edges.pop_node(node):
                 queue.append(childnode)
-        return self._create_batched_tree(output)
+        if create_tree:
+            return self._create_batched_tree(output)
+        else:
+            return [n.item for n in output]
 
 
     def _create_batched_tree(self, nodes):