auto-construction of joins which cross the same paths but
are querying divergent criteria. ClauseElements at the front
of filter_by() are removed (use filter()).
+ - improved support for custom column_property() attributes which
+ feature correlated subqueries...work better with eager loading now.
- along with recent speedups to ResultProxy, total number of
function calls significantly reduced for large loads.
test/perf/masseagerload.py reports 0.4 as having the fewest number
- added undefer_group() MapperOption, sets a set of "deferred"
columns joined by a "group" to load as "undeferred".
- sql
+ - significant architectural overhaul to SQL elements (ClauseElement).
+ all elements share a common "mutability" framework which allows a
+ consistent approach to in-place modifications of elements as well as
+ generative behavior. improves stability of the ORM which makes
+ heavy usage of mutations to SQL expressions.
+ - select() and union()'s now have "generative" behavior. methods like
+ order_by() and group_by() return a *new* instance - the original instance
+ is left unchanged. non-generative methods remain as well.
+ - the internals of select/union vastly simplified - all decision making
+ regarding "is subquery" and "correlation" pushed to SQL generation phase.
+ select() elements are now *never* mutated by their enclosing containers
+ or by any dialect's compilation process [ticket:52] [ticket:569]
- result sets from CRUD operations close their underlying cursor immediately.
will also autoclose the connection if defined for the operation; this
allows more efficient usage of connections for successive CRUD operations
"""
return ANSIIdentifierPreparer(self)
-class ANSICompiler(sql.Compiled):
+class ANSICompiler(engine.Compiled):
"""Default implementation of Compiled.
Compiles ClauseElements into ANSI-compliant SQL strings.
"""
- __traverse_options__ = {'column_collections':False}
+ __traverse_options__ = {'column_collections':False, 'entry':True}
def __init__(self, dialect, statement, parameters=None, **kwargs):
"""Construct a new ``ANSICompiler`` object.
correspond to the keys present in the parameters.
"""
- sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
+ super(ANSICompiler, self).__init__(dialect, statement, parameters, **kwargs)
# if we are insert/update. set to true when we visit an INSERT or UPDATE
self.isinsert = self.isupdate = False
# an ANSIIdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
-
+
+ # a dictionary containing attributes about all select()
+ # elements located within the clause, regarding which are subqueries, which are
+ # selected from, and which elements should be correlated to an enclosing select.
+ # used mostly to determine the list of FROM elements for each select statement, as well
+ # as some dialect-specific rules regarding subqueries.
+ self.correlate_state = {}
+
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
# default-value logic in the Dialect knows not to fire off column defaults
def get_str(self, obj):
return self.strings[obj]
-
+
+ def is_subquery(self, select):
+ return self.correlate_state[select].get('is_subquery', False)
+
def get_whereclause(self, obj):
return self.wheres.get(obj, None)
def visit_compound_select(self, cs):
text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
- group_by = self.get_str(cs.group_by_clause)
+ group_by = self.get_str(cs._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs)
self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias)
self.strings[alias] = self.get_str(alias.original)
+ def enter_select(self, select):
+ select.calculate_correlations(self.correlate_state)
+ self.select_stack.append(select)
+
+ def enter_update(self, update):
+ update.calculate_correlations(self.correlate_state)
+
+ def enter_delete(self, delete):
+ delete.calculate_correlations(self.correlate_state)
+
+ def label_select_column(self, select, column):
+ """convert a column from a select's "columns" clause.
+
+ given a select() and a column element from its inner_columns collection, return a
+ Label object if this column should be labeled in the columns clause. Otherwise,
+ return None and the column will be used as-is.
+
+ The calling method will traverse the returned label to acquire its string
+ representation.
+ """
+
+ # SQLite doesnt like selecting from a subquery where the column
+ # names look like table.colname. so if column is in a "selected from"
+ # subquery, label it synoymously with its column name
+ if \
+ self.correlate_state[select].get('is_selected_from', False) and \
+ isinstance(column, sql._ColumnClause) and \
+ not column.is_literal and \
+ column.table is not None and \
+ not isinstance(column.table, sql.Select):
+ return column.label(column.name)
+ else:
+ return None
+
def visit_select(self, select):
# the actual list of columns to print in the SELECT column list.
inner_columns = util.OrderedDict()
-
- self.select_stack.append(select)
- for c in select._raw_columns:
- if hasattr(c, '_selectable'):
- s = c._selectable()
+
+ froms = select.get_display_froms(self.correlate_state)
+ for f in froms:
+ if f not in self.strings:
+ self.traverse(f)
+
+ for co in select.inner_columns:
+ if select.use_labels:
+ labelname = co._label
+ if labelname is not None:
+ l = co.label(labelname)
+ self.traverse(l)
+ inner_columns[labelname] = l
+ else:
+ self.traverse(co)
+ inner_columns[self.get_str(co)] = co
else:
- self.traverse(c)
- inner_columns[self.get_str(c)] = c
- continue
- for co in s.columns:
- if select.use_labels:
- labelname = co._label
- if labelname is not None:
- l = co.label(labelname)
- self.traverse(l)
- inner_columns[labelname] = l
- else:
- self.traverse(co)
- inner_columns[self.get_str(co)] = co
- # TODO: figure this out, a ColumnClause with a select as a parent
- # is different from any other kind of parent
- elif select.is_selected_from and isinstance(co, sql._ColumnClause) and not co.is_literal and co.table is not None and not isinstance(co.table, sql.Select):
- # SQLite doesnt like selecting from a subquery where the column
- # names look like table.colname, so add a label synonomous with
- # the column name
- l = co.label(co.name)
+ l = self.label_select_column(select, co)
+ if l is not None:
self.traverse(l)
inner_columns[self.get_str(l.obj)] = l
else:
self.traverse(co)
inner_columns[self.get_str(co)] = co
+
self.select_stack.pop(-1)
collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
text += self.visit_select_precolumns(select)
text += collist
- whereclause = select.whereclause
-
- froms = []
- for f in select.froms:
-
- if self.parameters is not None:
- # TODO: whack this feature in 0.4
- # look at our own parameters, see if they
- # are all present in the form of BindParamClauses. if
- # not, then append to the above whereclause column conditions
- # matching those keys
- for c in f.columns:
- if sql.is_column(c) and self.parameters.has_key(c.key) and not self.binds.has_key(c.key):
- value = self.parameters[c.key]
- else:
- continue
- clause = c==value
- if whereclause is not None:
- whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause]))
- else:
- whereclause = clause
- self.traverse(whereclause)
+ whereclause = select._whereclause
+ from_strings = []
+ for f in froms:
# special thingy used by oracle to redefine a join
w = self.get_whereclause(f)
if w is not None:
t = self.get_from_text(f)
if t is not None:
- froms.append(t)
+ from_strings.append(t)
if len(froms):
text += " \nFROM "
- text += string.join(froms, ', ')
+ text += string.join(from_strings, ', ')
else:
text += self.default_from()
if t:
text += " \nWHERE " + t
- group_by = self.get_str(select.group_by_clause)
+ group_by = self.get_str(select._group_by_clause)
if group_by:
text += " GROUP BY " + group_by
- if select.having is not None:
- t = self.get_str(select.having)
+ if select._having is not None:
+ t = self.get_str(select._having)
if t:
text += " \nHAVING " + t
def visit_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just before column list."""
- return select.distinct and "DISTINCT " or ""
+ return select._distinct and "DISTINCT " or ""
def visit_select_postclauses(self, select):
"""Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses.
Most DB syntaxes put ``LIMIT``/``OFFSET`` here.
"""
- return (select.limit or select.offset) and self.limit_clause(select) or ""
+ return (select._limit or select._offset) and self.limit_clause(select) or ""
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.get_str(select._order_by_clause)
if order_by:
return " ORDER BY " + order_by
else:
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT -1"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
def visit_table(self, table):
text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
- if update_stmt.whereclause:
- text += " WHERE " + self.get_str(update_stmt.whereclause)
+ if update_stmt._whereclause:
+ text += " WHERE " + self.get_str(update_stmt._whereclause)
self.strings[update_stmt] = text
if sql._is_literal(value):
value = sql.bindparam(c.key, value, type=c.type, unique=True)
values.append((c, value))
+
return values
def visit_delete(self, delete_stmt):
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
- if delete_stmt.whereclause:
- text += " WHERE " + self.get_str(delete_stmt.whereclause)
+ if delete_stmt._whereclause:
+ text += " WHERE " + self.get_str(delete_stmt._whereclause)
self.strings[delete_stmt] = text
def visit_metadata(self, metadata):
collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
for table in collection:
- table.accept_visitor(self)
+ self.traverse_single(table)
if self.dialect.supports_alter():
for alterable in self.find_alterables(collection):
self.add_foreignkey(alterable)
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
- column.default.accept_visitor(self)
+ self.traverse_single(column.default)
#if column.onupdate is not None:
# column.onupdate.accept_visitor(visitor)
if column.primary_key:
first_pk = True
for constraint in column.constraints:
- constraint.accept_visitor(self)
+ self.traverse_single(constraint)
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
if len(table.primary_key):
- table.primary_key.accept_visitor(self)
+ self.traverse_single(table.primary_key)
for constraint in [c for c in table.constraints if c is not table.primary_key]:
- constraint.accept_visitor(self)
+ self.traverse_single(constraint)
self.append("\n)%s\n\n" % self.post_create_table(table))
self.execute()
if hasattr(table, 'indexes'):
for index in table.indexes:
- index.accept_visitor(self)
+ self.traverse_single(index)
def post_create_table(self, table):
return ''
for alterable in self.find_alterables(collection):
self.drop_foreignkey(alterable)
for table in collection:
- table.accept_visitor(self)
+ self.traverse_single(table)
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name)
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
- column.default.accept_visitor(self)
+ self.traverse_single(column.default)
self.append("\nDROP TABLE " + self.preparer.format_table(table))
self.execute()
"""
result = ""
- if select.limit:
- result += " FIRST %d " % select.limit
- if select.offset:
- result +=" SKIP %d " % select.offset
- if select.distinct:
+ if select._limit:
+ result += " FIRST %d " % select._limit
+ if select._offset:
+ result +=" SKIP %d " % select._offset
+ if select._distinct:
result += " DISTINCT "
return result
return " from systables where tabname = 'systables' "
def visit_select_precolumns( self , select ):
- s = select.distinct and "DISTINCT " or ""
+ s = select._distinct and "DISTINCT " or ""
# only has limit
- if select.limit:
- off = select.offset or 0
- s += " FIRST %s " % ( select.limit + off )
+ if select._limit:
+ off = select._offset or 0
+ s += " FIRST %s " % ( select._limit + off )
else:
s += ""
return s
def visit_select(self, select):
- if select.offset:
- self.offset = select.offset
- self.limit = select.limit or 0
+ if select._offset:
+ self.offset = select._offset
+ self.limit = select._limit or 0
# the column in order by clause must in select too
def __label( c ):
* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
-* ``select.limit`` implemented as ``SELECT TOP n``
+* ``select._limit`` implemented as ``SELECT TOP n``
Known issues / TODO:
def visit_select_precolumns(self, select):
""" MS-SQL puts TOP, it's version of LIMIT here """
- s = select.distinct and "DISTINCT " or ""
- if select.limit:
- s += "TOP %s " % (select.limit,)
- if select.offset:
+ s = select._distinct and "DISTINCT " or ""
+ if select._limit:
+ s += "TOP %s " % (select._limit,)
+ if select._offset:
raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
return s
binary.left, binary.right = binary.right, binary.left
super(MSSQLCompiler, self).visit_binary(binary)
- def visit_select(self, select):
- # label function calls, so they return a name in cursor.description
- for i,c in enumerate(select._raw_columns):
- if isinstance(c, sql._Function):
- select._raw_columns[i] = c.label(c.name + "_" + hex(random.randint(0, 65535))[2:])
-
- super(MSSQLCompiler, self).visit_select(select)
+ def label_select_column(self, select, column):
+ if isinstance(column, sql._Function):
+ return co.label(co.name + "_" + hex(random.randint(0, 65535))[2:])
+ else:
+ return super(MSSQLCompiler, self).label_select_column(select, column)
function_rewrites = {'current_date': 'getdate',
'length': 'len',
return ''
def order_by_clause(self, select):
- order_by = self.get_str(select.order_by_clause)
+ order_by = self.get_str(select._order_by_clause)
# MSSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not select.is_subquery or select.limit):
+ if order_by and (not self.is_subquery(select) or select._limit):
return " ORDER BY " + order_by
else:
return ""
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
- # striaght from the MySQL docs, I kid you not
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
+ # straight from the MySQL docs, I kid you not
text += " \n LIMIT 18446744073709551615"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
return text
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
the use_ansi flag is False.
"""
- def __init__(self, *args, **kwargs):
- super(OracleCompiler, self).__init__(*args, **kwargs)
- # we have to modify SELECT objects a little bit, so store state here
- self._select_state = {}
-
def default_from(self):
"""Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
self._outertable = None
- self.wheres[join].accept_visitor(self)
+ self.traverse_single(self.wheres[join])
def visit_insert_sequence(self, column, sequence, parameters):
"""This is the `sequence` equivalent to ``ANSICompiler``'s
def _TODO_visit_compound_select(self, select):
"""Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
-
- if getattr(select, '_oracle_visit', False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_compound_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- select._oracle_visit = True
- # to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
- if not orderby:
- orderby = select.oid_column
- self.traverse(orderby)
- orderby = self.strings[orderby]
- class SelectVisitor(sql.NoColumnVisitor):
- def visit_select(self, select):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- SelectVisitor().traverse(select)
- limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
- else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
- self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
- else:
- ansisql.ANSICompiler.visit_compound_select(self, select)
+ pass
def visit_select(self, select):
"""Look for ``LIMIT`` and OFFSET in a select statement, and if
so tries to wrap it in a subquery with ``row_number()`` criterion.
"""
- # TODO: put a real copy-container on Select and copy, or somehow make this
- # not modify the Select statement
- if self._select_state.get((select, 'visit'), False):
- # cancel out the compiled order_by on the select
- if hasattr(select, "order_by_clause"):
- self.strings[select.order_by_clause] = ""
- ansisql.ANSICompiler.visit_select(self, select)
- return
-
- if select.limit is not None or select.offset is not None:
- self._select_state[(select, 'visit')] = True
+ if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
# to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.strings[select.order_by_clause]
+ orderby = self.strings[select._order_by_clause]
if not orderby:
orderby = select.oid_column
self.traverse(orderby)
orderby = self.strings[orderby]
- if not hasattr(select, '_oracle_visit'):
- select.append_column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
- select._oracle_visit = True
+
+ oldselect = select
+ select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
+ select._oracle_visit = True
+
limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
- if select.offset is not None:
- limitselect.append_whereclause("ora_rn>%d" % select.offset)
- if select.limit is not None:
- limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
+ if select._offset is not None:
+ limitselect.append_whereclause("ora_rn>%d" % select._offset)
+ if select._limit is not None:
+ limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
else:
- limitselect.append_whereclause("ora_rn<=%d" % select.limit)
+ limitselect.append_whereclause("ora_rn<=%d" % select._limit)
self.traverse(limitselect)
- self.strings[select] = self.strings[limitselect]
- self.froms[select] = self.froms[limitselect]
+ self.strings[oldselect] = self.strings[limitselect]
+ self.froms[oldselect] = self.froms[limitselect]
else:
ansisql.ANSICompiler.visit_select(self, select)
return text
def visit_select_precolumns(self, select):
- if select.distinct:
- if type(select.distinct) == bool:
+ if select._distinct:
+ if type(select._distinct) == bool:
return "DISTINCT "
- if type(select.distinct) == list:
+ if type(select._distinct) == list:
dist_set = "DISTINCT ON ("
- for col in select.distinct:
+ for col in select._distinct:
dist_set += self.strings[col] + ", "
dist_set = dist_set[:-2] + ") "
return dist_set
- return "DISTINCT ON (" + str(select.distinct) + ") "
+ return "DISTINCT ON (" + str(select._distinct) + ") "
else:
return ""
def limit_clause(self, select):
text = ""
- if select.limit is not None:
- text += " \n LIMIT " + str(select.limit)
- if select.offset is not None:
- if select.limit is None:
+ if select._limit is not None:
+ text += " \n LIMIT " + str(select._limit)
+ if select._offset is not None:
+ if select._limit is None:
text += " \n LIMIT -1"
- text += " OFFSET " + str(select.offset)
+ text += " OFFSET " + str(select._offset)
else:
text += " OFFSET 0"
return text
raise NotImplementedError()
+class Compiled(sql.ClauseVisitor):
+ """Represent a compiled SQL expression.
+
+ The ``__str__`` method of the ``Compiled`` object should produce
+ the actual text of the statement. ``Compiled`` objects are
+ specific to their underlying database dialect, and also may
+ or may not be specific to the columns referenced within a
+ particular set of bind parameters. In no case should the
+ ``Compiled`` object be dependent on the actual values of those
+ bind parameters, even though it may reference those values as
+ defaults.
+ """
+
+ def __init__(self, dialect, statement, parameters, engine=None):
+ """Construct a new ``Compiled`` object.
+
+ statement
+ ``ClauseElement`` to be compiled.
+
+ parameters
+ Optional dictionary indicating a set of bind parameters
+ specified with this ``Compiled`` object. These parameters
+ are the *default* values corresponding to the
+ ``ClauseElement``'s ``_BindParamClauses`` when the
+ ``Compiled`` is executed. In the case of an ``INSERT`` or
+ ``UPDATE`` statement, these parameters will also result in
+ the creation of new ``_BindParamClause`` objects for each
+ key and will also affect the generated column list in an
+ ``INSERT`` statement and the ``SET`` clauses of an
+ ``UPDATE`` statement. The keys of the parameter dictionary
+ can either be the string names of columns or
+ ``_ColumnClause`` objects.
+
+ engine
+ Optional Engine to compile this statement against.
+ """
+ self.dialect = dialect
+ self.statement = statement
+ self.parameters = parameters
+ self.engine = engine
+ self.can_execute = statement.supports_execution()
+
+ def compile(self):
+ self.traverse(self.statement)
+ self.after_compile()
+
+ def __str__(self):
+ """Return the string text of the generated SQL statement."""
+
+ raise NotImplementedError()
+
+ def get_params(self, **params):
+ """Deprecated. use construct_params(). (supports unicode names)
+ """
+
+ return self.construct_params(params)
+
+ def construct_params(self, params):
+ """Return the bind params for this compiled object.
+
+ Will start with the default parameters specified when this
+ ``Compiled`` object was first constructed, and will override
+ those values with those sent via `**params`, which are
+ key/value pairs. Each key should match one of the
+ ``_BindParamClause`` objects compiled into this object; either
+ the `key` or `shortname` property of the ``_BindParamClause``.
+ """
+ raise NotImplementedError()
+
+ def execute(self, *multiparams, **params):
+ """Execute this compiled object."""
+
+ e = self.engine
+ if e is None:
+ raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.")
+ return e.execute_compiled(self, *multiparams, **params)
+
+ def scalar(self, *multiparams, **params):
+ """Execute this compiled object and return the result's scalar value."""
+
+ return self.execute(*multiparams, **params).scalar()
+
-class Connectable(sql.Executor):
+class Connectable(object):
"""Interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
def contextual_connect(self):
raise exceptions.InvalidRequestError("Unexecuteable object type: " + str(type(object)))
def execute_default(self, default, **kwargs):
- return default.accept_visitor(self.__engine.dialect.defaultrunner(self))
+ return self.__engine.dialect.defaultrunner(self).traverse_single(default)
def execute_text(self, statement, *multiparams, **params):
if len(multiparams) == 0:
else:
conn = connection
try:
- element.accept_visitor(visitorcallable(conn, **kwargs))
+ visitorcallable(conn, **kwargs).traverse(element)
finally:
if connection is None:
conn.close()
def get_column_default(self, column):
if column.default is not None:
- return column.default.accept_visitor(self)
+ return self.traverse_single(column.default)
else:
return None
def get_column_onupdate(self, column):
if column.onupdate is not None:
- return column.onupdate.accept_visitor(self)
+ return self.traverse_single(column.onupdate)
else:
return None
if isinstance(selectable, sql.Alias):
return _selectable_name(selectable.selectable)
elif isinstance(selectable, sql.Select):
- return ''.join([_selectable_name(s) for s in selectable.froms])
+ return ''.join([_selectable_name(s) for s in selectable.get_display_froms()])
elif isinstance(selectable, schema.Table):
return selectable.name.capitalize()
else:
return obj
def _deferred_inheritance_condition(self, needs_tables):
- cond = self.inherit_condition.copy_container()
+ cond = self.inherit_condition
param_names = []
def visit_binary(binary):
elif rightcol not in needs_tables:
binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True)
param_names.append(rightcol)
- mapperutil.BinaryVisitor(visit_binary).traverse(cond)
+ cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True)
return cond, param_names
def translate_row(self, tomapper, row):
# if the target mapper loads polymorphically, adapt the clauses to the target's selectable
if self.loads_polymorphic:
if self.secondaryjoin:
- self.polymorphic_secondaryjoin = self.secondaryjoin.copy_container()
- sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.polymorphic_secondaryjoin)
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
+ self.polymorphic_secondaryjoin = sql_util.ClauseAdapter(self.mapper.select_table).traverse(self.secondaryjoin, clone=True)
+ self.polymorphic_primaryjoin = self.primaryjoin
else:
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
if self.direction is sync.ONETOMANY:
- sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+ self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, include=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
elif self.direction is sync.MANYTOONE:
- sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.polymorphic_primaryjoin)
+ self.polymorphic_primaryjoin = sql_util.ClauseAdapter(self.mapper.select_table, exclude=self.foreign_keys, equivalents=target_equivalents).traverse(self.primaryjoin, clone=True)
self.polymorphic_secondaryjoin = None
# load "polymorphic" versions of the columns present in "remote_side" - this is
# important for lazy-clause generation which goes off the polymorphic target selectable
else:
raise exceptions.AssertionError(str(self) + ": Could not find corresponding column for " + str(c) + " in selectable " + str(self.mapper.select_table))
else:
- self.polymorphic_primaryjoin = self.primaryjoin.copy_container()
- self.polymorphic_secondaryjoin = self.secondaryjoin and self.secondaryjoin.copy_container() or None
+ self.polymorphic_primaryjoin = self.primaryjoin
+ self.polymorphic_secondaryjoin = self.secondaryjoin
def _post_init(self):
if logging.is_info_enabled(self.logger):
return self._parent_join_cache[(parent, primary, secondary)]
except KeyError:
parent_equivalents = parent._get_equivalent_columns()
- primaryjoin = self.polymorphic_primaryjoin.copy_container()
- if self.secondaryjoin is not None:
- secondaryjoin = self.polymorphic_secondaryjoin.copy_container()
- else:
- secondaryjoin = None
+ secondaryjoin = self.polymorphic_secondaryjoin
if self.direction is sync.ONETOMANY:
- sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
elif self.direction is sync.MANYTOONE:
- sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
elif self.secondaryjoin:
- sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
+ primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True)
if secondaryjoin is not None:
if secondary and not primary:
else:
if prop.secondary:
if create_aliases:
- join = prop.get_join(mapper, primary=True, secondary=False).copy_container()
+ join = prop.get_join(mapper, primary=True, secondary=False)
secondary_alias = prop.secondary.alias()
if alias is not None:
- sql_util.ClauseAdapter(alias).traverse(join)
+ join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
sql_util.ClauseAdapter(secondary_alias).traverse(join)
clause = clause.join(secondary_alias, join)
alias = prop.select_table.alias()
- join = prop.get_join(mapper, primary=False).copy_container()
- sql_util.ClauseAdapter(secondary_alias).traverse(join)
+ join = prop.get_join(mapper, primary=False)
+ join = sql_util.ClauseAdapter(secondary_alias).traverse(join, clone=True)
sql_util.ClauseAdapter(alias).traverse(join)
clause = clause.join(alias, join)
else:
clause = clause.join(prop.select_table, prop.get_join(mapper, primary=False))
else:
if create_aliases:
- join = prop.get_join(mapper).copy_container()
+ join = prop.get_join(mapper)
if alias is not None:
- sql_util.ClauseAdapter(alias).traverse(join)
+ join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
alias = prop.select_table.alias()
- sql_util.ClauseAdapter(alias).traverse(join)
+ join = sql_util.ClauseAdapter(alias).traverse(join, clone=True)
clause = clause.join(alias, join)
else:
clause = clause.join(prop.select_table, prop.get_join(mapper))
For performance, only use subselect if `order_by` attribute is set.
"""
- ops = {'distinct':self._distinct, 'order_by':self._order_by, 'from_obj':self._from_obj}
+ ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj}
if self._order_by is not False:
s1 = sql.select([col], self._criterion, **ops).alias('u')
# from there
context = QueryContext(self)
order_by = context.order_by
- group_by = context.group_by
from_obj = context.from_obj
lockmode = context.lockmode
- distinct = context.distinct
- limit = context.limit
- offset = context.offset
if order_by is False:
order_by = self.mapper.order_by
if order_by is False:
else:
cf = []
- s2 = sql.select(self.table.primary_key + list(cf), whereclause, use_labels=True, from_obj=from_obj, **context.select_args())
+ s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **context.select_args())
if order_by:
- s2.order_by(*util.to_list(order_by))
+ s2 = s2.order_by(*util.to_list(order_by))
s3 = s2.alias('tbl_row_count')
- crit = s3.primary_key==self.table.primary_key
+ crit = s3.primary_key==self.primary_key_columns
statement = sql.select([], crit, use_labels=True, for_update=for_update)
# now for the order by, convert the columns to their corresponding columns
# in the "rowcount" query, and tack that new order by onto the "rowcount" query
if order_by:
- statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
+ statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
else:
statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
if order_by:
- statement.order_by(*util.to_list(order_by))
+ statement.append_order_by(*util.to_list(order_by))
# for a DISTINCT query, you need the columns explicitly specified in order
# to use it in "order_by". ensure they are in the column criterion (particularly oid).
``QueryContext`` that can be applied to a ``sql.Select``
statement.
"""
- return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by}
+ return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by or None}
def accept_option(self, opt):
"""Accept a ``MapperOption`` which will process (modify) the
# based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper",
# probably via the query's own "mapper" property, and also use one of two "lazy" clauses,
# one against the "union" the other not
- for primary_key in self.select_mapper.pks_by_table[self.select_mapper.mapped_table]:
+ for primary_key in self.select_mapper.primary_key:
bind = self.lazyreverse[primary_key]
ident.append(params[bind.key])
return q.get(ident)
q = q.options(*options)
q = q.filter(self.lazywhere).params(**params)
+ result = q.all()
+ if self.uselist:
+ return result
+ else:
+ if len(result):
+ return result[0]
+ else:
+ return None
+
if self.uselist:
return q.all()
else:
sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True))
reverse[leftcol] = binds[col]
- lazywhere = primaryjoin.copy_container()
+ lazywhere = primaryjoin
li = mapperutil.BinaryVisitor(visit_binary)
if not secondaryjoin or not reverse_direction:
- li.traverse(lazywhere)
+ lazywhere = li.traverse(lazywhere, clone=True)
if secondaryjoin is not None:
- secondaryjoin = secondaryjoin.copy_container()
if reverse_direction:
- li.traverse(secondaryjoin)
+ secondaryjoin = li.traverse(secondaryjoin, clone=True)
lazywhere = sql.and_(lazywhere, secondaryjoin)
return (lazywhere, binds, reverse)
_create_lazy_clause = classmethod(_create_lazy_clause)
else:
aliasizer = sql_util.ClauseAdapter(self.eagertarget).\
chain(sql_util.ClauseAdapter(self.eagersecondary))
- self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin.copy_container()
- aliasizer.traverse(self.eagersecondaryjoin)
- self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
- aliasizer.traverse(self.eagerprimary)
+ self.eagersecondaryjoin = eagerloader.polymorphic_secondaryjoin
+ self.eagersecondaryjoin = aliasizer.traverse(self.eagersecondaryjoin, clone=True)
+ self.eagerprimary = eagerloader.polymorphic_primaryjoin
+ self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True)
else:
- self.eagerprimary = eagerloader.polymorphic_primaryjoin.copy_container()
+ self.eagerprimary = eagerloader.polymorphic_primaryjoin
if parentclauses is not None:
aliasizer = sql_util.ClauseAdapter(self.eagertarget)
aliasizer.chain(sql_util.ClauseAdapter(parentclauses.eagertarget, exclude=eagerloader.parent_property.remote_side))
else:
aliasizer = sql_util.ClauseAdapter(self.eagertarget)
- aliasizer.traverse(self.eagerprimary)
+ self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True)
if eagerloader.order_by:
self.eager_order_by = sql_util.ClauseAdapter(self.eagertarget).copy_and_process(util.to_list(eagerloader.order_by))
if column in self.extra_cols:
return self.extra_cols[column]
- aliased_column = column.copy_container()
- sql_util.ClauseAdapter(self.eagertarget).traverse(aliased_column)
+ aliased_column = column
+ # for column-level subqueries, swap out its selectable with our
+ # eager version as appropriate, and manually build the
+ # "correlation" list of the subquery.
+ class ModifySubquery(sql.ClauseVisitor):
+ def visit_select(s, select):
+ select._should_correlate = False
+ select.append_correlation(self.eagertarget)
+ aliased_column = sql_util.ClauseAdapter(self.eagertarget).chain(ModifySubquery()).traverse(aliased_column, clone=True)
alias = self._aliashash(column.name)
aliased_column = aliased_column.label(alias)
self._row_decorator.map[column] = alias
# this will locate the selectable inside of any containers it may be a part of (such
# as a join). if its inside of a join, we want to outer join on that join, not the
# selectable.
- for fromclause in statement.froms:
+ for fromclause in statement.get_display_froms():
if fromclause is localparent.mapped_table:
towrap = fromclause
break
break
else:
raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table))
-
+
try:
clauses = self.clauses[parentclauses]
except KeyError:
if self.secondaryjoin is not None:
statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin)
if self.order_by is False and self.secondary.default_order_by() is not None:
- statement.order_by(*clauses.eagersecondary.default_order_by())
+ statement.append_order_by(*clauses.eagersecondary.default_order_by())
else:
statement._outerjoin = towrap.outerjoin(clauses.eagertarget, clauses.eagerprimary)
if self.order_by is False and clauses.eagertarget.default_order_by() is not None:
- statement.order_by(*clauses.eagertarget.default_order_by())
+ statement.append_order_by(*clauses.eagertarget.default_order_by())
if clauses.eager_order_by:
- statement.order_by(*util.to_list(clauses.eager_order_by))
-
+ statement.append_order_by(*util.to_list(clauses.eager_order_by))
+
statement.append_from(statement._outerjoin)
+
for value in self.select_mapper.props.values():
value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper)
class SchemaItem(object):
"""Base class for items that define a database schema."""
+ __metaclass__ = sql._FigureVisitName
+
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
else:
return schema + "." + name
-class _TableSingleton(type):
+class _TableSingleton(sql._FigureVisitName):
"""A metaclass used by the ``Table`` object to provide singleton behavior."""
def __call__(self, name, metadata, *args, **kwargs):
column = property(lambda s: s._init_column())
- def accept_visitor(self, visitor):
- """Call the `visit_foreign_key` method on the given visitor."""
-
- visitor.visit_foreign_key(self)
-
def _get_parent(self):
return self.parent
super(PassiveDefault, self).__init__(**kwargs)
self.arg = arg
- def accept_visitor(self, visitor):
- return visitor.visit_passive_default(self)
-
def __repr__(self):
return "PassiveDefault(%s)" % repr(self.arg)
super(ColumnDefault, self).__init__(**kwargs)
self.arg = arg
- def accept_visitor(self, visitor):
- """Call the visit_column_default method on the given visitor."""
-
+ def _visit_name(self):
if self.for_update:
- return visitor.visit_column_onupdate(self)
+ return "column_onupdate"
else:
- return visitor.visit_column_default(self)
+ return "column_default"
+ __visit_name__ = property(_visit_name)
def __repr__(self):
return "ColumnDefault(%s)" % repr(self.arg)
def drop(self, connectable=None, checkfirst=True):
self.get_engine(connectable=connectable).drop(self, checkfirst=checkfirst)
- def accept_visitor(self, visitor):
- """Call the visit_seauence method on the given visitor."""
-
- return visitor.visit_sequence(self)
class Constraint(SchemaItem):
"""Represent a table-level ``Constraint`` such as a composite primary key, foreign key, or unique constraint.
super(CheckConstraint, self).__init__(name)
self.sqltext = sqltext
- def accept_visitor(self, visitor):
+ def _visit_name(self):
if isinstance(self.parent, Table):
- visitor.visit_check_constraint(self)
+ return "check_constraint"
else:
- visitor.visit_column_check_constraint(self)
+ return "column_check_constraint"
+ __visit_name__ = property(_visit_name)
def _set_parent(self, parent):
self.parent = parent
for (c, r) in zip(self.__colnames, self.__refcolnames):
self.append_element(c,r)
- def accept_visitor(self, visitor):
- visitor.visit_foreign_key_constraint(self)
-
def append_element(self, col, refcol):
fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
fk._set_parent(self.table.c[col])
for c in self.__colnames:
self.append_column(table.c[c])
- def accept_visitor(self, visitor):
- visitor.visit_primary_key_constraint(self)
-
def add(self, col):
self.append_column(col)
def append_column(self, col):
self.columns.add(col)
- def accept_visitor(self, visitor):
- visitor.visit_unique_constraint(self)
-
def copy(self):
return UniqueConstraint(name=self.name, *self.__colnames)
else:
self.get_engine().drop(self)
- def accept_visitor(self, visitor):
- visitor.visit_index(self)
-
def __str__(self):
return repr(self)
class MetaData(SchemaItem):
"""Represent a collection of Tables and their associated schema constructs."""
+ __visit_name__ = 'metadata'
+
def __init__(self, url=None, engine=None, **kwargs):
"""create a new MetaData object.
connectable = self.get_engine()
connectable.drop(self, checkfirst=checkfirst, tables=tables)
- def accept_visitor(self, visitor):
- visitor.visit_metadata(self)
-
def _derived_metadata(self):
return self
"""
+ __visit_name__ = 'metadata'
+
def __init__(self, engine_or_url, **kwargs):
from sqlalchemy.engine.url import URL
if isinstance(engine_or_url, basestring) or isinstance(engine_or_url, URL):
thread-local basis.
"""
+ __visit_name__ = 'metadata'
+
def __init__(self, threadlocal=True, **kwargs):
if threadlocal:
self.context = util.ThreadLocal()
"""Define the visiting for ``SchemaItem`` objects."""
__traverse_options__ = {'schema_visitor':True}
-
- def visit_schema(self, schema):
- """Visit a generic ``SchemaItem``."""
- pass
-
- def visit_table(self, table):
- """Visit a ``Table``."""
- pass
-
- def visit_column(self, column):
- """Visit a ``Column``."""
- pass
-
- def visit_foreign_key(self, join):
- """Visit a ``ForeignKey``."""
- pass
-
- def visit_index(self, index):
- """Visit an ``Index``."""
- pass
-
- def visit_passive_default(self, default):
- """Visit a passive default."""
- pass
-
- def visit_column_default(self, default):
- """Visit a ``ColumnDefault``."""
- pass
-
- def visit_column_onupdate(self, onupdate):
- """Visit a ``ColumnDefault`` with the `for_update` flag set."""
- pass
-
- def visit_sequence(self, sequence):
- """Visit a ``Sequence``."""
- pass
-
- def visit_primary_key_constraint(self, constraint):
- """Visit a ``PrimaryKeyConstraint``."""
- pass
-
- def visit_foreign_key_constraint(self, constraint):
- """Visit a ``ForeignKeyConstraint``."""
- pass
-
- def visit_unique_constraint(self, constraint):
- """Visit a ``UniqueConstraint``."""
- pass
-
- def visit_check_constraint(self, constraint):
- """Visit a ``CheckConstraint``."""
- pass
-
- def visit_column_check_constraint(self, constraint):
- """Visit a ``CheckConstraint`` on a ``Column``."""
- pass
-
-
from sqlalchemy import types as sqltypes
import string, re, random, sets
-
__all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
- 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join',
- 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc',
+ 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join',
+ 'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
'between_', 'bindparam', 'case', 'cast', 'column', 'delete',
'desc', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
'insert', 'intersect', 'intersect_all', 'join', 'literal',
return Join(left, right, onclause, **kwargs)
-def select(columns=None, whereclause = None, from_obj = [], **kwargs):
+def select(columns=None, whereclause=None, from_obj=[], **kwargs):
"""Returns a ``SELECT`` clause element.
Similar functionality is also available via the ``select()`` method on any
"""
- return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs)
+ return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
def subquery(alias, *args, **kwargs):
"""Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select].
return Select(*args, **kwargs).alias(alias)
def insert(table, values = None, **kwargs):
- """Return an [sqlalchemy.sql#_Insert] clause element.
+ """Return an [sqlalchemy.sql#Insert] clause element.
Similar functionality is available via the ``insert()``
method on [sqlalchemy.schema#Table].
against the ``INSERT`` statement.
"""
- return _Insert(table, values, **kwargs)
+ return Insert(table, values, **kwargs)
def update(table, whereclause = None, values = None, **kwargs):
- """Return an [sqlalchemy.sql#_Update] clause element.
+ """Return an [sqlalchemy.sql#Update] clause element.
Similar functionality is available via the ``update()``
method on [sqlalchemy.schema#Table].
against the ``UPDATE`` statement.
"""
- return _Update(table, whereclause, values, **kwargs)
+ return Update(table, whereclause, values, **kwargs)
def delete(table, whereclause = None, **kwargs):
- """Return a [sqlalchemy.sql#_Delete] clause element.
+ """Return a [sqlalchemy.sql#Delete] clause element.
Similar functionality is available via the ``delete()``
method on [sqlalchemy.schema#Table].
"""
- return _Delete(table, whereclause, **kwargs)
+ return Delete(table, whereclause, **kwargs)
def and_(*clauses):
"""Join a list of clauses together using the ``AND`` operator.
provides similar functionality.
"""
- return _BinaryExpression(ctest, and_(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type)), 'BETWEEN')
+ return _BinaryExpression(ctest, and_(_literal_as_binds(cleft, type=ctest.type), _literal_as_binds(cright, type=ctest.type)), 'BETWEEN')
def between_(*args, **kwargs):
"""synonym for [sqlalchemy.sql#between()] (deprecated)."""
def _is_literal(element):
return not isinstance(element, ClauseElement)
-def _literals_as_text(element):
+def _literal_as_text(element):
if _is_literal(element):
return _TextClause(unicode(element))
else:
return element
-def _literals_as_binds(element, name='literal', type=None):
+def _literal_as_binds(element, name='literal', type=None):
if _is_literal(element):
if element is None:
return null()
these options can indicate modifications to the set of
elements returned, such as to not return column collections
(column_collections=False) or to return Schema-level items
- (schema_visitor=True)."""
+ (schema_visitor=True).
+
+ """
__traverse_options__ = {}
- def traverse(self, obj, stop_on=None):
- stack = [obj]
- traversal = []
- while len(stack) > 0:
- t = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, t)
- for c in t.get_children(**self.__traverse_options__):
- stack.append(c)
- for target in traversal:
- v = self
- while v is not None:
- target.accept_visitor(v)
- v = getattr(v, '_next', None)
+
+ def traverse_single(self, obj):
+ meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj)
+
+ def traverse(self, obj, stop_on=None, clone=False):
+ if clone:
+ obj = obj._clone()
+
+ # entry flag indicates to also call a before-descent "enter_XXXX" method
+ entry = self.__traverse_options__.get('entry', False)
+
+ v = self
+ visitors = []
+ while v is not None:
+ visitors.append(v)
+ v = getattr(v, '_next', None)
+
+ def _trav(obj):
+ if stop_on is not None and obj in stop_on:
+ return
+ if entry:
+ for v in visitors:
+ meth = getattr(v, "enter_%s" % obj.__visit_name__, None)
+ if meth:
+ meth(obj)
+
+ for c in obj.get_children(clone=clone, **self.__traverse_options__):
+ _trav(c)
+
+ for v in visitors:
+ meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ meth(obj)
+ _trav(obj)
return obj
def chain(self, visitor):
tail = tail._next
tail._next = visitor
return self
-
- def visit_column(self, column):
- pass
- def visit_table(self, table):
- pass
- def visit_fromclause(self, fromclause):
- pass
- def visit_bindparam(self, bindparam):
- pass
- def visit_textclause(self, textclause):
- pass
- def visit_compound(self, compound):
- pass
- def visit_compound_select(self, compound):
- pass
- def visit_binary(self, binary):
- pass
- def visit_unary(self, unary):
- pass
- def visit_alias(self, alias):
- pass
- def visit_select(self, select):
- pass
- def visit_join(self, join):
- pass
- def visit_null(self, null):
- pass
- def visit_clauselist(self, list):
- pass
- def visit_calculatedclause(self, calcclause):
- pass
- def visit_grouping(self, gr):
- pass
- def visit_function(self, func):
- pass
- def visit_cast(self, cast):
- pass
- def visit_label(self, label):
- pass
- def visit_typeclause(self, typeclause):
- pass
-
-class LoggingClauseVisitor(ClauseVisitor):
- """extends ClauseVisitor to include debug logging of all traversal.
-
- To install this visitor, set logging.DEBUG for
- 'sqlalchemy.sql.ClauseVisitor' **before** you import the
- sqlalchemy.sql module.
- """
-
- def traverse(self, obj, stop_on=None):
- stack = [(obj, "")]
- traversal = []
- while len(stack) > 0:
- (t, indent) = stack.pop()
- if stop_on is None or t not in stop_on:
- traversal.insert(0, (t, indent))
- for c in t.get_children(**self.__traverse_options__):
- stack.append((c, indent + " "))
-
- for (target, indent) in traversal:
- self.logger.debug(indent + repr(target))
- v = self
- while v is not None:
- target.accept_visitor(v)
- v = getattr(v, '_next', None)
- return obj
-
-LoggingClauseVisitor.logger = logging.class_logger(ClauseVisitor)
-
-if logging.is_debug_enabled(LoggingClauseVisitor.logger):
- ClauseVisitor=LoggingClauseVisitor
class NoColumnVisitor(ClauseVisitor):
"""a ClauseVisitor that will not traverse the exported Column
"""
__traverse_options__ = {'column_collections':False}
-
-class Executor(object):
- """Interface representing a "thing that can produce Compiled objects
- and execute them"."""
- def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
- """Execute a Compiled object."""
-
- raise NotImplementedError()
-
- def compiler(self, statement, parameters, **kwargs):
- """Return a Compiled object for the given statement and parameters."""
-
- raise NotImplementedError()
-
-class Compiled(ClauseVisitor):
- """Represent a compiled SQL expression.
-
- The ``__str__`` method of the ``Compiled`` object should produce
- the actual text of the statement. ``Compiled`` objects are
- specific to their underlying database dialect, and also may
- or may not be specific to the columns referenced within a
- particular set of bind parameters. In no case should the
- ``Compiled`` object be dependent on the actual values of those
- bind parameters, even though it may reference those values as
- defaults.
- """
-
- def __init__(self, dialect, statement, parameters, engine=None):
- """Construct a new ``Compiled`` object.
-
- statement
- ``ClauseElement`` to be compiled.
-
- parameters
- Optional dictionary indicating a set of bind parameters
- specified with this ``Compiled`` object. These parameters
- are the *default* values corresponding to the
- ``ClauseElement``'s ``_BindParamClauses`` when the
- ``Compiled`` is executed. In the case of an ``INSERT`` or
- ``UPDATE`` statement, these parameters will also result in
- the creation of new ``_BindParamClause`` objects for each
- key and will also affect the generated column list in an
- ``INSERT`` statement and the ``SET`` clauses of an
- ``UPDATE`` statement. The keys of the parameter dictionary
- can either be the string names of columns or
- ``_ColumnClause`` objects.
-
- engine
- Optional Engine to compile this statement against.
- """
- self.dialect = dialect
- self.statement = statement
- self.parameters = parameters
- self.engine = engine
- self.can_execute = statement.supports_execution()
-
- def compile(self):
- self.traverse(self.statement)
- self.after_compile()
-
- def __str__(self):
- """Return the string text of the generated SQL statement."""
-
- raise NotImplementedError()
- def get_params(self, **params):
- """Deprecated. use construct_params(). (supports unicode names)
- """
-
- return self.construct_params(params)
-
- def construct_params(self, params):
- """Return the bind params for this compiled object.
-
- Will start with the default parameters specified when this
- ``Compiled`` object was first constructed, and will override
- those values with those sent via `**params`, which are
- key/value pairs. Each key should match one of the
- ``_BindParamClause`` objects compiled into this object; either
- the `key` or `shortname` property of the ``_BindParamClause``.
- """
- raise NotImplementedError()
+class _FigureVisitName(type):
+ def __init__(cls, clsname, bases, dict):
+ if not '__visit_name__' in cls.__dict__:
+ m = re.match(r'_?(\w+?)(?:Expression|Clause|Element|$)', clsname)
+ x = m.group(1)
+ x = re.sub(r'(?!^)[A-Z]', lambda m:'_'+m.group(0).lower(), x)
+ cls.__visit_name__ = x.lower()
+ super(_FigureVisitName, cls).__init__(clsname, bases, dict)
- def execute(self, *multiparams, **params):
- """Execute this compiled object."""
-
- e = self.engine
- if e is None:
- raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.")
- return e.execute_compiled(self, *multiparams, **params)
-
- def scalar(self, *multiparams, **params):
- """Execute this compiled object and return the result's scalar value."""
-
- return self.execute(*multiparams, **params).scalar()
-
class ClauseElement(object):
"""Base class for elements of a programmatically constructed SQL
expression.
"""
+ __metaclass__ = _FigureVisitName
+
+ def _clone(self):
+ # shallow copy. mutator operations always create
+ # clones of container objects.
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = self.__dict__.copy()
+ return c
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
"""Return objects represented in this ``ClauseElement`` that
should be added to the ``FROM`` list of a query, when this
``ClauseElement`` is placed in the column clause of a
raise NotImplementedError(repr(self))
- def _hide_froms(self):
+ def _hide_froms(self, **modifiers):
"""Return a list of ``FROM`` clause elements which this
``ClauseElement`` replaces.
"""
return self is other
- def accept_visitor(self, visitor):
- """Accept a ``ClauseVisitor`` and call the appropriate
- ``visit_xxx`` method.
- """
-
- raise NotImplementedError(repr(self))
-
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
"""return immediate child elements of this ``ClauseElement``.
this is used for visit traversal.
+ clone indicates child items should be _cloned(), replacing
+ the elements contained by this element, and the cloned
+ copy returned. this allows modifying traversals
+ to take place.
+
\**kwargs may contain flags that change the collection
that is returned, for example to return a subset of items
in order to cut down on larger traversals, or to return
return False
- def copy_container(self):
- """Return a copy of this ``ClauseElement``, if this
- ``ClauseElement`` contains other ``ClauseElements``.
-
- If this ``ClauseElement`` is not a container, it should return
- self. This is used to create copies of expression trees that
- still reference the same *leaf nodes*. The new structure can
- then be restructured without affecting the original.
- """
-
- return self
-
def _find_engine(self):
"""Default strategy for locating an engine within the clause element.
def _selectable(self):
return self
- def accept_visitor(self, visitor):
- raise NotImplementedError(repr(self))
-
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
clause of a ``SELECT`` statement.
"""
+ __visit_name__ = 'fromclause'
+
def __init__(self, name=None):
self.name = name
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
# this could also be [self], at the moment it doesnt matter to the Select object
return []
def default_order_by(self):
return [self.oid_column]
- def accept_visitor(self, visitor):
- visitor.visit_fromclause(self)
-
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
FindCols().traverse(self)
return ret
+ def is_derived_from(self, fromclause):
+ """return True if this FromClause is 'derived' from the given FromClause.
+
+ An example would be an Alias of a Table is derived from that Table."""
+
+ return False
+
def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_embedded=False):
"""Given a ``ColumnElement``, return the exported
``ColumnElement`` object from this ``Selectable`` which
self._export_columns()
return getattr(self, name)
+ def _clone_from_clause(self):
+ # delete all the "generated" collections of columns for a newly cloned FromClause,
+ # so that they will be re-derived from the item.
+ # this is because FromClause subclasses, when cloned, need to reestablish new "proxied"
+ # columns that are linked to the new item
+ for attr in ('_columns', '_primary_key' '_foreign_keys', '_orig_cols', '_oid_column'):
+ if hasattr(self, attr):
+ delattr(self, attr)
+
columns = property(lambda s:s._get_exported_attribute('_columns'))
c = property(lambda s:s._get_exported_attribute('_columns'))
primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
self._primary_key = ColumnCollection()
self._foreign_keys = util.Set()
self._orig_cols = {}
- for co in self._adjusted_exportable_columns():
+ for co in self._flatten_exportable_columns():
cp = self._proxy_column(co)
for ci in cp.orig_set:
# note that some ambiguity is raised here, whereby a selectable might have more than
for ci in self.oid_column.orig_set:
self._orig_cols[ci] = self.oid_column
- def _adjusted_exportable_columns(self):
+ def _flatten_exportable_columns(self):
"""return the list of ColumnElements represented within this FromClause's _exportable_columns"""
export = self._exportable_columns()
for column in export:
- try:
+ if hasattr(column, '_selectable'):
s = column._selectable()
- except AttributeError:
+ else:
continue
for co in s.columns:
yield co
Public constructor is the ``bindparam()`` function.
"""
+ __visit_name__ = 'bindparam'
+
def __init__(self, key, value, shortname=None, type=None, unique=False):
"""Construct a _BindParamClause.
self.unique = unique
self.type = sqltypes.to_instance(type)
- def accept_visitor(self, visitor):
- visitor.visit_bindparam(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
- def copy_container(self):
- return _BindParamClause(self.key, self.value, self.shortname, self.type, unique=self.unique)
-
def typeprocess(self, value, dialect):
return self.type.dialect_impl(dialect).convert_bind_param(value, dialect)
Used by the ``Case`` statement.
"""
+ __visit_name__ = 'typeclause'
+
def __init__(self, type):
self.type = type
- def accept_visitor(self, visitor):
- visitor.visit_typeclause(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
class _TextClause(ClauseElement):
Public constructor is the ``text()`` function.
"""
+ __visit_name__ = 'textclause'
+
def __init__(self, text = "", engine=None, bindparams=None, typemap=None):
self._engine = engine
self.bindparams = {}
columns = property(lambda s:[])
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.bindparams = [b._clone() for b in self.bindparams]
+
return self.bindparams.values()
- def accept_visitor(self, visitor):
- visitor.visit_textclause(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
def supports_execution(self):
def __init__(self):
self.type = sqltypes.NULLTYPE
- def accept_visitor(self, visitor):
- visitor.visit_null(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return []
class ClauseList(ClauseElement):
By default, is comma-separated, such as a column listing.
"""
-
+ __visit_name__ = 'clauselist'
+
def __init__(self, *clauses, **kwargs):
self.clauses = []
self.operator = kwargs.pop('operator', ',')
self.group = kwargs.pop('group', True)
self.group_contents = kwargs.pop('group_contents', True)
for c in clauses:
- if c is None: continue
+ if c is None:
+ continue
self.append(c)
def __iter__(self):
def __len__(self):
return len(self.clauses)
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return ClauseList(operator=self.operator, *clauses)
-
def self_group(self, against=None):
if self.group:
return _Grouping(self)
# TODO: not sure if i like the 'group_contents' flag. need to define the difference between
# a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ?
if self.group_contents:
- self.clauses.append(_literals_as_text(clause).self_group(against=self.operator))
+ self.clauses.append(_literal_as_text(clause).self_group(against=self.operator))
else:
- self.clauses.append(_literals_as_text(clause))
+ self.clauses.append(_literal_as_text(clause))
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.clauses = [clause._clone() for clause in self.clauses]
+
return self.clauses
- def accept_visitor(self, visitor):
- visitor.visit_clauselist(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
f = []
for c in self.clauses:
- f += c._get_from_objects()
+ f += c._get_from_objects(**modifiers)
return f
def self_group(self, against=None):
Extends ``ColumnElement`` to provide column-level comparison
operators.
"""
-
+ __visit_name__ = 'calculatedclause'
+
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type', None))
key = property(lambda self:self.name or "_calc_")
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return _CalculatedClause(type=self.type, engine=self._engine, *clauses)
-
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.clause_expr = self.clause_expr._clone()
return self.clause_expr,
- def accept_visitor(self, visitor):
- visitor.visit_calculatedclause(self)
- def _get_from_objects(self):
- return self.clauses._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.clauses._get_from_objects(**modifiers)
def _bind_param(self, obj):
return _BindParamClause(self.name, obj, type=self.type, unique=True)
key = property(lambda self:self.name)
-
- def append(self, clause):
- self.clauses.append(_literals_as_binds(clause, self.name))
-
- def copy_container(self):
- clauses = [clause.copy_container() for clause in self.clauses]
- return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self._clone_from_clause()
+ return _CalculatedClause.get_children(self, clone=clone, **kwargs)
- def accept_visitor(self, visitor):
- visitor.visit_function(self)
+ def append(self, clause):
+ self.clauses.append(_literal_as_binds(clause, self.name))
class _Cast(ColumnElement):
+
def __init__(self, clause, totype, **kwargs):
if not hasattr(clause, 'label'):
clause = literal(clause)
self.clause = clause
self.typeclause = _TypeClause(self.type)
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.clause = self.clause._clone()
+ self.typeclause = self.typeclause._clone()
+
return self.clause, self.typeclause
- def accept_visitor(self, visitor):
- visitor.visit_cast(self)
- def _get_from_objects(self):
- return self.clause._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.clause._get_from_objects(**modifiers)
def _make_proxy(self, selectable, name=None):
if name is not None:
self.operator = operator
self.modifier = modifier
- self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier)
+ self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier)
self.type = sqltypes.to_instance(type)
self.negate = negate
- def copy_container(self):
- return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate)
-
- def _get_from_objects(self):
- return self.element._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.element._get_from_objects(**modifiers)
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.element = self.element._clone()
return self.element,
- def accept_visitor(self, visitor):
- visitor.visit_unary(self)
-
def compare(self, other):
"""Compare this ``_UnaryClause`` against the given ``ClauseElement``."""
self.modifier == other.modifier and
self.element.compare(other.element)
)
+
def _negate(self):
if self.negate is not None:
return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
"""Represent an expression that is ``LEFT <operator> RIGHT``."""
def __init__(self, left, right, operator, type=None, negate=None):
- self.left = _literals_as_text(left).self_group(against=operator)
- self.right = _literals_as_text(right).self_group(against=operator)
+ self.left = _literal_as_text(left).self_group(against=operator)
+ self.right = _literal_as_text(right).self_group(against=operator)
self.operator = operator
self.type = sqltypes.to_instance(type)
self.negate = negate
- def copy_container(self):
- return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator)
-
- def _get_from_objects(self):
- return self.left._get_from_objects() + self.right._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.left = self.left._clone()
+ self.right = self.right._clone()
+
return self.left, self.right
- def accept_visitor(self, visitor):
- visitor.visit_binary(self)
-
def compare(self, other):
"""Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``."""
return super(_BinaryExpression, self)._negate()
class _Exists(_UnaryExpression):
+ __visit_name__ = _UnaryExpression.__visit_name__
+
def __init__(self, *args, **kwargs):
kwargs['correlate'] = True
s = select(*args, **kwargs).self_group()
_UnaryExpression.__init__(self, s, operator="EXISTS")
- def _hide_froms(self):
- return self._get_from_objects()
+ def _hide_froms(self, **modifiers):
+ return self._get_from_objects(**modifiers)
class Join(FromClause):
"""represent a ``JOIN`` construct between two ``FromClause``
def _init_primary_key(self):
pkcol = util.OrderedSet()
- for col in self._adjusted_exportable_columns():
+ for col in self._flatten_exportable_columns():
if col.primary_key:
pkcol.add(col)
for col in list(pkcol):
self._foreign_keys.add(f)
return column
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self._clone_from_clause()
+ self.left = self.left._clone()
+ self.right = self.right._clone()
+ self.onclause = self.onclause._clone()
+ self.__folded_equivalents = None
+ self._init_primary_key()
+ return self.left, self.right, self.onclause
+
def _match_primaries(self, primary, secondary):
crit = []
constraints = util.Set()
return select(collist, whereclause, from_obj=[self], **kwargs)
- def get_children(self, **kwargs):
- return self.left, self.right, self.onclause
-
- def accept_visitor(self, visitor):
- visitor.visit_join(self)
-
engine = property(lambda s:s.left.engine or s.right.engine)
def alias(self, name=None):
return self.select(use_labels=True, correlate=False).alias(name)
- def _hide_froms(self):
- return self.left._get_from_objects() + self.right._get_from_objects()
+ def _hide_froms(self, **modifiers):
+ return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
+ def _get_from_objects(self, **modifiers):
+ return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
class Alias(FromClause):
"""represent an alias, as typically applied to any
self.encodedname = alias.encode('ascii', 'backslashreplace')
self.case_sensitive = getattr(baseselectable, "case_sensitive", True)
+ def is_derived_from(self, fromclause):
+ x = self.selectable
+ while isinstance(x, Alias):
+ x = x.selectable
+ if x is fromclause:
+ return True
+ return False
+
def supports_execution(self):
return self.original.supports_execution()
#return self.selectable._exportable_columns()
return self.selectable.columns
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self._clone_from_clause()
+ self.selectable = self.selectable._clone()
+ baseselectable = self.selectable
+ while isinstance(baseselectable, Alias):
+ baseselectable = baseselectable.selectable
+ self.original = baseselectable
for c in self.c:
yield c
yield self.selectable
- def accept_visitor(self, visitor):
- visitor.visit_alias(self)
-
def _get_from_objects(self):
return [self]
_label = property(lambda s: s.elem._label)
orig_set = property(lambda s:s.elem.orig_set)
- def copy_container(self):
- return _Grouping(self.elem.copy_container())
-
- def accept_visitor(self, visitor):
- visitor.visit_grouping(self)
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.elem = self.elem._clone()
return self.elem,
- def _hide_froms(self):
- return self.elem._hide_froms()
- def _get_from_objects(self):
- return self.elem._get_from_objects()
+
+ def _hide_froms(self, **modifiers):
+ return self.elem._hide_froms(**modifiers)
+
+ def _get_from_objects(self, **modifiers):
+ return self.elem._get_from_objects(**modifiers)
class _Label(ColumnElement):
"""represent a label, as typically applied to any column-level element
def _compare_self(self):
return self.obj
- def get_children(self, **kwargs):
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.obj = self.obj._clone()
return self.obj,
- def accept_visitor(self, visitor):
- visitor.visit_label(self)
+ def _get_from_objects(self, **modifiers):
+ return self.obj._get_from_objects(**modifiers)
- def _get_from_objects(self):
- return self.obj._get_from_objects()
-
- def _hide_froms(self):
- return self.obj._hide_froms()
+ def _hide_froms(self, **modifiers):
+ return self.obj._hide_froms(**modifiers)
def _make_proxy(self, selectable, name = None):
if isinstance(self.obj, Selectable):
self.__label = None
self.case_sensitive = case_sensitive
self.is_literal = is_literal
-
+
+ def _clone(self):
+ # ColumnClause is immutable
+ return self
+
def _get_label(self):
"""Generate a 'label' for this column.
else:
return super(_ColumnClause, self).label(name)
- def accept_visitor(self, visitor):
- visitor.visit_column(self)
-
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
if self.table is not None:
return [self.table]
else:
self.append_column(c)
self._oid_column = _ColumnClause('oid', self, _is_oid=True)
+ def _clone(self):
+ # TableClause is immutable
+ return self
+
def named_with_column(self):
return True
else:
return []
- def accept_visitor(self, visitor):
- visitor.visit_table(self)
-
def _exportable_columns(self):
raise NotImplementedError()
def delete(self, whereclause = None):
return delete(self, whereclause)
- def _get_from_objects(self):
+ def _get_from_objects(self, **modifiers):
return [self]
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
+ def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, connectable=None, scalar=False, engine=None):
+ self.use_labels = use_labels
+ self.for_update = for_update
+ self._limit = limit
+ self._offset = offset
+ self._engine = connectable or engine
+ self.is_scalar = scalar
+ if self.is_scalar:
+ # allow corresponding_column to return None
+ self.orig_set = util.Set()
+
+ self.append_order_by(*util.to_list(order_by, []))
+ self.append_group_by(*util.to_list(group_by, []))
+
def supports_execution(self):
return True
+ def _generate(self):
+ s = self._clone()
+ s._clone_from_clause()
+ return s
+
+ def limit(self, limit):
+ s = self._generate()
+ s._limit = limit
+ return s
+
+ def offset(self, offset):
+ s = self._generate()
+ s._offset = offset
+ return s
+
def order_by(self, *clauses):
- if len(clauses) == 1 and clauses[0] is None:
- self.order_by_clause = ClauseList()
- elif getattr(self, 'order_by_clause', None):
- self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses)))
- else:
- self.order_by_clause = ClauseList(*clauses)
+ s = self._generate()
+ s.append_order_by(*clauses)
+ return s
def group_by(self, *clauses):
- if len(clauses) == 1 and clauses[0] is None:
- self.group_by_clause = ClauseList()
- elif getattr(self, 'group_by_clause', None):
- self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses)))
+ s = self._generate()
+ s.append_group_by(*clauses)
+ return s
+
+ def append_order_by(self, *clauses):
+ if clauses == [None]:
+ self._order_by_clause = ClauseList()
else:
- self.group_by_clause = ClauseList(*clauses)
+ if getattr(self, '_order_by_clause', None):
+ clauses = list(self._order_by_clause) + list(clauses)
+ self._order_by_clause = ClauseList(*clauses)
+ def append_group_by(self, *clauses):
+ if clauses == [None]:
+ self._group_by_clause = ClauseList()
+ else:
+ if getattr(self, '_group_by_clause', None):
+ clauses = list(self._group_by_clause) + list(clauses)
+ self._group_by_clause = ClauseList(*clauses)
+
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
- def _get_from_objects(self):
- if self.is_where or self.is_scalar:
+ def _get_from_objects(self, is_where=False, **modifiers):
+ if is_where or self.is_scalar:
return []
else:
return [self]
class CompoundSelect(_SelectBaseMixin, FromClause):
def __init__(self, keyword, *selects, **kwargs):
- _SelectBaseMixin.__init__(self)
+ self._should_correlate = kwargs.pop('correlate', False)
self.keyword = keyword
- self.use_labels = kwargs.pop('use_labels', False)
- self.should_correlate = kwargs.pop('correlate', False)
- self.for_update = kwargs.pop('for_update', False)
- self.nowait = kwargs.pop('nowait', False)
- self.limit = kwargs.pop('limit', None)
- self.offset = kwargs.pop('offset', None)
- self.is_compound = True
- self.is_where = False
- self.is_scalar = False
- self.is_subquery = False
-
- self.selects = selects
+ self.selects = []
# some DBs do not like ORDER BY in the inner queries of a UNION, etc.
for s in selects:
- s.order_by(None)
+ if len(s._order_by_clause):
+ s = s.order_by(None)
+ self.selects.append(s)
- self.group_by(*kwargs.pop('group_by', [None]))
- self.order_by(*kwargs.pop('order_by', [None]))
- if len(kwargs):
- raise TypeError("invalid keyword argument(s) for CompoundSelect: %s" % repr(kwargs.keys()))
self._col_map = {}
+ _SelectBaseMixin.__init__(self, **kwargs)
+
name = property(lambda s:s.keyword + " statement")
def self_group(self, against=None):
col.orig_set = colset
return col
- def get_children(self, column_collections=True, **kwargs):
- return (column_collections and list(self.c) or []) + \
- [self.order_by_clause, self.group_by_clause] + list(self.selects)
- def accept_visitor(self, visitor):
- visitor.visit_compound_select(self)
+ def get_children(self, clone=False, column_collections=True, **kwargs):
+ if clone:
+ self._clone_from_clause()
+ self._col_map = {}
+ self.selects = [s._clone() for s in self.selects]
+ for attr in ('_order_by_clause', '_group_by_clause'):
+ if getattr(self, attr) is not None:
+ setattr(self, attr, getattr(self, attr)._clone())
+ return (column_collections and list(self.c) or []) + \
+ [self._order_by_clause, self._group_by_clause] + list(self.selects)
+
def _find_engine(self):
for s in self.selects:
e = s._find_engine()
"""
- def __init__(self, columns=None, whereclause=None, from_obj=[],
- order_by=None, group_by=None, having=None,
- use_labels=False, distinct=False, for_update=False,
- engine=None, limit=None, offset=None, scalar=False,
- correlate=True):
+ def __init__(self, columns, whereclause=None, from_obj=None, distinct=False, having=None, correlate=True, **kwargs):
"""construct a Select object.
The public constructor for Select is the [sqlalchemy.sql#select()] function;
see that function for argument descriptions.
"""
- _SelectBaseMixin.__init__(self)
- self.__froms = util.OrderedSet()
- self.__hide_froms = util.Set([self])
- self.use_labels = use_labels
- self.whereclause = None
- self.having = None
- self._engine = engine
- self.limit = limit
- self.offset = offset
- self.for_update = for_update
- self.is_compound = False
- # indicates that this select statement should not expand its columns
- # into the column clause of an enclosing select, and should instead
- # act like a single scalar column
- self.is_scalar = scalar
- if scalar:
- # allow corresponding_column to return None
- self.orig_set = util.Set()
-
- # indicates if this select statement, as a subquery, should automatically correlate
- # its FROM clause to that of an enclosing select, update, or delete statement.
- # note that the "correlate" method can be used to explicitly add a value to be correlated.
- self.should_correlate = correlate
-
- # indicates if this select statement is a subquery inside another query
- self.is_subquery = False
-
- # indicates if this select statement is in the from clause of another query
- self.is_selected_from = False
+ self._should_correlate = correlate
+ self._distinct = distinct
- # indicates if this select statement is a subquery as a criterion
- # inside of a WHERE clause
- self.is_where = False
-
- self.distinct = distinct
self._raw_columns = []
- self.__correlated = {}
- self.__correlator = Select._CorrelatedVisitor(self, False)
- self.__wherecorrelator = Select._CorrelatedVisitor(self, True)
- self.__fromvisitor = Select._FromVisitor(self)
-
-
- self.order_by_clause = self.group_by_clause = None
+ self.__correlate = util.Set()
+ self._froms = util.OrderedSet()
+ self._whereclause = None
+ self._having = None
if columns is not None:
for c in columns:
self.append_column(c)
- if order_by:
- order_by = util.to_list(order_by)
- if group_by:
- group_by = util.to_list(group_by)
- self.order_by(*(order_by or [None]))
- self.group_by(*(group_by or [None]))
- for c in self.order_by_clause:
- self.__correlator.traverse(c)
- for c in self.group_by_clause:
- self.__correlator.traverse(c)
-
- for f in from_obj:
- self.append_from(f)
-
- # whereclauses must be appended after the columns/FROM, since it affects
- # the correlation of subqueries. see test/sql/select.py SelectTest.testwheresubquery
+ if from_obj is not None:
+ for f in from_obj:
+ self.append_from(f)
+
if whereclause is not None:
self.append_whereclause(whereclause)
+
if having is not None:
self.append_having(having)
+ _SelectBaseMixin.__init__(self, **kwargs)
- class _CorrelatedVisitor(NoColumnVisitor):
- """Visit a clause, locate any ``Select`` clauses, and tell
- them that they should correlate their ``FROM`` list to that of
- their parent.
- """
-
- def __init__(self, select, is_where):
- NoColumnVisitor.__init__(self)
- self.select = select
- self.is_where = is_where
-
- def visit_compound_select(self, cs):
- self.visit_select(cs)
-
- def visit_column(self, c):
- pass
+ def get_display_froms(self, correlation_state=None):
+ froms = util.Set()
+ hide_froms = util.Set()
+
+ for col in self._raw_columns:
+ for f in col._hide_froms():
+ hide_froms.add(f)
+ for f in col._get_from_objects():
+ froms.add(f)
- def visit_table(self, c):
- pass
+ if self._whereclause is not None:
+ for f in self._whereclause._get_from_objects(is_where=True):
+ froms.add(f)
+
+ for elem in self._froms:
+ froms.add(elem)
+ for f in elem._get_from_objects():
+ froms.add(f)
- def visit_select(self, select):
- if select is self.select:
- return
- select.is_where = self.is_where
- select.is_subquery = True
- if not select.should_correlate:
- return
- [select.correlate(x) for x in self.select._Select__froms]
+ for elem in froms:
+ for f in elem._hide_froms():
+ hide_froms.add(f)
- class _FromVisitor(NoColumnVisitor):
- def __init__(self, select):
- NoColumnVisitor.__init__(self)
- self.select = select
+ froms = froms.difference(hide_froms)
+
+ if len(froms) > 1:
+ corr = self.__correlate
+ if correlation_state is not None:
+ corr = correlation_state[self].get('correlate', util.Set()).union(corr)
+ return froms.difference(corr)
+ else:
+ return froms
+
+ def locate_all_froms(self):
+ froms = util.Set()
+ for col in self._raw_columns:
+ for f in col._get_from_objects():
+ froms.add(f)
+
+ if self._whereclause is not None:
+ for f in self._whereclause._get_from_objects(is_where=True):
+ froms.add(f)
+
+ for elem in self._froms:
+ froms.add(elem)
+ for f in elem._get_from_objects():
+ froms.add(f)
+ return froms
+
+ def calculate_correlations(self, correlation_state):
+ if self not in correlation_state:
+ correlation_state[self] = {}
+
+ display_froms = self.get_display_froms(correlation_state)
+
+ class CorrelatedVisitor(NoColumnVisitor):
+ def __init__(self, is_where=False, is_column=False, is_from=False):
+ self.is_where = is_where
+ self.is_column = is_column
+ self.is_from = is_from
+
+ def visit_compound_select(self, cs):
+ self.visit_select(cs)
+
+ def visit_select(s, select):
+ if select not in correlation_state:
+ correlation_state[select] = {}
+
+ if select is self:
+ return
+
+ select_state = correlation_state[select]
+ if s.is_from:
+ select_state['is_selected_from'] = True
+ if s.is_where:
+ select_state['is_where'] = True
+ select_state['is_subquery'] = True
+
+ if select._should_correlate:
+ corr = select_state.setdefault('correlate', util.Set())
+ # not crazy about this part. need to be clearer on what elements in the
+ # subquery correspond to elements in the enclosing query.
+ for f in display_froms:
+ corr.add(f)
+ for f2 in f._get_from_objects():
+ corr.add(f2)
+
+ col_vis = CorrelatedVisitor(is_column=True)
+ where_vis = CorrelatedVisitor(is_where=True)
+ from_vis = CorrelatedVisitor(is_from=True)
+
+ for col in self._raw_columns:
+ col_vis.traverse(col)
+ for f in col._get_from_objects():
+ if f is not self:
+ from_vis.traverse(f)
+
+ for col in list(self._order_by_clause) + list(self._group_by_clause):
+ col_vis.traverse(col)
+
+ if self._whereclause is not None:
+ where_vis.traverse(self._whereclause)
+ for f in self._whereclause._get_from_objects(is_where=True):
+ if f is not self:
+ from_vis.traverse(f)
+
+ for elem in self._froms:
+ from_vis.traverse(elem)
+
+ def _get_inner_columns(self):
+ for c in self._raw_columns:
+ if hasattr(c, '_selectable'):
+ for co in c._selectable().columns:
+ yield co
+ else:
+ yield c
+
+ inner_columns = property(_get_inner_columns)
+
+ def get_children(self, clone=False, column_collections=True, **kwargs):
+ if clone:
+ self._clone_from_clause()
+ self._raw_columns = [c._clone() for c in self._raw_columns]
+ self._recorrelate_froms([f._clone() for f in self._froms])
+ for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
+ if getattr(self, attr) is not None:
+ setattr(self, attr, getattr(self, attr)._clone())
+
+ return (column_collections and list(self.columns) or []) + \
+ list(self._froms) + \
+ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
+
+ def _recorrelate_froms(self, froms):
+ newcorrelate = util.Set()
+ for f in froms:
+ if f in self.__correlate:
+ newcorrelate.add(cl)
+ self.__correlate.remove(f)
+ self.__correlate = self.__correlate.union(newcorrelate)
+ self._froms = froms
+
+ def column(self, column):
+ s = self._generate()
+ s.append_column(column)
+ return s
+
+ def where(self, whereclause):
+ s = self._generate()
+ s.append_whereclause(whereclause)
+ return s
+
+ def having(self, having):
+ s = self._generate()
+ s.append_having(having)
+ return s
+
+ def distinct(self):
+ s = self._generate()
+ s.distinct = True
+ return s
+
+ def select_from(self, fromclause):
+ s = self._generate()
+ s.append_from(fromclause)
+ return s
+
+ def correlate_to(self, fromclause):
+ s = self._generate()
+ s.append_correlation(fromclause)
+ return s
+
+ def append_correlation(self, fromclause):
+ self.__correlate.add(fromclause)
- def visit_select(self, select):
- if select is self.select:
- return
- select.is_selected_from = True
- select.is_subquery = True
-
def append_column(self, column):
if _is_literal(column):
column = literal_column(str(column))
self._raw_columns.append(column)
- if self.is_scalar and not hasattr(self, 'type'):
- self.type = column.type
-
- # if the column is a Select statement itself,
- # accept visitor
- self.__correlator.traverse(column)
+ def append_whereclause(self, whereclause):
+ if self._whereclause is not None:
+ self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
+ else:
+ self._whereclause = _literal_as_text(whereclause)
+
+ def append_having(self, having):
+ if self._having is not None:
+ self._having = and_(self._having, _literal_as_text(having))
+ else:
+ self._having = _literal_as_text(having)
- # visit the FROM objects of the column looking for more Selects
- for f in column._get_from_objects():
- if f is not self:
- self.__correlator.traverse(f)
- self._process_froms(column, False)
+ def append_from(self, fromclause):
+ if _is_literal(fromclause):
+ fromclause = FromClause(fromclause)
+ self._froms.add(fromclause)
def _make_proxy(self, selectable, name):
if self.is_scalar:
- return self._raw_columns[0]._make_proxy(selectable, name)
+ return list(self.inner_columns)[0]._make_proxy(selectable, name)
else:
raise exceptions.InvalidRequestError("Not a scalar select statement")
else:
return label(name, self)
+ def _get_type(self):
+ if self.is_scalar:
+ return list(self.inner_columns)[0].type
+ else:
+ return None
+ type = property(_get_type)
+
def _exportable_columns(self):
return [c for c in self._raw_columns if isinstance(c, Selectable)]
else:
return column._make_proxy(self)
- def _process_froms(self, elem, asfrom):
- for f in elem._get_from_objects():
- self.__fromvisitor.traverse(f)
- self.__froms.add(f)
- if asfrom:
- self.__froms.add(elem)
- for f in elem._hide_froms():
- self.__hide_froms.add(f)
-
def self_group(self, against=None):
return _Grouping(self)
-
- def append_whereclause(self, whereclause):
- self._append_condition('whereclause', whereclause)
-
- def append_having(self, having):
- self._append_condition('having', having)
-
- def _append_condition(self, attribute, condition):
- if type(condition) == str:
- condition = _TextClause(condition)
- self.__wherecorrelator.traverse(condition)
- self._process_froms(condition, False)
- if getattr(self, attribute) is not None:
- setattr(self, attribute, and_(getattr(self, attribute), condition))
- else:
- setattr(self, attribute, condition)
-
- def correlate(self, from_obj):
- """Given a ``FROM`` object, correlate this ``SELECT`` statement to it.
-
- This basically means the given from object will not come out
- in this select statement's ``FROM`` clause when printed.
- """
-
- self.__correlated[from_obj] = from_obj
-
- def append_from(self, fromclause):
- if type(fromclause) == str:
- fromclause = FromClause(fromclause)
- self.__correlator.traverse(fromclause)
- self._process_froms(fromclause, True)
def _locate_oid_column(self):
- for f in self.__froms:
+ for f in self.locate_all_froms():
if f is self:
# we might be in our own _froms list if a column with us as the parent is attached,
# which includes textual columns.
else:
return None
- def _calc_froms(self):
- f = self.__froms.difference(self.__hide_froms)
- if (len(f) > 1):
- return f.difference(self.__correlated)
- else:
- return f
-
- froms = property(_calc_froms,
- doc="""A collection containing all elements
- of the ``FROM`` clause.""")
-
- def get_children(self, column_collections=True, **kwargs):
- return (column_collections and list(self.columns) or []) + \
- list(self.froms) + \
- [x for x in (self.whereclause, self.having, self.order_by_clause, self.group_by_clause) if x is not None]
-
- def accept_visitor(self, visitor):
- visitor.visit_select(self)
-
def union(self, other, **kwargs):
return union(self, other, **kwargs)
if self._engine is not None:
return self._engine
- for f in self.__froms:
+ for f in self._froms:
if f is self:
continue
e = f.engine
def supports_execution(self):
return True
- class _SelectCorrelator(NoColumnVisitor):
- def __init__(self, table):
- NoColumnVisitor.__init__(self)
- self.table = table
-
- def visit_select(self, select):
- if select.should_correlate:
- select.correlate(self.table)
-
- def _process_whereclause(self, whereclause):
- if whereclause is not None:
- _UpdateBase._SelectCorrelator(self.table).traverse(whereclause)
- return whereclause
-
+ def calculate_correlations(self, correlate_state):
+ class SelectCorrelator(NoColumnVisitor):
+ def visit_select(s, select):
+ if select._should_correlate:
+ select_state = correlate_state.setdefault(select, {})
+ corr = select_state.setdefault('correlate', util.Set())
+ corr.add(self.table)
+
+ vis = SelectCorrelator()
+
+ if self._whereclause is not None:
+ vis.traverse(self._whereclause)
+
+ if getattr(self, 'parameters', None) is not None:
+ for key, value in self.parameters.items():
+ if isinstance(value, ClauseElement):
+ vis.traverse(value)
+
def _process_colparams(self, parameters):
"""Receive the *values* of an ``INSERT`` or ``UPDATE``
statement and construct appropriate bind parameters.
i +=1
parameters = pp
- correlator = _UpdateBase._SelectCorrelator(self.table)
for key in parameters.keys():
value = parameters[key]
if isinstance(value, ClauseElement):
- correlator.traverse(value)
+ pass
elif _is_literal(value):
if _is_literal(key):
col = self.table.c[key]
def _find_engine(self):
return self.table.engine
-class _Insert(_UpdateBase):
+class Insert(_UpdateBase):
def __init__(self, table, values=None):
self.table = table
self.select = None
return self.select,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_insert(self)
-class _Update(_UpdateBase):
+class Update(_UpdateBase):
def __init__(self, table, whereclause, values=None):
self.table = table
- self.whereclause = self._process_whereclause(whereclause)
+ self._whereclause = whereclause
self.parameters = self._process_colparams(values)
def get_children(self, **kwargs):
- if self.whereclause is not None:
- return self.whereclause,
+ if self._whereclause is not None:
+ return self._whereclause,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_update(self)
-class _Delete(_UpdateBase):
+class Delete(_UpdateBase):
def __init__(self, table, whereclause):
self.table = table
- self.whereclause = self._process_whereclause(whereclause)
+ self._whereclause = whereclause
def get_children(self, **kwargs):
- if self.whereclause is not None:
- return self.whereclause,
+ if self._whereclause is not None:
+ return self._whereclause,
else:
return ()
- def accept_visitor(self, visitor):
- visitor.visit_delete(self)
process the new list.
"""
- list_ = [o.copy_container() for o in list_]
+ list_ = list(list_)
self.process_list(list_)
return list_
if elem is not None:
list_[i] = elem
else:
- self.traverse(list_[i])
+ list_[i] = self.traverse(list_[i], clone=True)
def visit_grouping(self, grouping):
elem = self.convert_element(grouping.elem)
elem = self.convert_element(binary.right)
if elem is not None:
binary.right = elem
-
- # TODO: visit_select().
+
+ def visit_select(self, select):
+ fr = util.OrderedSet()
+ for elem in select._froms:
+ n = self.convert_element(elem)
+ if n is None:
+ fr.add(elem)
+ else:
+ fr.add(n)
+ select._recorrelate_froms(fr)
+
+ col = []
+ for elem in select._raw_columns:
+ n = self.convert_element(elem)
+ if n is None:
+ col.append(elem)
+ else:
+ col.append(n)
+ select._raw_columns = col
class ClauseAdapter(AbstractClauseProcessor):
"""Given a clause (like as in a WHERE criterion), locate columns
self.equivalents = equivalents
def convert_element(self, col):
+ if isinstance(col, sql.FromClause):
+ if self.selectable.is_derived_from(col):
+ return self.selectable
if not isinstance(col, sql.ColumnElement):
return None
if self.include is not None:
self[key] = value = self.creator(key)
return value
-def to_list(x):
+def to_list(x, default=None):
if x is None:
- return None
+ return default
if not isinstance(x, list) and not isinstance(x, tuple):
return [x]
else:
})
session = create_session()
query = session.query(tables.User)
- x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+ x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
print x.compile()
self.assert_result(list(x), tables.User, *tables.user_result[1:3])
def test_outerjointo_count(self):
})
session = create_session()
query = session.query(tables.User)
- x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
+ x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
assert x==2
def test_from(self):
mapper(tables.User, tables.users, properties={
Column('data', String(30))
)
- join = polymorphic_union(
- {
- 'table3' : table1.join(table3),
- 'table2' : table1.join(table2),
- 'table1' : table1.select(table1.c.type.in_('table1', 'table1b')),
- }, None, 'pjoin')
-
- # still with us so far ?
+ #join = polymorphic_union(
+ # {
+ # 'table3' : table1.join(table3),
+ # 'table2' : table1.join(table2),
+ # 'table1' : table1.select(table1.c.type.in_('table1', 'table1b')),
+ # }, None, 'pjoin')
+
+ join = table1.outerjoin(table2).outerjoin(table3).alias('pjoin')
+ #join = None
class Table1(object):
def __init__(self, name, data=None):
return "%s(%d, %s)" % (self.__class__.__name__, self.id, repr(str(self.data)))
try:
- # this is how the mapping used to work. insure that this raises an error now
+ # this is how the mapping used to work. ensure that this raises an error now
table1_mapper = mapper(Table1, table1,
select_table=join,
- polymorphic_on=join.c.type,
+ polymorphic_on=table1.c.type,
polymorphic_identity='table1',
properties={
'next': relation(Table1,
# exception now. since eager loading would never work for that relation anyway, its better that the user
# gets an exception instead of it silently not eager loading.
table1_mapper = mapper(Table1, table1,
- select_table=join,
- polymorphic_on=join.c.type,
+ #select_table=join,
+ polymorphic_on=table1.c.type,
polymorphic_identity='table1',
properties={
'next': relation(Table1,
polymorphic_identity='table2')
table3_mapper = mapper(Table3, table3, inherits=table1_mapper, polymorphic_identity='table3')
-
+
+ table1_mapper.compile()
+ assert table1_mapper.primary_key == [table1.c.id], table1_mapper.primary_key
+
def testone(self):
self.do_testlist([Table1, Table2, Table1, Table2])
'concat': column_property(f),
'count': column_property(select([func.count(addresses.c.address_id)], users.c.user_id==addresses.c.user_id, scalar=True).label('count'))
})
+
+ mapper(Address, addresses, properties={
+ 'user':relation(User, lazy=False)
+ })
sess = create_session()
l = sess.query(User).select()
assert l[0].concat == l[0].user_id * 2 == 14
assert l[1].concat == l[1].user_id * 2 == 16
- ### eager loads, not really working across all DBs, no column aliasing in place so
- # results still wont be good for larger situations
- clear_mappers()
- mapper(Address, addresses, properties={
- 'user':relation(User, lazy=False)
- })
-
- mapper(User, users, properties={
- 'concat': column_property(f),
- })
-
- for x in range(0, 2):
- sess.clear()
- l = sess.query(Address).select()
- for a in l:
- print "User", a.user.user_id, a.user.user_name, a.user.concat
- assert l[0].user.concat == l[0].user.user_id * 2 == 14
- assert l[1].user.concat == l[1].user.user_id * 2 == 16
+ for option in (None, eagerload('user')):
+ for x in range(0, 2):
+ sess.clear()
+ l = sess.query(Address)
+ if option:
+ l = l.options(option)
+ l = l.all()
+ for a in l:
+ print "User", a.user.user_id, a.user.user_name, a.user.concat, a.user.count
+ assert l[0].user.concat == l[0].user.user_id * 2 == 14
+ assert l[1].user.concat == l[1].user.user_id * 2 == 16
+ assert l[0].user.count == 1
+ assert l[1].user.count == 3
@testbase.unsupported('firebird')
"""test eager loading of a mapper which is against a select"""
s = select([orders], orders.c.isopen==1).alias('openorders')
+ print "SELECT:", id(s), str(s)
mapper(Order, s, properties={
'user':relation(User, lazy=False)
})
from testbase import PersistTest, AssertMixin
import unittest, sys, os
from sqlalchemy import *
+from sqlalchemy.orm import *
from testbase import Table, Column
import StringIO
import testbase
'sql.testtypes',
'sql.constraints',
+ 'sql.generative',
+
# SQL syntax
'sql.select',
'sql.selectable',
--- /dev/null
+import testbase
+from sqlalchemy import *
+
+class TraversalTest(testbase.AssertMixin):
+ """test ClauseVisitor's traversal, particularly its ability to copy and modify
+ a ClauseElement in place."""
+
+ def setUpAll(self):
+ global A, B
+
+ # establish two ficticious ClauseElements.
+ # define deep equality semantics as well as deep identity semantics.
+ class A(ClauseElement):
+ def __init__(self, expr):
+ self.expr = expr
+
+ def accept_visitor(self, visitor):
+ visitor.visit_a(self)
+
+ def is_other(self, other):
+ return other is self
+
+ def __eq__(self, other):
+ return other.expr == self.expr
+
+ def __ne__(self, other):
+ return other.expr != self.expr
+
+ def __str__(self):
+ return "A(%s)" % repr(self.expr)
+
+ class B(ClauseElement):
+ def __init__(self, *items):
+ self.items = items
+
+ def is_other(self, other):
+ if other is not self:
+ return False
+ for i1, i2 in zip(self.items, other.items):
+ if i1 is not i2:
+ return False
+ return True
+
+ def __eq__(self, other):
+ for i1, i2 in zip(self.items, other.items):
+ if i1 != i2:
+ return False
+ return True
+
+ def __ne__(self, other):
+ for i1, i2 in zip(self.items, other.items):
+ if i1 != i2:
+ return True
+ return False
+
+ def get_children(self, clone=False, **kwargs):
+ if clone:
+ self.items = [i._clone() for i in self.items]
+ return self.items
+
+ def accept_visitor(self, visitor):
+ visitor.visit_b(self)
+
+ def __str__(self):
+ return "B(%s)" % repr([str(i) for i in self.items])
+
+ def test_test_classes(self):
+ a1 = A("expr1")
+ struct = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct2 = B(a1, A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct3 = B(a1, A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
+
+ assert a1.is_other(a1)
+ assert struct.is_other(struct)
+ assert struct == struct2
+ assert struct != struct3
+ assert not struct.is_other(struct2)
+ assert not struct.is_other(struct3)
+
+ def test_clone(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ pass
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=True)
+ assert struct == s2
+ assert not struct.is_other(s2)
+
+ def test_no_clone(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ pass
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=False)
+ assert struct == s2
+ assert struct.is_other(s2)
+
+ def test_change_in_place(self):
+ struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3"))
+ struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
+
+ class Vis(ClauseVisitor):
+ def visit_a(self, a):
+ if a.expr == "expr2":
+ a.expr = "expr2modified"
+ def visit_b(self, b):
+ pass
+
+ vis = Vis()
+ s2 = vis.traverse(struct, clone=True)
+ assert struct != s2
+ assert struct is not s2
+ assert struct2 == s2
+
+ class Vis2(ClauseVisitor):
+ def visit_a(self, a):
+ if a.expr == "expr2b":
+ a.expr = "expr2bmodified"
+ def visit_b(self, b):
+ pass
+
+ vis2 = Vis2()
+ s3 = vis2.traverse(struct, clone=True)
+ assert struct != s3
+ assert struct3 == s3
+
+class ClauseTest(testbase.AssertMixin):
+ def setUpAll(self):
+ global t1, t2
+ t1 = table("table1",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+ t2 = table("table2",
+ column("col1"),
+ column("col2"),
+ column("col3"),
+ )
+
+ def test_binary(self):
+ clause = t1.c.col2 == t2.c.col2
+ assert str(clause) == ClauseVisitor().traverse(clause, clone=True)
+
+ def test_join(self):
+ clause = t1.join(t2, t1.c.col2==t2.c.col2)
+ c1 = str(clause)
+ assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True))
+
+ class Vis(ClauseVisitor):
+ def visit_binary(self, binary):
+ binary.right = t2.c.col3
+
+ clause2 = Vis().traverse(clause, clone=True)
+ assert c1 == str(clause)
+ assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
+
+ def test_select(self):
+ s = t1.select()
+ s2 = select([s])
+ s2_assert = str(s2)
+ s3_assert = str(select([t1.select()], t1.c.col2==7))
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col2==7)
+ s3 = Vis().traverse(s2, clone=True)
+ assert str(s3) == s3_assert
+ assert str(s2) == s2_assert
+ print str(s2)
+ print str(s3)
+ Vis().traverse(s2)
+ assert str(s2) == s3_assert
+
+ print "------------------"
+
+ s4_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col3==9)))
+ class Vis(ClauseVisitor):
+ def visit_select(self, select):
+ select.append_whereclause(t1.c.col3==9)
+ s4 = Vis().traverse(s3, clone=True)
+ print str(s3)
+ print str(s4)
+ assert str(s4) == s4_assert
+ assert str(s3) == s3_assert
+
+ print "------------------"
+ s5_assert = str(select([t1.select()], and_(t1.c.col2==7, t1.c.col1==9)))
+ class Vis(ClauseVisitor):
+ def visit_binary(self, binary):
+ if binary.left is t1.c.col3:
+ binary.left = t1.c.col1
+ binary.right = bindparam("table1_col1")
+ s5 = Vis().traverse(s4, clone=True)
+ print str(s4)
+ print str(s5)
+ assert str(s5) == s5_assert
+ assert str(s4) == s4_assert
+
+
+if __name__ == '__main__':
+ testbase.main()
\ No newline at end of file
self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={})
def testwheresubquery(self):
+ s = select([addresses.c.street], addresses.c.user_id==users.c.user_id, correlate=True).alias('s')
+ self.runtest(
+ select([users, s.c.street], from_obj=[s]),
+ """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
+
# TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet.
#self.runtest(
# table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), ""
order_by = ['dist', places.c.nm]
)
self.runtest(q, "SELECT places.id, places.nm, main_zip.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = main_zip.zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = main_zip.zipcode)) AS dist FROM places, zips AS main_zip ORDER BY dist, places.nm")
+
+ a1 = table2.alias('t2alias')
+ s1 = select([a1.c.otherid], table1.c.myid==a1.c.otherid, scalar=True)
+ j1 = table1.join(table2, table1.c.myid==table2.c.otherid)
+ s2 = select([table1, s1], from_obj=[j1])
+ self.runtest(s2, "SELECT mytable.myid, mytable.name, mytable.description, (SELECT t2alias.otherid FROM myothertable AS t2alias WHERE mytable.myid = t2alias.otherid) FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
def testlabelcomparison(self):
x = func.lala(table1.c.myid).label('foo')
s.append_column("column2")
s.append_whereclause("column1=12")
s.append_whereclause("column2=19")
- s.order_by("column1")
+ s = s.order_by("column1")
s.append_from("table1")
self.runtest(s, "SELECT column1, column2 FROM table1 WHERE column1=12 AND column2=19 ORDER BY column1")
)
- def testlateargs(self):
- """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments
- are sent"""
-
- self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'myid':'3', 'name':'jack'})
-
- self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3'})
-
- self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'})
-
def testcast(self):
tbl = table('casttest',
column('id', Integer),
def testdelete(self):
self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
-
+
+ def testcorrelateddelete(self):
+ # test a non-correlated WHERE clause
+ s = select([table2.c.othername], table2.c.otherid == 7)
+ u = delete(table1, table1.c.name==s)
+ self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = :myothertable_otherid)")
+
+ # test one that is actually correlated...
+ s = select([table2.c.othername], table2.c.otherid == table1.c.myid)
+ u = table1.delete(table1.c.name==s)
+ self.runtest(u, "DELETE FROM mytable WHERE mytable.name = (SELECT myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid)")
+
class SchemaTest(SQLTest):
def testselect(self):
# these tests will fail with the MS-SQL compiler since it will alias schema-qualified tables