]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- all kinds of cleanup, tiny-to-slightly-significant speed improvements
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Nov 2007 00:55:39 +0000 (00:55 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Nov 2007 00:55:39 +0000 (00:55 +0000)
19 files changed:
lib/sqlalchemy/databases/maxdb.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/sybase.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/profiling/compiler.py
test/profiling/zoomark.py

index 09cc9f0b68e91322bf0cb3ee172aa37002ee7809..3ca1bd61cc218d93c6574cd41fd7ca9fc4921748 100644 (file)
@@ -480,6 +480,21 @@ class MaxDBDialect(default.DefaultDialect):
         super(MaxDBDialect, self).__init__(**kw)
         self._raise_known = _raise_known_sql_errors
 
+        if self.dbapi is None:
+            self.dbapi_type_map = {}
+        else:
+            self.dbapi_type_map = {
+                'Long Binary': MaxBlob(),
+                'Long byte_t': MaxBlob(),
+                'Long Unicode': MaxText(),
+                'Timestamp': MaxTimestamp(),
+                'Date': MaxDate(),
+                'Time': MaxTime(),
+                datetime.datetime: MaxTimestamp(),
+                datetime.date: MaxDate(),
+                datetime.time: MaxTime(),
+            }
+
     def dbapi(cls):
         from sapdb import dbapi as _dbapi
         return _dbapi
@@ -498,22 +513,6 @@ class MaxDBDialect(default.DefaultDialect):
         else:
             return sqltypes.adapt_type(typeobj, colspecs)
 
-    def dbapi_type_map(self):
-        if self.dbapi is None:
-            return {}
-        else:
-            return {
-                'Long Binary': MaxBlob(),
-                'Long byte_t': MaxBlob(),
-                'Long Unicode': MaxText(),
-                'Timestamp': MaxTimestamp(),
-                'Date': MaxDate(),
-                'Time': MaxTime(),
-                datetime.datetime: MaxTimestamp(),
-                datetime.date: MaxDate(),
-                datetime.time: MaxTime(),
-            }
-
     def create_execution_context(self, connection, **kw):
         return MaxDBExecutionContext(self, connection, **kw)
 
index 71c82bfd86bd744efb0b45558e6cd77fd4846419..9c9c54bf667a3e441c7c83f03f3f4af1370f67f8 100644 (file)
@@ -251,24 +251,19 @@ class OracleDialect(default.DefaultDialect):
         self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
         self.auto_setinputsizes = auto_setinputsizes
         self.auto_convert_lobs = auto_convert_lobs
-        
-        if self.dbapi is not None:
-            self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
-        else:
+        if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
+            self.dbapi_type_map = {}
             self.ORACLE_BINARY_TYPES = []
-
-    def dbapi_type_map(self):
-        if self.dbapi is None or not self.auto_convert_lobs:
-            return {}
         else:
             # only use this for LOB objects.  using it for strings, dates
             # etc. leads to a little too much magic, reflection doesn't know if it should
             # expect encoded strings or unicodes, etc.
-            return {
+            self.dbapi_type_map = {
                 self.dbapi.CLOB: OracleText(), 
                 self.dbapi.BLOB: OracleBinary(), 
                 self.dbapi.BINARY: OracleRaw(), 
             }
+            self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
 
     def dbapi(cls):
         import cx_Oracle
index 022f547463c5bef08d2a73b924a2a8051d3e938e..87045d192612d4fbf087b875df39ad150a3addec 100644 (file)
@@ -639,8 +639,7 @@ class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
     def __init__(self, **params):
         super(SybaseSQLDialect_mxodbc, self).__init__(**params)
 
-    def dbapi_type_map(self):
-        return {'getdate' : SybaseDate_mxodbc()}
+        self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()}
         
     def import_dbapi(cls):
         #import mx.ODBC.Windows as module
@@ -686,9 +685,7 @@ class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
 class SybaseSQLDialect_pyodbc(SybaseSQLDialect):
     def __init__(self, **params):
         super(SybaseSQLDialect_pyodbc, self).__init__(**params)
-
-    def dbapi_type_map(self):
-        return {'getdate' : SybaseDate_pyodbc()}
+        self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()}
 
     def import_dbapi(cls):
         import mypyodbc as module
index 6af0ce0d39a8003e57c661ef36096a2649e0f11f..9c7e70ba92efb12e3ba7a1e3accfb100feac886e 100644 (file)
@@ -87,6 +87,17 @@ class Dialect(object):
     supports_pk_autoincrement
       Indicates if the dialect should allow the database to passively assign
       a primary key column value.
+
+    dbapi_type_map
+      A mapping of DB-API type objects present in this Dialect's
+      DB-API implmentation mapped to TypeEngine implementations used
+      by the dialect.
+
+      This is used to apply types to result sets based on the DB-API
+      types present in cursor.description; it only takes effect for
+      result sets against textual statements where no explicit
+      typemap was present.  
+      
     """
 
     def create_connect_args(self, url):
@@ -99,21 +110,6 @@ class Dialect(object):
 
         raise NotImplementedError()
 
-    def dbapi_type_map(self):
-        """Returns a DB-API to sqlalchemy.types mapping.
-
-        A mapping of DB-API type objects present in this Dialect's
-        DB-API implmentation mapped to TypeEngine implementations used
-        by the dialect.
-
-        This is used to apply types to result sets based on the DB-API
-        types present in cursor.description; it only takes effect for
-        result sets against textual statements where no explicit
-        typemap was present.  Constructed SQL statements always have
-        type information explicitly embedded.
-        """
-
-        raise NotImplementedError()
 
     def type_descriptor(self, typeobj):
         """Transform a generic type to a database-specific type.
@@ -1339,7 +1335,7 @@ class ResultProxy(object):
         metadata = self.cursor.description
 
         if metadata is not None:
-            typemap = self.dialect.dbapi_type_map()
+            typemap = self.dialect.dbapi_type_map
 
             for i, item in enumerate(metadata):
                 # sqlite possibly prepending table name to colnames so strip
index 198f6742bf9702189ad00d233258b5a5d488f3f7..a91d65b81f301c8aede1de80bcd06da870a3cd4d 100644 (file)
@@ -33,7 +33,8 @@ class DefaultDialect(base.Dialect):
     supports_sane_multi_rowcount = True
     preexecute_pk_sequences = False
     supports_pk_autoincrement = True
-
+    dbapi_type_map = {}
+    
     def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
         self.convert_unicode = convert_unicode
         self.encoding = encoding
@@ -59,12 +60,6 @@ class DefaultDialect(base.Dialect):
                     property(lambda s: s.preexecute_sequences, doc=(
                       "Proxy to deprecated preexecute_sequences attribute.")))
 
-    def dbapi_type_map(self):
-        # most DB-APIs have problems with this (such as, psycocpg2 types 
-        # are unhashable).  So far Oracle can return it.
-        
-        return {}
-    
     def create_execution_context(self, connection, **kwargs):
         return DefaultExecutionContext(self, connection, **kwargs)
 
index 27f4b017c30d16128034d7a0fa2d67e82c9c01ee..123a99c9a82ff907333c29bb113e2edc7d5d8511 100644 (file)
@@ -686,7 +686,7 @@ class InstanceState(object):
         self._strong_obj = None
         
 
-class InstanceDict(UserDict.UserDict):
+class WeakInstanceDict(UserDict.UserDict):
     """similar to WeakValueDictionary, but wired towards 'state' objects."""
     
     def __init__(self, *args, **kw):
@@ -802,7 +802,12 @@ class InstanceDict(UserDict.UserDict):
     def copy(self):
         raise NotImplementedError()
 
-    
+    def all_states(self):
+        return self.data.values()
+
+class StrongInstanceDict(dict):
+    def all_states(self):
+        return [o._state for o in self.values()]
     
 class AttributeHistory(object):
     """Calculate the *history* of a particular attribute on a
index 42fe8f56fb0666cf7b679e23761dea44d94130c2..3bacb13e93a04ece039a37397b434f977f86c126 100644 (file)
@@ -6,7 +6,7 @@
 
 import weakref, warnings, operator
 from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy.sql import expression
+from sqlalchemy.sql import expression, visitors
 from sqlalchemy.sql import util as sqlutil
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter
@@ -513,11 +513,9 @@ class Mapper(object):
                     result[binary.right].add(binary.left)
                 else:
                     result[binary.right] = util.Set([binary.left])
-        vis = mapperutil.BinaryVisitor(visit_binary)
-
         for mapper in self.base_mapper.polymorphic_iterator():
             if mapper.inherit_condition is not None:
-                vis.traverse(mapper.inherit_condition)
+                visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary)
 
         # TODO: matching of cols to foreign keys might better be generalized
         # into general column translation (i.e. corresponding_column)
@@ -1472,11 +1470,10 @@ class Mapper(object):
         allconds = []
         param_names = []
 
-        visitor = mapperutil.BinaryVisitor(visit_binary)
         for mapper in self.iterate_to_root():
             if mapper is base_mapper:
                 break
-            allconds.append(visitor.traverse(mapper.inherit_condition, clone=True))
+            allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
         
         return sql.and_(*allconds), param_names
 
index 46224aac658f5167c3c286e9c7ded824acd64742..9e7815e38e1468a207eecc706ece0268321fdd5b 100644 (file)
@@ -12,7 +12,7 @@ to handle flush-time dependency sorting and processing.
 """
 
 from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import util as sql_util, visitors
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
@@ -439,9 +439,9 @@ class PropertyLoader(StrategizedProperty):
                     self._opposite_side.add(binary.right)
                 if binary.right in self.foreign_keys:
                     self._opposite_side.add(binary.left)
-            mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin)
+            visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
             if self.secondaryjoin is not None:
-                mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin)
+                visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
         else:
             self.foreign_keys = util.Set()
             self._opposite_side = util.Set()
@@ -463,7 +463,7 @@ class PropertyLoader(StrategizedProperty):
                     if f.references(binary.left.table):
                         self.foreign_keys.add(binary.right)
                         self._opposite_side.add(binary.left)
-            mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin)
+            visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
 
             if len(self.foreign_keys) == 0:
                 raise exceptions.ArgumentError(
@@ -472,7 +472,7 @@ class PropertyLoader(StrategizedProperty):
                     "'foreign_keys' argument to indicate which columns in "
                     "the join condition are foreign." %(str(self.primaryjoin), str(self)))
             if self.secondaryjoin is not None:
-                mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin)
+                visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
 
 
     def _determine_direction(self):
@@ -543,14 +543,13 @@ class PropertyLoader(StrategizedProperty):
         # in the "polymorphic" selectables.  these are used to construct joins for both Query as well as
         # eager loading, and also are used to calculate "lazy loading" clauses.
 
-        # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
-        # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
-        # several "equivalent" columns (such as parent/child fk cols) into just one column.
-
-        target_equivalents = self.mapper._get_equivalent_columns()
-            
-        # if the target mapper loads polymorphically, adapt the clauses to the target's selectable
         if self.loads_polymorphic:
+
+            # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
+            # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
+            # several "equivalent" columns (such as parent/child fk cols) into just one column.
+            target_equivalents = self.mapper._get_equivalent_columns()
+
             if self.secondaryjoin:
                 self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
                 self.polymorphic_primaryjoin = self.primaryjoin
@@ -560,6 +559,7 @@ class PropertyLoader(StrategizedProperty):
                 elif self.direction is sync.MANYTOONE:
                     self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
                 self.polymorphic_secondaryjoin = None
+
             # load "polymorphic" versions of the columns present in "remote_side" - this is
             # important for lazy-clause generation which goes off the polymorphic target selectable
             for c in list(self.remote_side):
index 47a7022695db113b9fa666d805e909578707327a..77f7fbe04d71d9c92893019e9c2905135f6d60a0 100644 (file)
@@ -907,11 +907,8 @@ class Query(object):
                 for o in order_by:
                     cf.traverse(o)
             
-            s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, **self._select_args())
+            s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args())
 
-            if order_by:
-                s2.append_order_by(*util.to_list(order_by))
-            
             s3 = s2.alias()
                 
             self._primary_adapter = mapperutil.create_row_adapter(s3, self.table)
@@ -926,13 +923,10 @@ class Query(object):
 
             statement.append_order_by(*context.eager_order_by)
         else:
-            statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args())
+            statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=util.to_list(order_by), **self._select_args())
             if context.eager_joins:
                 statement.append_from(context.eager_joins, _copy_collection=False)
 
-            if order_by:
-                statement.append_order_by(*util.to_list(order_by))
-
             if context.eager_order_by:
                 statement.append_order_by(*context.eager_order_by)
                 
index 2caee2dd4a1b5b1f38ab83828790af7028739595..7911b93c8ba39e4870d00527a82535fc14ef7a08 100644 (file)
@@ -407,21 +407,11 @@ class LazyLoader(AbstractRelationLoader):
             else:
                 return othercol in remote_side
 
-        def find_column_in_expr(expr):
-            if not isinstance(expr, sql.ColumnElement):
-                return None
-            columns = []
-            class FindColumnInColumnClause(visitors.ClauseVisitor):
-                def visit_column(self, c):
-                    columns.append(c)
-            FindColumnInColumnClause().traverse(expr)
-            return len(columns) and columns[0] or None
-        
         def visit_binary(binary):
-            leftcol = find_column_in_expr(binary.left)
-            rightcol = find_column_in_expr(binary.right)
-            if leftcol is None or rightcol is None:
+            if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
                 return
+            leftcol = binary.left
+            rightcol = binary.right
             
             if should_bind(leftcol, rightcol):
                 col = leftcol
@@ -438,14 +428,13 @@ class LazyLoader(AbstractRelationLoader):
                 reverse[leftcol] = binds[col]
 
         lazywhere = primaryjoin
-        li = mapperutil.BinaryVisitor(visit_binary)
         
         if not secondaryjoin or not reverse_direction:
-            lazywhere = li.traverse(lazywhere, clone=True)
+            lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary)
         
         if secondaryjoin is not None:
             if reverse_direction:
-                secondaryjoin = li.traverse(secondaryjoin, clone=True)
+                secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
         return (lazywhere, binds, reverse)
     _create_lazy_clause = classmethod(_create_lazy_clause)
index 5b5a9e43b1285839a99003b82d20409666181055..9575aa958feb476dcabc6e5715bfe9e5bc222086 100644 (file)
@@ -79,7 +79,7 @@ class ClauseSynchronizer(object):
                         self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary))
 
         rules_added = len(self.syncrules)
-        BinaryVisitor(compile_binary).traverse(sqlclause)
+        visitors.traverse(sqlclause, visit_binary=compile_binary)
         if len(self.syncrules) == rules_added:
             raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
 
@@ -144,9 +144,3 @@ class SyncRule(object):
 
 SyncRule.logger = logging.class_logger(SyncRule)
 
-class BinaryVisitor(visitors.ClauseVisitor):
-    def __init__(self, func):
-        self.func = func
-
-    def visit_binary(self, binary):
-        self.func(binary)
index 2cd7cb6f5dd958c9c46f74258a19537b2919b66a..cdffad266b3a6d13129e56ec26411e8efc78ef8b 100644 (file)
@@ -89,9 +89,9 @@ class UnitOfWork(object):
 
     def __init__(self, session):
         if session.weak_identity_map:
-            self.identity_map = attributes.InstanceDict()
+            self.identity_map = attributes.WeakInstanceDict()
         else:
-            self.identity_map = {}
+            self.identity_map = attributes.StrongInstanceDict()
 
         self.new = util.IdentitySet() #OrderedSet()
         self.deleted = util.IdentitySet()
@@ -158,6 +158,7 @@ class UnitOfWork(object):
             )
             ])
 
+        
     def flush(self, session, objects=None):
         """create a dependency tree of all pending SQL operations within this unit of work and execute."""
         
@@ -166,12 +167,16 @@ class UnitOfWork(object):
         # communication with the mappers and relationships to fire off SQL
         # and synchronize attributes between related objects.
 
-        # detect persistent objects that have changes
-        dirty = self.locate_dirty()
-
+        dirty = [x for x in self.identity_map.all_states()
+            if x.modified
+            or (getattr(x.class_, '_sa_has_mutable_scalars', False) and attribute_manager._is_modified(x))
+        ]
+        
         if len(dirty) == 0 and len(self.deleted) == 0 and len(self.new) == 0:
             return
-
+            
+        dirty = util.IdentitySet([x.obj() for x in dirty]).difference(self.deleted)
+        
         flush_context = UOWTransaction(self, session)
 
         if session.extension is not None:
@@ -232,7 +237,7 @@ class UnitOfWork(object):
         the number of objects pruned.
         """
 
-        if isinstance(self.identity_map, attributes.InstanceDict):
+        if isinstance(self.identity_map, attributes.WeakInstanceDict):
             return 0
         ref_count = len(self.identity_map)
         dirty = self.locate_dirty()
index 55d4db98fee66b1767ed9094111c42a5b4f189d0..a5e2d1e6e880658e5e9cc8f7d6fe93a1c30d83e1 100644 (file)
@@ -98,7 +98,10 @@ class TranslatingDict(dict):
             return ourcol
 
     def __getitem__(self, col):
-        return super(TranslatingDict, self).__getitem__(self.__translate_col(col))
+        try:
+            return super(TranslatingDict, self).__getitem__(col)
+        except KeyError:
+            return super(TranslatingDict, self).__getitem__(self.__translate_col(col))
 
     def has_key(self, col):
         return col in self
@@ -172,13 +175,6 @@ class ExtensionCarrier(object):
     def __getattr__(self, key):
         return self.methods.get(key, self._pass)
 
-class BinaryVisitor(visitors.ClauseVisitor):
-    def __init__(self, func):
-        self.func = func
-
-    def visit_binary(self, binary):
-        self.func(binary)
-
 class AliasedClauses(object):
     """Creates aliases of a mapped tables for usage in ORM queries.
     """
index dd6f0ddddeb819c4cf2038fb78e69a1fea16023c..c1f3bc2a05daa18ea339a48c491937d69f875d41 100644 (file)
@@ -245,19 +245,25 @@ class DefaultCompiler(engine.Compiled):
             n = self.dialect.oid_column_name(column)
             if n is not None:
                 if column.table is None or not column.table.named_with_column():
-                    return self.preparer.format_column(column, name=n)
+                    return n
                 else:
-                    return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)), n)
+                    return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n
             elif len(column.table.primary_key) != 0:
                 pk = list(column.table.primary_key)[0]
                 pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
-                return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name))
+                return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname)
             else:
                 return None
         elif column.table is None or not column.table.named_with_column():
-            return self.preparer.format_column(column, name=name)
+            if getattr(column, "is_literal", False):
+                return name
+            else:
+                return self.preparer.quote(column, name)
         else:
-            return self.preparer.format_column_with_table(column, column_name=name, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name))
+            if getattr(column, "is_literal", False):
+                return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name
+            else:
+                return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(column, name)
 
 
     def visit_fromclause(self, fromclause, **kwargs):
@@ -588,7 +594,10 @@ class DefaultCompiler(engine.Compiled):
 
     def visit_table(self, table, asfrom=False, **kwargs):
         if asfrom:
-            return self.preparer.format_table(table)
+            if getattr(table, "schema", None):
+                return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name)
+            else:
+                return self.preparer.quote(table, table.name)
         else:
             return ""
 
@@ -606,7 +615,7 @@ class DefaultCompiler(engine.Compiled):
 
         return ("INSERT INTO %s (%s) VALUES (%s)" %
                 (preparer.format_table(insert_stmt.table),
-                 ', '.join([preparer.format_column(c[0])
+                 ', '.join([preparer.quote(c[0], c[0].name)
                             for c in colparams]),
                  ', '.join([c[1] for c in colparams])))
     
@@ -616,7 +625,7 @@ class DefaultCompiler(engine.Compiled):
         self.isupdate = True
         colparams = self._get_colparams(update_stmt)
 
-        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
+        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ')
 
         if update_stmt._whereclause:
             text += " WHERE " + self.process(update_stmt._whereclause)
@@ -831,7 +840,7 @@ class SchemaGenerator(DDLBase):
         if constraint.name is not None:
             self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
         self.append("PRIMARY KEY ")
-        self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
+        self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint]))
 
     def visit_foreign_key_constraint(self, constraint):
         if constraint.use_alter and self.dialect.supports_alter:
@@ -849,10 +858,11 @@ class SchemaGenerator(DDLBase):
         if constraint.name is not None:
             self.append("CONSTRAINT %s " %
                         preparer.format_constraint(constraint))
+        table = list(constraint.elements)[0].column.table
         self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
-            ', '.join([preparer.format_column(f.parent) for f in constraint.elements]),
-            preparer.format_table(list(constraint.elements)[0].column.table),
-            ', '.join([preparer.format_column(f.column) for f in constraint.elements])
+            ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]),
+            preparer.format_table(table),
+            ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements])
         ))
         if constraint.ondelete is not None:
             self.append(" ON DELETE %s" % constraint.ondelete)
@@ -864,7 +874,7 @@ class SchemaGenerator(DDLBase):
         if constraint.name is not None:
             self.append("CONSTRAINT %s " %
                         self.preparer.format_constraint(constraint))
-        self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint])))
+        self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint])))
 
     def visit_column(self, column):
         pass
@@ -877,7 +887,7 @@ class SchemaGenerator(DDLBase):
         self.append("INDEX %s ON %s (%s)" \
                     % (preparer.format_index(index),
                        preparer.format_table(index.table),
-                       string.join([preparer.format_column(c) for c in index.columns], ', ')))
+                       string.join([preparer.quote(c, c.name) for c in index.columns], ', ')))
         self.execute()
 
 class SchemaDropper(DDLBase):
@@ -978,12 +988,12 @@ class IdentifierPreparer(object):
                 or not self.legal_characters.match(unicode(value))
                 or (lc_value != value))
 
-    def __generic_obj_format(self, obj, ident):
+    def quote(self, obj, ident):
         if getattr(obj, 'quote', False):
             return self.quote_identifier(ident)
-        try:
+        if ident in self.__strings:
             return self.__strings[ident]
-        except KeyError:
+        else:
             if self._requires_quotes(ident):
                 self.__strings[ident] = self.quote_identifier(ident)
             else:
@@ -994,45 +1004,49 @@ class IdentifierPreparer(object):
         return object.quote or self._requires_quotes(object.name)
 
     def format_sequence(self, sequence, use_schema=True):
-        name = self.__generic_obj_format(sequence, sequence.name)
+        name = self.quote(sequence, sequence.name)
         if use_schema and sequence.schema is not None:
-            name = self.__generic_obj_format(sequence, sequence.schema) + "." + name
+            name = self.quote(sequence, sequence.schema) + "." + name
         return name
 
     def format_label(self, label, name=None):
-        return self.__generic_obj_format(label, name or label.name)
+        return self.quote(label, name or label.name)
 
     def format_alias(self, alias, name=None):
-        return self.__generic_obj_format(alias, name or alias.name)
+        return self.quote(alias, name or alias.name)
 
     def format_savepoint(self, savepoint, name=None):
-        return self.__generic_obj_format(savepoint, name or savepoint.ident)
+        return self.quote(savepoint, name or savepoint.ident)
 
     def format_constraint(self, constraint):
-        return self.__generic_obj_format(constraint, constraint.name)
+        return self.quote(constraint, constraint.name)
 
     def format_index(self, index):
-        return self.__generic_obj_format(index, index.name)
+        return self.quote(index, index.name)
 
     def format_table(self, table, use_schema=True, name=None):
         """Prepare a quoted table and schema name."""
 
         if name is None:
             name = table.name
-        result = self.__generic_obj_format(table, name)
+        result = self.quote(table, name)
         if use_schema and getattr(table, "schema", None):
-            result = self.__generic_obj_format(table, table.schema) + "." + result
+            result = self.quote(table, table.schema) + "." + result
         return result
 
     def format_column(self, column, use_table=False, name=None, table_name=None):
-        """Prepare a quoted column name."""
+        """Prepare a quoted column name.
+        
+        deprecated.  use preparer.quote(col, column.name) or combine with format_table()
+        """
+        
         if name is None:
             name = column.name
         if not getattr(column, 'is_literal', False):
             if use_table:
-                return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name)
+                return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name)
             else:
-                return self.__generic_obj_format(column, name)
+                return self.quote(column, name)
         else:
             # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
             if use_table:
@@ -1040,12 +1054,6 @@ class IdentifierPreparer(object):
             else:
                 return name
 
-    def format_column_with_table(self, column, column_name=None, table_name=None):
-        """Prepare a quoted column name with table name."""
-        
-        return self.format_column(column, use_table=True, name=column_name, table_name=table_name)
-
-
     def format_table_seq(self, table, use_schema=True):
         """Format table name and schema as a tuple."""
 
index c7ab342722f08e43c6693a289c7976b0fc341be7..b3200a7eba339a0a5ae41f738638b45fcce93be3 100644 (file)
@@ -863,6 +863,16 @@ class ClauseElement(object):
 
         raise NotImplementedError(repr(self))
 
+    def _aggregate_hide_froms(self, **modifiers):
+        """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces, taking into account
+        previous ClauseElements which this ClauseElement is a clone of."""
+        
+        s = self
+        while s is not None:
+            for h in s._hide_froms(**modifiers):
+                yield h
+            s = getattr(s, '_is_clone_of', None)
+            
     def _hide_froms(self, **modifiers):
         """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces."""
 
@@ -2203,11 +2213,10 @@ class Join(FromClause):
                 else:
                     equivs[x] = util.Set([y])
 
-        class BinaryVisitor(visitors.ClauseVisitor):
-            def visit_binary(self, binary):
-                if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
-                    add_equiv(binary.left, binary.right)
-        BinaryVisitor().traverse(self.onclause)
+        def visit_binary(binary):
+            if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
+                add_equiv(binary.left, binary.right)
+        visitors.traverse(self.onclause, visit_binary=visit_binary)
 
         for col in pkcol:
             for fk in col.foreign_keys:
@@ -2719,8 +2728,8 @@ class _SelectBaseMixin(object):
         self._offset = offset
         self._bind = bind
 
-        self.append_order_by(*util.to_list(order_by, []))
-        self.append_group_by(*util.to_list(group_by, []))
+        self._order_by_clause = ClauseList(*util.to_list(order_by, []))
+        self._group_by_clause = ClauseList(*util.to_list(group_by, []))
 
     def as_scalar(self):
         """return a 'scalar' representation of this selectable, which can be used
@@ -2967,30 +2976,41 @@ class Select(_SelectBaseMixin, FromClause):
         # usually called via a generative method, create a copy of each collection
         # by default
 
-        self._raw_columns = []
         self.__correlate = util.Set()
-        self._froms = util.OrderedSet()
-        self._whereclause = None
         self._having = None
         self._prefixes = []
 
-        if columns is not None:
-            for c in columns:
-                self.append_column(c, _copy_collection=False)
-
-        if from_obj is not None:
-            for f in from_obj:
-                self.append_from(f, _copy_collection=False)
+        if columns:
+            self._raw_columns = [
+                isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
+                for c in 
+                [_literal_as_column(c) for c in columns]
+            ]
+        else:
+            self._raw_columns = []
+            
+        if from_obj:
+            self._froms = util.Set([
+                _is_literal(f) and _TextFromClause(f) or f
+                for f in from_obj
+            ])
+        else:
+            self._froms = util.Set()
 
-        if whereclause is not None:
-            self.append_whereclause(whereclause)
+        if whereclause:
+            self._whereclause = _literal_as_text(whereclause)
+        else:
+            self._whereclause = None
 
-        if having is not None:
-            self.append_having(having)
+        if having:
+            self._having = _literal_as_text(having)
+        else:
+            self._having = None
 
-        if prefixes is not None:
-            for p in prefixes:
-                self.append_prefix(p, _copy_collection=False)
+        if prefixes:
+            self._prefixes = [_literal_as_text(p) for p in prefixes]
+        else:
+            self._prefixes = []
 
         _SelectBaseMixin.__init__(self, **kwargs)
 
@@ -3003,48 +3023,30 @@ class Select(_SelectBaseMixin, FromClause):
         correlating.
         """
 
-        froms = util.OrderedSet()
+        froms = util.Set()
         hide_froms = util.Set()
 
         for col in self._raw_columns:
-            for f in col._hide_froms():
-                hide_froms.add(f)
-                while hasattr(f, '_is_clone_of'):
-                    hide_froms.add(f._is_clone_of)
-                    f = f._is_clone_of
-            for f in col._get_from_objects():
-                froms.add(f)
+            hide_froms.update(col._aggregate_hide_froms())
+            froms.update(col._get_from_objects())
 
         if self._whereclause is not None:
-            for f in self._whereclause._get_from_objects(is_where=True):
-                froms.add(f)
+            froms.update(self._whereclause._get_from_objects(is_where=True))
 
-        for elem in self._froms:
-            froms.add(elem)
-            for f in elem._get_from_objects():
-                froms.add(f)
-        
-        for elem in froms:
-            for f in elem._hide_froms():
-                hide_froms.add(f)
-                while hasattr(f, '_is_clone_of'):
-                    hide_froms.add(f._is_clone_of)
-                    f = f._is_clone_of
+        if self._froms:
+            froms.update(self._froms)
+            for elem in self._froms:
+                hide_froms.update(elem._aggregate_hide_froms())
         
         froms = froms.difference(hide_froms)
         
         if len(froms) > 1:
             corr = self.__correlate
             if self._should_correlate and existing_froms is not None:
-                corr = existing_froms.union(corr)
-                
-            for f in list(corr):
-                while hasattr(f, '_is_clone_of'):
-                    corr.add(f._is_clone_of)
-                    f = f._is_clone_of
+                corr.update(existing_froms)
                 
             f = froms.difference(corr)
-            if len(f) == 0:
+            if not f:
                 raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
             return f
         else:
index 70d1940e62e7b842e0fb0d5b4b5b161c25eb4b11..3e2d4ec311c30f299664abe6820ff29c7dae029c 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy.sql import expression, visitors
 
 """Utility functions that build upon SQL and Schema constructs."""
 
+# TODO: replace with plain list.  break out sorting funcs into module-level funcs
 class TableCollection(object):
     def __init__(self, tables=None):
         self.tables = tables or []
@@ -65,6 +66,7 @@ class TableCollection(object):
         return sequence
 
 
+# TODO: replace with plain module-level func
 class TableFinder(TableCollection, visitors.NoColumnVisitor):
     """locate all Tables within a clause."""
 
index 1a0629a17dcd3145475cfec3359e2a261330f9ec..150ee9cc7b497c101cedbb34df47dcb5b8fc29a4 100644 (file)
@@ -98,3 +98,15 @@ class NoColumnVisitor(ClauseVisitor):
     """
     
     __traverse_options__ = {'column_collections':False}
+
+def traverse(clause, **kwargs):
+    clone = kwargs.pop('clone', False)
+    class Vis(ClauseVisitor):
+        __traverse_options__ = kwargs.pop('traverse_options', {})
+        def __getattr__(self, key):
+            if key in kwargs:
+                return kwargs[key]
+            else:
+                return None
+    return Vis().traverse(clause, clone=clone)
+
index 29e17db778cbc566343f6862d6509d87dc22afde..544e674f3e2966fe03973431666cd8067339b358 100644 (file)
@@ -24,7 +24,7 @@ class CompileTest(AssertMixin):
         t1.update().compile()
 
     # TODO: this is alittle high
-    @profiling.profiled('ctest_select', call_range=(170, 200), always=True)        
+    @profiling.profiled('ctest_select', call_range=(130, 150), always=True)        
     def test_select(self):
         s = select([t1], t1.c.c2==t2.c.c1)
         s.compile()
index 8503ddc1fc4402610068af6c1eea9917b0a7fb40..d18502c72aaa1a7c8dedfc90f23aedac9141a56f 100644 (file)
@@ -245,7 +245,7 @@ class ZooMarkTest(testing.AssertMixin):
             legs.sort()
     
     @testing.supported('postgres')
-    @profiling.profiled('editing', call_range=(1200, 1290), always=True)
+    @profiling.profiled('editing', call_range=(1150, 1280), always=True)
     def test_6_editing(self):
         Zoo = metadata.tables['Zoo']