super(MaxDBDialect, self).__init__(**kw)
self._raise_known = _raise_known_sql_errors
+ if self.dbapi is None:
+ self.dbapi_type_map = {}
+ else:
+ self.dbapi_type_map = {
+ 'Long Binary': MaxBlob(),
+ 'Long byte_t': MaxBlob(),
+ 'Long Unicode': MaxText(),
+ 'Timestamp': MaxTimestamp(),
+ 'Date': MaxDate(),
+ 'Time': MaxTime(),
+ datetime.datetime: MaxTimestamp(),
+ datetime.date: MaxDate(),
+ datetime.time: MaxTime(),
+ }
+
def dbapi(cls):
from sapdb import dbapi as _dbapi
return _dbapi
else:
return sqltypes.adapt_type(typeobj, colspecs)
- def dbapi_type_map(self):
- if self.dbapi is None:
- return {}
- else:
- return {
- 'Long Binary': MaxBlob(),
- 'Long byte_t': MaxBlob(),
- 'Long Unicode': MaxText(),
- 'Timestamp': MaxTimestamp(),
- 'Date': MaxDate(),
- 'Time': MaxTime(),
- datetime.datetime: MaxTimestamp(),
- datetime.date: MaxDate(),
- datetime.time: MaxTime(),
- }
-
def create_execution_context(self, connection, **kw):
return MaxDBExecutionContext(self, connection, **kw)
self.supports_timestamp = self.dbapi is None or hasattr(self.dbapi, 'TIMESTAMP' )
self.auto_setinputsizes = auto_setinputsizes
self.auto_convert_lobs = auto_convert_lobs
-
- if self.dbapi is not None:
- self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
- else:
+ if self.dbapi is None or not self.auto_convert_lobs or not 'CLOB' in self.dbapi.__dict__:
+ self.dbapi_type_map = {}
self.ORACLE_BINARY_TYPES = []
-
- def dbapi_type_map(self):
- if self.dbapi is None or not self.auto_convert_lobs:
- return {}
else:
# only use this for LOB objects. using it for strings, dates
# etc. leads to a little too much magic, reflection doesn't know if it should
# expect encoded strings or unicodes, etc.
- return {
+ self.dbapi_type_map = {
self.dbapi.CLOB: OracleText(),
self.dbapi.BLOB: OracleBinary(),
self.dbapi.BINARY: OracleRaw(),
}
+ self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB"] if hasattr(self.dbapi, k)]
def dbapi(cls):
import cx_Oracle
def __init__(self, **params):
super(SybaseSQLDialect_mxodbc, self).__init__(**params)
- def dbapi_type_map(self):
- return {'getdate' : SybaseDate_mxodbc()}
+ self.dbapi_type_map = {'getdate' : SybaseDate_mxodbc()}
def import_dbapi(cls):
#import mx.ODBC.Windows as module
class SybaseSQLDialect_pyodbc(SybaseSQLDialect):
def __init__(self, **params):
super(SybaseSQLDialect_pyodbc, self).__init__(**params)
-
- def dbapi_type_map(self):
- return {'getdate' : SybaseDate_pyodbc()}
+ self.dbapi_type_map = {'getdate' : SybaseDate_pyodbc()}
def import_dbapi(cls):
import mypyodbc as module
supports_pk_autoincrement
Indicates if the dialect should allow the database to passively assign
a primary key column value.
+
+ dbapi_type_map
+ A mapping of DB-API type objects present in this Dialect's
+ DB-API implmentation mapped to TypeEngine implementations used
+ by the dialect.
+
+ This is used to apply types to result sets based on the DB-API
+ types present in cursor.description; it only takes effect for
+ result sets against textual statements where no explicit
+ typemap was present.
+
"""
def create_connect_args(self, url):
raise NotImplementedError()
- def dbapi_type_map(self):
- """Returns a DB-API to sqlalchemy.types mapping.
-
- A mapping of DB-API type objects present in this Dialect's
- DB-API implmentation mapped to TypeEngine implementations used
- by the dialect.
-
- This is used to apply types to result sets based on the DB-API
- types present in cursor.description; it only takes effect for
- result sets against textual statements where no explicit
- typemap was present. Constructed SQL statements always have
- type information explicitly embedded.
- """
-
- raise NotImplementedError()
def type_descriptor(self, typeobj):
"""Transform a generic type to a database-specific type.
metadata = self.cursor.description
if metadata is not None:
- typemap = self.dialect.dbapi_type_map()
+ typemap = self.dialect.dbapi_type_map
for i, item in enumerate(metadata):
# sqlite possibly prepending table name to colnames so strip
supports_sane_multi_rowcount = True
preexecute_pk_sequences = False
supports_pk_autoincrement = True
-
+ dbapi_type_map = {}
+
def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
self.convert_unicode = convert_unicode
self.encoding = encoding
property(lambda s: s.preexecute_sequences, doc=(
"Proxy to deprecated preexecute_sequences attribute.")))
- def dbapi_type_map(self):
- # most DB-APIs have problems with this (such as, psycocpg2 types
- # are unhashable). So far Oracle can return it.
-
- return {}
-
def create_execution_context(self, connection, **kwargs):
return DefaultExecutionContext(self, connection, **kwargs)
self._strong_obj = None
-class InstanceDict(UserDict.UserDict):
+class WeakInstanceDict(UserDict.UserDict):
"""similar to WeakValueDictionary, but wired towards 'state' objects."""
def __init__(self, *args, **kw):
def copy(self):
raise NotImplementedError()
-
+ def all_states(self):
+ return self.data.values()
+
+class StrongInstanceDict(dict):
+ def all_states(self):
+ return [o._state for o in self.values()]
class AttributeHistory(object):
"""Calculate the *history* of a particular attribute on a
import weakref, warnings, operator
from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy.sql import expression
+from sqlalchemy.sql import expression, visitors
from sqlalchemy.sql import util as sqlutil
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter
result[binary.right].add(binary.left)
else:
result[binary.right] = util.Set([binary.left])
- vis = mapperutil.BinaryVisitor(visit_binary)
-
for mapper in self.base_mapper.polymorphic_iterator():
if mapper.inherit_condition is not None:
- vis.traverse(mapper.inherit_condition)
+ visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary)
# TODO: matching of cols to foreign keys might better be generalized
# into general column translation (i.e. corresponding_column)
allconds = []
param_names = []
- visitor = mapperutil.BinaryVisitor(visit_binary)
for mapper in self.iterate_to_root():
if mapper is base_mapper:
break
- allconds.append(visitor.traverse(mapper.inherit_condition, clone=True))
+ allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
return sql.and_(*allconds), param_names
"""
from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql import util as sql_util, visitors
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
self._opposite_side.add(binary.right)
if binary.right in self.foreign_keys:
self._opposite_side.add(binary.left)
- mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin)
+ visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
if self.secondaryjoin is not None:
- mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin)
+ visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
else:
self.foreign_keys = util.Set()
self._opposite_side = util.Set()
if f.references(binary.left.table):
self.foreign_keys.add(binary.right)
self._opposite_side.add(binary.left)
- mapperutil.BinaryVisitor(visit_binary).traverse(self.primaryjoin)
+ visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
if len(self.foreign_keys) == 0:
raise exceptions.ArgumentError(
"'foreign_keys' argument to indicate which columns in "
"the join condition are foreign." %(str(self.primaryjoin), str(self)))
if self.secondaryjoin is not None:
- mapperutil.BinaryVisitor(visit_binary).traverse(self.secondaryjoin)
+ visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
def _determine_direction(self):
# in the "polymorphic" selectables. these are used to construct joins for both Query as well as
# eager loading, and also are used to calculate "lazy loading" clauses.
- # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
- # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
- # several "equivalent" columns (such as parent/child fk cols) into just one column.
-
- target_equivalents = self.mapper._get_equivalent_columns()
-
- # if the target mapper loads polymorphically, adapt the clauses to the target's selectable
if self.loads_polymorphic:
+
+ # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
+ # first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
+ # several "equivalent" columns (such as parent/child fk cols) into just one column.
+ target_equivalents = self.mapper._get_equivalent_columns()
+
if self.secondaryjoin:
self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
self.polymorphic_primaryjoin = self.primaryjoin
elif self.direction is sync.MANYTOONE:
self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
self.polymorphic_secondaryjoin = None
+
# load "polymorphic" versions of the columns present in "remote_side" - this is
# important for lazy-clause generation which goes off the polymorphic target selectable
for c in list(self.remote_side):
for o in order_by:
cf.traverse(o)
- s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, **self._select_args())
+ s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, order_by=util.to_list(order_by), **self._select_args())
- if order_by:
- s2.append_order_by(*util.to_list(order_by))
-
s3 = s2.alias()
self._primary_adapter = mapperutil.create_row_adapter(s3, self.table)
statement.append_order_by(*context.eager_order_by)
else:
- statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args())
+ statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=util.to_list(order_by), **self._select_args())
if context.eager_joins:
statement.append_from(context.eager_joins, _copy_collection=False)
- if order_by:
- statement.append_order_by(*util.to_list(order_by))
-
if context.eager_order_by:
statement.append_order_by(*context.eager_order_by)
else:
return othercol in remote_side
- def find_column_in_expr(expr):
- if not isinstance(expr, sql.ColumnElement):
- return None
- columns = []
- class FindColumnInColumnClause(visitors.ClauseVisitor):
- def visit_column(self, c):
- columns.append(c)
- FindColumnInColumnClause().traverse(expr)
- return len(columns) and columns[0] or None
-
def visit_binary(binary):
- leftcol = find_column_in_expr(binary.left)
- rightcol = find_column_in_expr(binary.right)
- if leftcol is None or rightcol is None:
+ if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
return
+ leftcol = binary.left
+ rightcol = binary.right
if should_bind(leftcol, rightcol):
col = leftcol
reverse[leftcol] = binds[col]
lazywhere = primaryjoin
- li = mapperutil.BinaryVisitor(visit_binary)
if not secondaryjoin or not reverse_direction:
- lazywhere = li.traverse(lazywhere, clone=True)
+ lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary)
if secondaryjoin is not None:
if reverse_direction:
- secondaryjoin = li.traverse(secondaryjoin, clone=True)
+ secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
lazywhere = sql.and_(lazywhere, secondaryjoin)
return (lazywhere, binds, reverse)
_create_lazy_clause = classmethod(_create_lazy_clause)
self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary))
rules_added = len(self.syncrules)
- BinaryVisitor(compile_binary).traverse(sqlclause)
+ visitors.traverse(sqlclause, visit_binary=compile_binary)
if len(self.syncrules) == rules_added:
raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
SyncRule.logger = logging.class_logger(SyncRule)
-class BinaryVisitor(visitors.ClauseVisitor):
- def __init__(self, func):
- self.func = func
-
- def visit_binary(self, binary):
- self.func(binary)
def __init__(self, session):
if session.weak_identity_map:
- self.identity_map = attributes.InstanceDict()
+ self.identity_map = attributes.WeakInstanceDict()
else:
- self.identity_map = {}
+ self.identity_map = attributes.StrongInstanceDict()
self.new = util.IdentitySet() #OrderedSet()
self.deleted = util.IdentitySet()
)
])
+
def flush(self, session, objects=None):
"""create a dependency tree of all pending SQL operations within this unit of work and execute."""
# communication with the mappers and relationships to fire off SQL
# and synchronize attributes between related objects.
- # detect persistent objects that have changes
- dirty = self.locate_dirty()
-
+ dirty = [x for x in self.identity_map.all_states()
+ if x.modified
+ or (getattr(x.class_, '_sa_has_mutable_scalars', False) and attribute_manager._is_modified(x))
+ ]
+
if len(dirty) == 0 and len(self.deleted) == 0 and len(self.new) == 0:
return
-
+
+ dirty = util.IdentitySet([x.obj() for x in dirty]).difference(self.deleted)
+
flush_context = UOWTransaction(self, session)
if session.extension is not None:
the number of objects pruned.
"""
- if isinstance(self.identity_map, attributes.InstanceDict):
+ if isinstance(self.identity_map, attributes.WeakInstanceDict):
return 0
ref_count = len(self.identity_map)
dirty = self.locate_dirty()
return ourcol
def __getitem__(self, col):
- return super(TranslatingDict, self).__getitem__(self.__translate_col(col))
+ try:
+ return super(TranslatingDict, self).__getitem__(col)
+ except KeyError:
+ return super(TranslatingDict, self).__getitem__(self.__translate_col(col))
def has_key(self, col):
return col in self
def __getattr__(self, key):
return self.methods.get(key, self._pass)
-class BinaryVisitor(visitors.ClauseVisitor):
- def __init__(self, func):
- self.func = func
-
- def visit_binary(self, binary):
- self.func(binary)
-
class AliasedClauses(object):
"""Creates aliases of a mapped tables for usage in ORM queries.
"""
n = self.dialect.oid_column_name(column)
if n is not None:
if column.table is None or not column.table.named_with_column():
- return self.preparer.format_column(column, name=n)
+ return n
else:
- return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)), n)
+ return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + n
elif len(column.table.primary_key) != 0:
pk = list(column.table.primary_key)[0]
pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
- return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name))
+ return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(pk, pkname)
else:
return None
elif column.table is None or not column.table.named_with_column():
- return self.preparer.format_column(column, name=name)
+ if getattr(column, "is_literal", False):
+ return name
+ else:
+ return self.preparer.quote(column, name)
else:
- return self.preparer.format_column_with_table(column, column_name=name, table_name=ANONYMOUS_LABEL.sub(self._process_anon, column.table.name))
+ if getattr(column, "is_literal", False):
+ return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name
+ else:
+ return self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + self.preparer.quote(column, name)
def visit_fromclause(self, fromclause, **kwargs):
def visit_table(self, table, asfrom=False, **kwargs):
if asfrom:
- return self.preparer.format_table(table)
+ if getattr(table, "schema", None):
+ return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name)
+ else:
+ return self.preparer.quote(table, table.name)
else:
return ""
return ("INSERT INTO %s (%s) VALUES (%s)" %
(preparer.format_table(insert_stmt.table),
- ', '.join([preparer.format_column(c[0])
+ ', '.join([preparer.quote(c[0], c[0].name)
for c in colparams]),
', '.join([c[1] for c in colparams])))
self.isupdate = True
colparams = self._get_colparams(update_stmt)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ')
if update_stmt._whereclause:
text += " WHERE " + self.process(update_stmt._whereclause)
if constraint.name is not None:
self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join([self.preparer.format_column(c) for c in constraint]))
+ self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint]))
def visit_foreign_key_constraint(self, constraint):
if constraint.use_alter and self.dialect.supports_alter:
if constraint.name is not None:
self.append("CONSTRAINT %s " %
preparer.format_constraint(constraint))
+ table = list(constraint.elements)[0].column.table
self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join([preparer.format_column(f.parent) for f in constraint.elements]),
- preparer.format_table(list(constraint.elements)[0].column.table),
- ', '.join([preparer.format_column(f.column) for f in constraint.elements])
+ ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]),
+ preparer.format_table(table),
+ ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements])
))
if constraint.ondelete is not None:
self.append(" ON DELETE %s" % constraint.ondelete)
if constraint.name is not None:
self.append("CONSTRAINT %s " %
self.preparer.format_constraint(constraint))
- self.append(" UNIQUE (%s)" % (', '.join([self.preparer.format_column(c) for c in constraint])))
+ self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint])))
def visit_column(self, column):
pass
self.append("INDEX %s ON %s (%s)" \
% (preparer.format_index(index),
preparer.format_table(index.table),
- string.join([preparer.format_column(c) for c in index.columns], ', ')))
+ string.join([preparer.quote(c, c.name) for c in index.columns], ', ')))
self.execute()
class SchemaDropper(DDLBase):
or not self.legal_characters.match(unicode(value))
or (lc_value != value))
- def __generic_obj_format(self, obj, ident):
+ def quote(self, obj, ident):
if getattr(obj, 'quote', False):
return self.quote_identifier(ident)
- try:
+ if ident in self.__strings:
return self.__strings[ident]
- except KeyError:
+ else:
if self._requires_quotes(ident):
self.__strings[ident] = self.quote_identifier(ident)
else:
return object.quote or self._requires_quotes(object.name)
def format_sequence(self, sequence, use_schema=True):
- name = self.__generic_obj_format(sequence, sequence.name)
+ name = self.quote(sequence, sequence.name)
if use_schema and sequence.schema is not None:
- name = self.__generic_obj_format(sequence, sequence.schema) + "." + name
+ name = self.quote(sequence, sequence.schema) + "." + name
return name
def format_label(self, label, name=None):
- return self.__generic_obj_format(label, name or label.name)
+ return self.quote(label, name or label.name)
def format_alias(self, alias, name=None):
- return self.__generic_obj_format(alias, name or alias.name)
+ return self.quote(alias, name or alias.name)
def format_savepoint(self, savepoint, name=None):
- return self.__generic_obj_format(savepoint, name or savepoint.ident)
+ return self.quote(savepoint, name or savepoint.ident)
def format_constraint(self, constraint):
- return self.__generic_obj_format(constraint, constraint.name)
+ return self.quote(constraint, constraint.name)
def format_index(self, index):
- return self.__generic_obj_format(index, index.name)
+ return self.quote(index, index.name)
def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
if name is None:
name = table.name
- result = self.__generic_obj_format(table, name)
+ result = self.quote(table, name)
if use_schema and getattr(table, "schema", None):
- result = self.__generic_obj_format(table, table.schema) + "." + result
+ result = self.quote(table, table.schema) + "." + result
return result
def format_column(self, column, use_table=False, name=None, table_name=None):
- """Prepare a quoted column name."""
+ """Prepare a quoted column name.
+
+ deprecated. use preparer.quote(col, column.name) or combine with format_table()
+ """
+
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name)
+ return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name)
else:
- return self.__generic_obj_format(column, name)
+ return self.quote(column, name)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
else:
return name
- def format_column_with_table(self, column, column_name=None, table_name=None):
- """Prepare a quoted column name with table name."""
-
- return self.format_column(column, use_table=True, name=column_name, table_name=table_name)
-
-
def format_table_seq(self, table, use_schema=True):
"""Format table name and schema as a tuple."""
raise NotImplementedError(repr(self))
+ def _aggregate_hide_froms(self, **modifiers):
+ """Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces, taking into account
+ previous ClauseElements which this ClauseElement is a clone of."""
+
+ s = self
+ while s is not None:
+ for h in s._hide_froms(**modifiers):
+ yield h
+ s = getattr(s, '_is_clone_of', None)
+
def _hide_froms(self, **modifiers):
"""Return a list of ``FROM`` clause elements which this ``ClauseElement`` replaces."""
else:
equivs[x] = util.Set([y])
- class BinaryVisitor(visitors.ClauseVisitor):
- def visit_binary(self, binary):
- if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
- add_equiv(binary.left, binary.right)
- BinaryVisitor().traverse(self.onclause)
+ def visit_binary(binary):
+ if binary.operator == operators.eq and isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
+ add_equiv(binary.left, binary.right)
+ visitors.traverse(self.onclause, visit_binary=visit_binary)
for col in pkcol:
for fk in col.foreign_keys:
self._offset = offset
self._bind = bind
- self.append_order_by(*util.to_list(order_by, []))
- self.append_group_by(*util.to_list(group_by, []))
+ self._order_by_clause = ClauseList(*util.to_list(order_by, []))
+ self._group_by_clause = ClauseList(*util.to_list(group_by, []))
def as_scalar(self):
"""return a 'scalar' representation of this selectable, which can be used
# usually called via a generative method, create a copy of each collection
# by default
- self._raw_columns = []
self.__correlate = util.Set()
- self._froms = util.OrderedSet()
- self._whereclause = None
self._having = None
self._prefixes = []
- if columns is not None:
- for c in columns:
- self.append_column(c, _copy_collection=False)
-
- if from_obj is not None:
- for f in from_obj:
- self.append_from(f, _copy_collection=False)
+ if columns:
+ self._raw_columns = [
+ isinstance(c, _ScalarSelect) and c.self_group(against=operators.comma_op) or c
+ for c in
+ [_literal_as_column(c) for c in columns]
+ ]
+ else:
+ self._raw_columns = []
+
+ if from_obj:
+ self._froms = util.Set([
+ _is_literal(f) and _TextFromClause(f) or f
+ for f in from_obj
+ ])
+ else:
+ self._froms = util.Set()
- if whereclause is not None:
- self.append_whereclause(whereclause)
+ if whereclause:
+ self._whereclause = _literal_as_text(whereclause)
+ else:
+ self._whereclause = None
- if having is not None:
- self.append_having(having)
+ if having:
+ self._having = _literal_as_text(having)
+ else:
+ self._having = None
- if prefixes is not None:
- for p in prefixes:
- self.append_prefix(p, _copy_collection=False)
+ if prefixes:
+ self._prefixes = [_literal_as_text(p) for p in prefixes]
+ else:
+ self._prefixes = []
_SelectBaseMixin.__init__(self, **kwargs)
correlating.
"""
- froms = util.OrderedSet()
+ froms = util.Set()
hide_froms = util.Set()
for col in self._raw_columns:
- for f in col._hide_froms():
- hide_froms.add(f)
- while hasattr(f, '_is_clone_of'):
- hide_froms.add(f._is_clone_of)
- f = f._is_clone_of
- for f in col._get_from_objects():
- froms.add(f)
+ hide_froms.update(col._aggregate_hide_froms())
+ froms.update(col._get_from_objects())
if self._whereclause is not None:
- for f in self._whereclause._get_from_objects(is_where=True):
- froms.add(f)
+ froms.update(self._whereclause._get_from_objects(is_where=True))
- for elem in self._froms:
- froms.add(elem)
- for f in elem._get_from_objects():
- froms.add(f)
-
- for elem in froms:
- for f in elem._hide_froms():
- hide_froms.add(f)
- while hasattr(f, '_is_clone_of'):
- hide_froms.add(f._is_clone_of)
- f = f._is_clone_of
+ if self._froms:
+ froms.update(self._froms)
+ for elem in self._froms:
+ hide_froms.update(elem._aggregate_hide_froms())
froms = froms.difference(hide_froms)
if len(froms) > 1:
corr = self.__correlate
if self._should_correlate and existing_froms is not None:
- corr = existing_froms.union(corr)
-
- for f in list(corr):
- while hasattr(f, '_is_clone_of'):
- corr.add(f._is_clone_of)
- f = f._is_clone_of
+ corr.update(existing_froms)
f = froms.difference(corr)
- if len(f) == 0:
+ if not f:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
return f
else:
"""Utility functions that build upon SQL and Schema constructs."""
+# TODO: replace with plain list. break out sorting funcs into module-level funcs
class TableCollection(object):
def __init__(self, tables=None):
self.tables = tables or []
return sequence
+# TODO: replace with plain module-level func
class TableFinder(TableCollection, visitors.NoColumnVisitor):
"""locate all Tables within a clause."""
"""
__traverse_options__ = {'column_collections':False}
+
+def traverse(clause, **kwargs):
+ clone = kwargs.pop('clone', False)
+ class Vis(ClauseVisitor):
+ __traverse_options__ = kwargs.pop('traverse_options', {})
+ def __getattr__(self, key):
+ if key in kwargs:
+ return kwargs[key]
+ else:
+ return None
+ return Vis().traverse(clause, clone=clone)
+
t1.update().compile()
# TODO: this is alittle high
- @profiling.profiled('ctest_select', call_range=(170, 200), always=True)
+ @profiling.profiled('ctest_select', call_range=(130, 150), always=True)
def test_select(self):
s = select([t1], t1.c.c2==t2.c.c1)
s.compile()
legs.sort()
@testing.supported('postgres')
- @profiling.profiled('editing', call_range=(1200, 1290), always=True)
+ @profiling.profiled('editing', call_range=(1150, 1280), always=True)
def test_6_editing(self):
Zoo = metadata.tables['Zoo']