From 3f93103a5ef9128b7b300c51d41dea43dd843834 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 24 Nov 2007 00:55:39 +0000 Subject: [PATCH] - all kinds of cleanup, tiny-to-slightly-significant speed improvements --- lib/sqlalchemy/databases/maxdb.py | 31 ++++----- lib/sqlalchemy/databases/oracle.py | 13 ++-- lib/sqlalchemy/databases/sybase.py | 7 +- lib/sqlalchemy/engine/base.py | 28 ++++---- lib/sqlalchemy/engine/default.py | 9 +-- lib/sqlalchemy/orm/attributes.py | 9 ++- lib/sqlalchemy/orm/mapper.py | 9 +-- lib/sqlalchemy/orm/properties.py | 24 +++---- lib/sqlalchemy/orm/query.py | 10 +-- lib/sqlalchemy/orm/strategies.py | 21 ++---- lib/sqlalchemy/orm/sync.py | 8 +-- lib/sqlalchemy/orm/unitofwork.py | 19 ++++-- lib/sqlalchemy/orm/util.py | 12 ++-- lib/sqlalchemy/sql/compiler.py | 78 +++++++++++---------- lib/sqlalchemy/sql/expression.py | 106 +++++++++++++++-------------- lib/sqlalchemy/sql/util.py | 2 + lib/sqlalchemy/sql/visitors.py | 12 ++++ test/profiling/compiler.py | 2 +- test/profiling/zoomark.py | 2 +- 19 files changed, 194 insertions(+), 208 deletions(-) diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/databases/maxdb.py index 09cc9f0b68..3ca1bd61cc 100644 --- a/lib/sqlalchemy/databases/maxdb.py +++ b/lib/sqlalchemy/databases/maxdb.py @@ -480,6 +480,21 @@ class MaxDBDialect(default.DefaultDialect): super(MaxDBDialect, self).__init__(**kw) self._raise_known = _raise_known_sql_errors + if self.dbapi is None: + self.dbapi_type_map = {} + else: + self.dbapi_type_map = { + 'Long Binary': MaxBlob(), + 'Long byte_t': MaxBlob(), + 'Long Unicode': MaxText(), + 'Timestamp': MaxTimestamp(), + 'Date': MaxDate(), + 'Time': MaxTime(), + datetime.datetime: MaxTimestamp(), + datetime.date: MaxDate(), + datetime.time: MaxTime(), + } + def dbapi(cls): from sapdb import dbapi as _dbapi return _dbapi @@ -498,22 +513,6 @@ class MaxDBDialect(default.DefaultDialect): else: return sqltypes.adapt_type(typeobj, colspecs) - def dbapi_type_map(self): - if self.dbapi is None: - return {} - else: - return { - 'Long Binary': MaxBlob(), - 'Long byte_t': MaxBlob(), - 'Long Unicode': MaxText(), - 'Timestamp': MaxTimestamp(), - 'Date': MaxDate(), - 'Time': MaxTime(), - datetime.datetime: MaxTimestamp(), - datetime.date: MaxDate(), - datetime.time: MaxTime(), - } - def create_execution_context(self, connection, **kw): return MaxDBExecutionContext(self, connection, **kw) diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 71c82bfd86..9c9c54bf66 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -251,24 +251,19 @@ class OracleDialect(default.DefaultDialect): self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' ) self.auto_setinputsizes = auto_setinputsizes self.auto_convert_lobs = auto_convert_lobs - - if self.dbapi is not None: - self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] - else: + if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__: + self.dbapi_type_map = {} self.ORACLE_BINARY_TYPES = [] - - def dbapi_type_map(self): - if self.dbapi is None or not self.auto_convert_lobs: - return {} else: # only use this for LOB objects. using it for strings, dates # etc. leads to a little too much magic, reflection doesn't know if it should # expect encoded strings or unicodes, etc. - return { + self.dbapi_type_map = { self.dbapi.CLOB: OracleText(), self.dbapi.BLOB: OracleBinary(), self.dbapi.BINARY: OracleRaw(), } + self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)] def dbapi(cls): import cx_Oracle diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py index 022f547463..87045d1926 100644 --- a/lib/sqlalchemy/databases/sybase.py +++ b/lib/sqlalchemy/databases/sybase.py @@ -639,8 +639,7 @@ class SybaseSQLDialect_mxodbc(SybaseSQLDialect): def __init__(self, **params): super(SybaseSQLDialect_mxodbc, self).__init__(**params) - def dbapi_type_map(self): - return {'getdate' : SybaseDate_mxodbc()} + self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()} def import_dbapi(cls): #import mx.ODBC.Windows as module @@ -686,9 +685,7 @@ class SybaseSQLDialect_mxodbc(SybaseSQLDialect): class SybaseSQLDialect_pyodbc(SybaseSQLDialect): def __init__(self, **params): super(SybaseSQLDialect_pyodbc, self).__init__(**params) - - def dbapi_type_map(self): - return {'getdate' : SybaseDate_pyodbc()} + self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()} def import_dbapi(cls): import mypyodbc as module diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 6af0ce0d39..9c7e70ba92 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -87,6 +87,17 @@ class Dialect(object): supports_pk_autoincrement Indicates if the dialect should allow the database to passively assign a primary key column value. + + dbapi_type_map + A mapping of DB-API type objects present in this Dialect's + DB-API implmentation mapped to TypeEngine implementations used + by the dialect. + + This is used to apply types to result sets based on the DB-API + types present in cursor.description; it only takes effect for + result sets against textual statements where no explicit + typemap was present. + """ def create_connect_args(self, url): @@ -99,21 +110,6 @@ class Dialect(object): raise NotImplementedError() - def dbapi_type_map(self): - """Returns a DB-API to sqlalchemy.types mapping. - - A mapping of DB-API type objects present in this Dialect's - DB-API implmentation mapped to TypeEngine implementations used - by the dialect. - - This is used to apply types to result sets based on the DB-API - types present in cursor.description; it only takes effect for - result sets against textual statements where no explicit - typemap was present. Constructed SQL statements always have - type information explicitly embedded. - """ - - raise NotImplementedError() def type_descriptor(self, typeobj): """Transform a generic type to a database-specific type. @@ -1339,7 +1335,7 @@ class ResultProxy(object): metadata = self.cursor.description if metadata is not None: - typemap = self.dialect.dbapi_type_map() + typemap = self.dialect.dbapi_type_map for i, item in enumerate(metadata): # sqlite possibly prepending table name to colnames so strip diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 198f6742bf..a91d65b81f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -33,7 +33,8 @@ class DefaultDialect(base.Dialect): supports_sane_multi_rowcount = True preexecute_pk_sequences = False supports_pk_autoincrement = True - + dbapi_type_map = {} + def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs): self.convert_unicode = convert_unicode self.encoding = encoding @@ -59,12 +60,6 @@ class DefaultDialect(base.Dialect): property(lambda s: s.preexecute_sequences, doc=( "Proxy to deprecated preexecute_sequences attribute."))) - def dbapi_type_map(self): - # most DB-APIs have problems with this (such as, psycocpg2 types - # are unhashable). So far Oracle can return it. - - return {} - def create_execution_context(self, connection, **kwargs): return DefaultExecutionContext(self, connection, **kwargs) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 27f4b017c3..123a99c9a8 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -686,7 +686,7 @@ class InstanceState(object): self._strong_obj = None -class InstanceDict(UserDict.UserDict): +class WeakInstanceDict(UserDict.UserDict): """similar to WeakValueDictionary, but wired towards 'state' objects.""" def __init__(self, *args, **kw): @@ -802,7 +802,12 @@ class InstanceDict(UserDict.UserDict): def copy(self): raise NotImplementedError() - + def all_states(self): + return self.data.values() + +class StrongInstanceDict(dict): + def all_states(self): + return [o._state for o in self.values()] class AttributeHistory(object): """Calculate the *history* of a particular attribute on a diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 42fe8f56fb..3bacb13e93 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -6,7 +6,7 @@ import weakref, warnings, operator from sqlalchemy import sql, util, exceptions, logging -from sqlalchemy.sql import expression +from sqlalchemy.sql import expression, visitors from sqlalchemy.sql import util as sqlutil from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter @@ -513,11 +513,9 @@ class Mapper(object): result[binary.right].add(binary.left) else: result[binary.right] = util.Set([binary.left]) - vis = mapperutil.BinaryVisitor(visit_binary) - for mapper in self.base_mapper.polymorphic_iterator(): if mapper.inherit_condition is not None: - vis.traverse(mapper.inherit_condition) + visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary) # TODO: matching of cols to foreign keys might better be generalized # into general column translation (i.e. corresponding_column) @@ -1472,11 +1470,10 @@ class Mapper(object): allconds = [] param_names = [] - visitor = mapperutil.BinaryVisitor(visit_binary) for mapper in self.iterate_to_root(): if mapper is base_mapper: break - allconds.append(visitor.traverse(mapper.inherit_condition, clone=True)) + allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary)) return sql.and_(*allconds), param_names diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 46224aac65..9e7815e38e 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -12,7 +12,7 @@ to handle flush-time dependency sorting and processing. """ from sqlalchemy import sql, schema, util, exceptions, logging -from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql import util as sql_util, visitors from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil @@ -439,9 +439,9 @@ class PropertyLoader(StrategizedProperty): self._opposite_side.add(binary.right) if binary.right in self.foreign_keys: self._opposite_side.add(binary.left) - mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin) + visitors.traverse(self.primaryjoin, visit_binary=visit_binary) if self.secondaryjoin is not None: - mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin) + visitors.traverse(self.secondaryjoin, visit_binary=visit_binary) else: self.foreign_keys = util.Set() self._opposite_side = util.Set() @@ -463,7 +463,7 @@ class PropertyLoader(StrategizedProperty): if f.references(binary.left.table): self.foreign_keys.add(binary.right) self._opposite_side.add(binary.left) - mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin) + visitors.traverse(self.primaryjoin, visit_binary=visit_binary) if len(self.foreign_keys) == 0: raise exceptions.ArgumentError( @@ -472,7 +472,7 @@ class PropertyLoader(StrategizedProperty): "'foreign_keys' argument to indicate which columns in " "the join condition are foreign." %(str(self.primaryjoin), str(self))) if self.secondaryjoin is not None: - mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin) + visitors.traverse(self.secondaryjoin, visit_binary=visit_binary) def _determine_direction(self): @@ -543,14 +543,13 @@ class PropertyLoader(StrategizedProperty): # in the "polymorphic" selectables. these are used to construct joins for both Query as well as # eager loading, and also are used to calculate "lazy loading" clauses. - # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out, - # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge - # several "equivalent" columns (such as parent/child fk cols) into just one column. - - target_equivalents = self.mapper._get_equivalent_columns() - - # if the target mapper loads polymorphically, adapt the clauses to the target's selectable if self.loads_polymorphic: + + # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out, + # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge + # several "equivalent" columns (such as parent/child fk cols) into just one column. + target_equivalents = self.mapper._get_equivalent_columns() + if self.secondaryjoin: self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True) self.polymorphic_primaryjoin = self.primaryjoin @@ -560,6 +559,7 @@ class PropertyLoader(StrategizedProperty): elif self.direction is sync.MANYTOONE: self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True) self.polymorphic_secondaryjoin = None + # load "polymorphic" versions of the columns present in "remote_side" - this is # important for lazy-clause generation which goes off the polymorphic target selectable for c in list(self.remote_side): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 47a7022695..77f7fbe04d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -907,11 +907,8 @@ class Query(object): for o in order_by: cf.traverse(o) - s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, **self._select_args()) + s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args()) - if order_by: - s2.append_order_by(*util.to_list(order_by)) - s3 = s2.alias() self._primary_adapter = mapperutil.create_row_adapter(s3, self.table) @@ -926,13 +923,10 @@ class Query(object): statement.append_order_by(*context.eager_order_by) else: - statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args()) + statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=util.to_list(order_by), **self._select_args()) if context.eager_joins: statement.append_from(context.eager_joins, _copy_collection=False) - if order_by: - statement.append_order_by(*util.to_list(order_by)) - if context.eager_order_by: statement.append_order_by(*context.eager_order_by) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 2caee2dd4a..7911b93c8b 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -407,21 +407,11 @@ class LazyLoader(AbstractRelationLoader): else: return othercol in remote_side - def find_column_in_expr(expr): - if not isinstance(expr, sql.ColumnElement): - return None - columns = [] - class FindColumnInColumnClause(visitors.ClauseVisitor): - def visit_column(self, c): - columns.append(c) - FindColumnInColumnClause().traverse(expr) - return len(columns) and columns[0] or None - def visit_binary(binary): - leftcol = find_column_in_expr(binary.left) - rightcol = find_column_in_expr(binary.right) - if leftcol is None or rightcol is None: + if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement): return + leftcol = binary.left + rightcol = binary.right if should_bind(leftcol, rightcol): col = leftcol @@ -438,14 +428,13 @@ class LazyLoader(AbstractRelationLoader): reverse[leftcol] = binds[col] lazywhere = primaryjoin - li = mapperutil.BinaryVisitor(visit_binary) if not secondaryjoin or not reverse_direction: - lazywhere = li.traverse(lazywhere, clone=True) + lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary) if secondaryjoin is not None: if reverse_direction: - secondaryjoin = li.traverse(secondaryjoin, clone=True) + secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary) lazywhere = sql.and_(lazywhere, secondaryjoin) return (lazywhere, binds, reverse) _create_lazy_clause = classmethod(_create_lazy_clause) diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 5b5a9e43b1..9575aa958f 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -79,7 +79,7 @@ class ClauseSynchronizer(object): self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary)) rules_added = len(self.syncrules) - BinaryVisitor(compile_binary).traverse(sqlclause) + visitors.traverse(sqlclause, visit_binary=compile_binary) if len(self.syncrules) == rules_added: raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause)) @@ -144,9 +144,3 @@ class SyncRule(object): SyncRule.logger = logging.class_logger(SyncRule) -class BinaryVisitor(visitors.ClauseVisitor): - def __init__(self, func): - self.func = func - - def visit_binary(self, binary): - self.func(binary) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 2cd7cb6f5d..cdffad266b 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -89,9 +89,9 @@ class UnitOfWork(object): def __init__(self, session): if session.weak_identity_map: - self.identity_map = attributes.InstanceDict() + self.identity_map = attributes.WeakInstanceDict() else: - self.identity_map = {} + self.identity_map = attributes.StrongInstanceDict() self.new = util.IdentitySet() #OrderedSet() self.deleted = util.IdentitySet() @@ -158,6 +158,7 @@ class UnitOfWork(object): ) ]) + def flush(self, session, objects=None): """create a dependency tree of all pending SQL operations within this unit of work and execute.""" @@ -166,12 +167,16 @@ class UnitOfWork(object): # communication with the mappers and relationships to fire off SQL # and synchronize attributes between related objects. - # detect persistent objects that have changes - dirty = self.locate_dirty() - + dirty = [x for x in self.identity_map.all_states() + if x.modified + or (getattr(x.class_, '_sa_has_mutable_scalars', False) and attribute_manager._is_modified(x)) + ] + if len(dirty) == 0 and len(self.deleted) == 0 and len(self.new) == 0: return - + + dirty = util.IdentitySet([x.obj() for x in dirty]).difference(self.deleted) + flush_context = UOWTransaction(self, session) if session.extension is not None: @@ -232,7 +237,7 @@ class UnitOfWork(object): the number of objects pruned. """ - if isinstance(self.identity_map, attributes.InstanceDict): + if isinstance(self.identity_map, attributes.WeakInstanceDict): return 0 ref_count = len(self.identity_map) dirty = self.locate_dirty() diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 55d4db98fe..a5e2d1e6e8 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -98,7 +98,10 @@ class TranslatingDict(dict): return ourcol def __getitem__(self, col): - return super(TranslatingDict, self).__getitem__(self.__translate_col(col)) + try: + return super(TranslatingDict, self).__getitem__(col) + except KeyError: + return super(TranslatingDict, self).__getitem__(self.__translate_col(col)) def has_key(self, col): return col in self @@ -172,13 +175,6 @@ class ExtensionCarrier(object): def __getattr__(self, key): return self.methods.get(key, self._pass) -class BinaryVisitor(visitors.ClauseVisitor): - def __init__(self, func): - self.func = func - - def visit_binary(self, binary): - self.func(binary) - class AliasedClauses(object): """Creates aliases of a mapped tables for usage in ORM queries. """ diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index dd6f0dddde..c1f3bc2a05 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -245,19 +245,25 @@ class DefaultCompiler(engine.Compiled): n = self.dialect.oid_column_name(column) if n is not None: if column.table is None or not column.table.named_with_column(): - return self.preparer.format_column(column, name=n) + return n else: - return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)), n) + return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n elif len(column.table.primary_key) != 0: pk = list(column.table.primary_key)[0] pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name)) - return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname) else: return None elif column.table is None or not column.table.named_with_column(): - return self.preparer.format_column(column, name=name) + if getattr(column, "is_literal", False): + return name + else: + return self.preparer.quote(column, name) else: - return self.preparer.format_column_with_table(column, column_name=name, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + if getattr(column, "is_literal", False): + return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name + else: + return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(column, name) def visit_fromclause(self, fromclause, **kwargs): @@ -588,7 +594,10 @@ class DefaultCompiler(engine.Compiled): def visit_table(self, table, asfrom=False, **kwargs): if asfrom: - return self.preparer.format_table(table) + if getattr(table, "schema", None): + return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name) + else: + return self.preparer.quote(table, table.name) else: return "" @@ -606,7 +615,7 @@ class DefaultCompiler(engine.Compiled): return ("INSERT INTO %s (%s) VALUES (%s)" % (preparer.format_table(insert_stmt.table), - ', '.join([preparer.format_column(c[0]) + ', '.join([preparer.quote(c[0], c[0].name) for c in colparams]), ', '.join([c[1] for c in colparams]))) @@ -616,7 +625,7 @@ class DefaultCompiler(engine.Compiled): self.isupdate = True colparams = self._get_colparams(update_stmt) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ') + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ') if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) @@ -831,7 +840,7 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint])) + self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint])) def visit_foreign_key_constraint(self, constraint): if constraint.use_alter and self.dialect.supports_alter: @@ -849,10 +858,11 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % preparer.format_constraint(constraint)) + table = list(constraint.elements)[0].column.table self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join([preparer.format_column(f.parent) for f in constraint.elements]), - preparer.format_table(list(constraint.elements)[0].column.table), - ', '.join([preparer.format_column(f.column) for f in constraint.elements]) + ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]), + preparer.format_table(table), + ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements]) )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) @@ -864,7 +874,7 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint]))) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint]))) def visit_column(self, column): pass @@ -877,7 +887,7 @@ class SchemaGenerator(DDLBase): self.append("INDEX %s ON %s (%s)" \ % (preparer.format_index(index), preparer.format_table(index.table), - string.join([preparer.format_column(c) for c in index.columns], ', '))) + string.join([preparer.quote(c, c.name) for c in index.columns], ', '))) self.execute() class SchemaDropper(DDLBase): @@ -978,12 +988,12 @@ class IdentifierPreparer(object): or not self.legal_characters.match(unicode(value)) or (lc_value != value)) - def __generic_obj_format(self, obj, ident): + def quote(self, obj, ident): if getattr(obj, 'quote', False): return self.quote_identifier(ident) - try: + if ident in self.__strings: return self.__strings[ident] - except KeyError: + else: if self._requires_quotes(ident): self.__strings[ident] = self.quote_identifier(ident) else: @@ -994,45 +1004,49 @@ class IdentifierPreparer(object): return object.quote or self._requires_quotes(object.name) def format_sequence(self, sequence, use_schema=True): - name = self.__generic_obj_format(sequence, sequence.name) + name = self.quote(sequence, sequence.name) if use_schema and sequence.schema is not None: - name = self.__generic_obj_format(sequence, sequence.schema) + "." + name + name = self.quote(sequence, sequence.schema) + "." + name return name def format_label(self, label, name=None): - return self.__generic_obj_format(label, name or label.name) + return self.quote(label, name or label.name) def format_alias(self, alias, name=None): - return self.__generic_obj_format(alias, name or alias.name) + return self.quote(alias, name or alias.name) def format_savepoint(self, savepoint, name=None): - return self.__generic_obj_format(savepoint, name or savepoint.ident) + return self.quote(savepoint, name or savepoint.ident) def format_constraint(self, constraint): - return self.__generic_obj_format(constraint, constraint.name) + return self.quote(constraint, constraint.name) def format_index(self, index): - return self.__generic_obj_format(index, index.name) + return self.quote(index, index.name) def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name - result = self.__generic_obj_format(table, name) + result = self.quote(table, name) if use_schema and getattr(table, "schema", None): - result = self.__generic_obj_format(table, table.schema) + "." + result + result = self.quote(table, table.schema) + "." + result return result def format_column(self, column, use_table=False, name=None, table_name=None): - """Prepare a quoted column name.""" + """Prepare a quoted column name. + + deprecated. use preparer.quote(col, column.name) or combine with format_table() + """ + if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name) + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name) else: - return self.__generic_obj_format(column, name) + return self.quote(column, name) else: # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted if use_table: @@ -1040,12 +1054,6 @@ class IdentifierPreparer(object): else: return name - def format_column_with_table(self, column, column_name=None, table_name=None): - """Prepare a quoted column name with table name.""" - - return self.format_column(column, use_table=True, name=column_name, table_name=table_name) - - def format_table_seq(self, table, use_schema=True): """Format table name and schema as a tuple.""" diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index c7ab342722..b3200a7eba 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -863,6 +863,16 @@ class ClauseElement(object): raise NotImplementedError(repr(self)) + def _aggregate_hide_froms(self, **modifiers): + """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces, taking into account + previous ClauseElements which this ClauseElement is a clone of.""" + + s = self + while s is not None: + for h in s._hide_froms(**modifiers): + yield h + s = getattr(s, '_is_clone_of', None) + def _hide_froms(self, **modifiers): """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces.""" @@ -2203,11 +2213,10 @@ class Join(FromClause): else: equivs[x] = util.Set([y]) - class BinaryVisitor(visitors.ClauseVisitor): - def visit_binary(self, binary): - if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column): - add_equiv(binary.left, binary.right) - BinaryVisitor().traverse(self.onclause) + def visit_binary(binary): + if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column): + add_equiv(binary.left, binary.right) + visitors.traverse(self.onclause, visit_binary=visit_binary) for col in pkcol: for fk in col.foreign_keys: @@ -2719,8 +2728,8 @@ class _SelectBaseMixin(object): self._offset = offset self._bind = bind - self.append_order_by(*util.to_list(order_by, [])) - self.append_group_by(*util.to_list(group_by, [])) + self._order_by_clause = ClauseList(*util.to_list(order_by, [])) + self._group_by_clause = ClauseList(*util.to_list(group_by, [])) def as_scalar(self): """return a 'scalar' representation of this selectable, which can be used @@ -2967,30 +2976,41 @@ class Select(_SelectBaseMixin, FromClause): # usually called via a generative method, create a copy of each collection # by default - self._raw_columns = [] self.__correlate = util.Set() - self._froms = util.OrderedSet() - self._whereclause = None self._having = None self._prefixes = [] - if columns is not None: - for c in columns: - self.append_column(c, _copy_collection=False) - - if from_obj is not None: - for f in from_obj: - self.append_from(f, _copy_collection=False) + if columns: + self._raw_columns = [ + isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c + for c in + [_literal_as_column(c) for c in columns] + ] + else: + self._raw_columns = [] + + if from_obj: + self._froms = util.Set([ + _is_literal(f) and _TextFromClause(f) or f + for f in from_obj + ]) + else: + self._froms = util.Set() - if whereclause is not None: - self.append_whereclause(whereclause) + if whereclause: + self._whereclause = _literal_as_text(whereclause) + else: + self._whereclause = None - if having is not None: - self.append_having(having) + if having: + self._having = _literal_as_text(having) + else: + self._having = None - if prefixes is not None: - for p in prefixes: - self.append_prefix(p, _copy_collection=False) + if prefixes: + self._prefixes = [_literal_as_text(p) for p in prefixes] + else: + self._prefixes = [] _SelectBaseMixin.__init__(self, **kwargs) @@ -3003,48 +3023,30 @@ class Select(_SelectBaseMixin, FromClause): correlating. """ - froms = util.OrderedSet() + froms = util.Set() hide_froms = util.Set() for col in self._raw_columns: - for f in col._hide_froms(): - hide_froms.add(f) - while hasattr(f, '_is_clone_of'): - hide_froms.add(f._is_clone_of) - f = f._is_clone_of - for f in col._get_from_objects(): - froms.add(f) + hide_froms.update(col._aggregate_hide_froms()) + froms.update(col._get_from_objects()) if self._whereclause is not None: - for f in self._whereclause._get_from_objects(is_where=True): - froms.add(f) + froms.update(self._whereclause._get_from_objects(is_where=True)) - for elem in self._froms: - froms.add(elem) - for f in elem._get_from_objects(): - froms.add(f) - - for elem in froms: - for f in elem._hide_froms(): - hide_froms.add(f) - while hasattr(f, '_is_clone_of'): - hide_froms.add(f._is_clone_of) - f = f._is_clone_of + if self._froms: + froms.update(self._froms) + for elem in self._froms: + hide_froms.update(elem._aggregate_hide_froms()) froms = froms.difference(hide_froms) if len(froms) > 1: corr = self.__correlate if self._should_correlate and existing_froms is not None: - corr = existing_froms.union(corr) - - for f in list(corr): - while hasattr(f, '_is_clone_of'): - corr.add(f._is_clone_of) - f = f._is_clone_of + corr.update(existing_froms) f = froms.difference(corr) - if len(f) == 0: + if not f: raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) return f else: diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 70d1940e62..3e2d4ec311 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -3,6 +3,7 @@ from sqlalchemy.sql import expression, visitors """Utility functions that build upon SQL and Schema constructs.""" +# TODO: replace with plain list. break out sorting funcs into module-level funcs class TableCollection(object): def __init__(self, tables=None): self.tables = tables or [] @@ -65,6 +66,7 @@ class TableCollection(object): return sequence +# TODO: replace with plain module-level func class TableFinder(TableCollection, visitors.NoColumnVisitor): """locate all Tables within a clause.""" diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 1a0629a17d..150ee9cc7b 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -98,3 +98,15 @@ class NoColumnVisitor(ClauseVisitor): """ __traverse_options__ = {'column_collections':False} + +def traverse(clause, **kwargs): + clone = kwargs.pop('clone', False) + class Vis(ClauseVisitor): + __traverse_options__ = kwargs.pop('traverse_options', {}) + def __getattr__(self, key): + if key in kwargs: + return kwargs[key] + else: + return None + return Vis().traverse(clause, clone=clone) + diff --git a/test/profiling/compiler.py b/test/profiling/compiler.py index 29e17db778..544e674f3e 100644 --- a/test/profiling/compiler.py +++ b/test/profiling/compiler.py @@ -24,7 +24,7 @@ class CompileTest(AssertMixin): t1.update().compile() # TODO: this is alittle high - @profiling.profiled('ctest_select', call_range=(170, 200), always=True) + @profiling.profiled('ctest_select', call_range=(130, 150), always=True) def test_select(self): s = select([t1], t1.c.c2==t2.c.c1) s.compile() diff --git a/test/profiling/zoomark.py b/test/profiling/zoomark.py index 8503ddc1fc..d18502c72a 100644 --- a/test/profiling/zoomark.py +++ b/test/profiling/zoomark.py @@ -245,7 +245,7 @@ class ZooMarkTest(testing.AssertMixin): legs.sort() @testing.supported('postgres') - @profiling.profiled('editing', call_range=(1200, 1290), always=True) + @profiling.profiled('editing', call_range=(1150, 1280), always=True) def test_6_editing(self): Zoo = metadata.tables['Zoo'] -- 2.47.2