and "correct" version which will copy all elements exactly once, except for those which were
replaced with target elements. It also can match a wider variety of target elements including
joins and selects on identity alone.
orderby = self.process(select._order_by_clause)
if not orderby:
orderby = select.oid_column
- self.traverse(orderby)
orderby = self.process(orderby)
oldselect = select
executors = {
expression._Function : _execute_function,
expression.ClauseElement : _execute_clauseelement,
- visitors.ClauseVisitor : _execute_compiled,
+ Compiled : _execute_compiled,
schema.SchemaItem:_execute_default,
str.__mro__[-2] : _execute_text
}
# adapt the given WHERECLAUSE to adjust instances of this query's mapped
# table to be that of our select_table,
# which may be the "polymorphic" selectable used by our mapper.
- sql_util.ClauseAdapter(self.table).traverse(whereclause, stop_on=util.Set([self.table]))
+ whereclause = sql_util.ClauseAdapter(self.table).traverse(whereclause, stop_on=util.Set([self.table]))
# if extra entities, adapt the criterion to those as well
for m in self._entities:
operators.isnot : 'IS NOT'
}
-class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
+class DefaultCompiler(engine.Compiled):
"""Default implementation of Compiled.
- Compiles ClauseElements into SQL strings.
+ Compiles ClauseElements into SQL strings. Uses a similar visit
+ paradigm as visitors.ClauseVisitor but implements its own traversal.
"""
__traverse_options__ = {'column_collections':False, 'entry':True}
if stack:
self.stack.append(stack)
try:
- return self.traverse_single(obj, **kwargs)
+ meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj, **kwargs)
finally:
if stack:
self.stack.pop(-1)
# TODO: use UnaryExpression for this instead ?
modifier = _FunctionGenerator(group=False)
-
+def _clone(element):
+ return element._clone()
+
def _compound_select(keyword, *selects, **kwargs):
return CompoundSelect(keyword, *selects, **kwargs)
return self is other
- def _copy_internals(self):
+ def _copy_internals(self, clone=_clone):
"""Reassign internal elements to be clones of themselves.
Called during a copy-and-traverse operation on newly
An example would be an Alias of a Table is derived from that Table.
"""
-
- return False
+ return fromclause is self
def replace_selectable(self, old, alias):
"""replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
columns = property(lambda s:[])
- def _copy_internals(self):
- self.bindparams = dict([(b.key, b._clone()) for b in self.bindparams.values()])
+ def _copy_internals(self, clone=_clone):
+ self.bindparams = dict([(b.key, clone(b)) for b in self.bindparams.values()])
def get_children(self, **kwargs):
return self.bindparams.values()
else:
self.clauses.append(_literal_as_text(clause))
- def _copy_internals(self):
- self.clauses = [clause._clone() for clause in self.clauses]
+ def _copy_internals(self, clone=_clone):
+ self.clauses = [clone(clause) for clause in self.clauses]
def get_children(self, **kwargs):
return self.clauses
key = property(lambda self:self.name or "_calc_")
- def _copy_internals(self):
- self.clause_expr = self.clause_expr._clone()
+ def _copy_internals(self, clone=_clone):
+ self.clause_expr = clone(self.clause_expr)
def clauses(self):
if isinstance(self.clause_expr, _Grouping):
key = property(lambda self:self.name)
columns = property(lambda self:[self])
- def _copy_internals(self):
- _CalculatedClause._copy_internals(self)
+ def _copy_internals(self, clone=_clone):
+ _CalculatedClause._copy_internals(self, clone=clone)
self._clone_from_clause()
def get_children(self, **kwargs):
self.typeclause = _TypeClause(self.type)
self._distance = 0
- def _copy_internals(self):
- self.clause = self.clause._clone()
- self.typeclause = self.typeclause._clone()
+ def _copy_internals(self, clone=_clone):
+ self.clause = clone(self.clause)
+ self.typeclause = clone(self.typeclause)
def get_children(self, **kwargs):
return self.clause, self.typeclause
def _get_from_objects(self, **modifiers):
return self.element._get_from_objects(**modifiers)
- def _copy_internals(self):
- self.element = self.element._clone()
+ def _copy_internals(self, clone=_clone):
+ self.element = clone(self.element)
def get_children(self, **kwargs):
return self.element,
def _get_from_objects(self, **modifiers):
return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
- def _copy_internals(self):
- self.left = self.left._clone()
- self.right = self.right._clone()
+ def _copy_internals(self, clone=_clone):
+ self.left = clone(self.left)
+ self.right = clone(self.right)
def get_children(self, **kwargs):
return self.left, self.right
self._foreign_keys.add(f)
return column
- def _copy_internals(self):
+ def _copy_internals(self, clone=_clone):
self._clone_from_clause()
- self.left = self.left._clone()
- self.right = self.right._clone()
- self.onclause = self.onclause._clone()
+ self.left = clone(self.left)
+ self.right = clone(self.right)
+ self.onclause = clone(self.onclause)
self.__folded_equivalents = None
self._init_primary_key()
self.oid_column = None
def is_derived_from(self, fromclause):
- x = self.selectable
- while True:
- if x is fromclause:
- return True
- if isinstance(x, Alias):
- x = x.selectable
- else:
- break
- return False
+ return self.selectable.is_derived_from(fromclause)
def supports_execution(self):
return self.original.supports_execution()
#return self.selectable._exportable_columns()
return self.selectable.columns
- def _copy_internals(self):
- self._clone_from_clause()
- self.selectable = self.selectable._clone()
- baseselectable = self.selectable
- while isinstance(baseselectable, Alias):
- baseselectable = baseselectable.selectable
- self.original = baseselectable
+ def _clone(self):
+ # Alias is immutable
+ return self
+
+ def _copy_internals(self, clone=_clone):
+ pass
def get_children(self, **kwargs):
for c in self.c:
key = property(lambda s: s.elem.key)
_label = property(lambda s: s.elem._label)
- def _copy_internals(self):
- self.elem = self.elem._clone()
+ def _copy_internals(self, clone=_clone):
+ self.elem = clone(self.elem)
def get_children(self, **kwargs):
return self.elem,
def _hide_froms(self, **modifiers):
return self.elem._hide_froms(**modifiers)
- def _copy_internals(self):
- self.elem = self.elem._clone()
+ def _copy_internals(self, clone=_clone):
+ self.elem = clone(self.elem)
def _get_from_objects(self, **modifiers):
return self.elem._get_from_objects(**modifiers)
def expression_element(self):
return self.obj
- def _copy_internals(self):
- self.obj = self.obj._clone()
+ def _copy_internals(self, clone=_clone):
+ self.obj = clone(self.obj)
def get_children(self, **kwargs):
return self.obj,
col.orig_set = colset
return col
- def _copy_internals(self):
+ def _copy_internals(self, clone=_clone):
self._clone_from_clause()
self._col_map = {}
- self.selects = [s._clone() for s in self.selects]
+ self.selects = [clone(s) for s in self.selects]
for attr in ('_order_by_clause', '_group_by_clause'):
if getattr(self, attr) is not None:
- setattr(self, attr, getattr(self, attr)._clone())
+ setattr(self, attr, clone(getattr(self, attr)))
def get_children(self, column_collections=True, **kwargs):
return (column_collections and list(self.c) or []) + \
inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""")
- def _copy_internals(self):
+ def is_derived_from(self, fromclause):
+ for f in self.locate_all_froms():
+ if f.is_derived_from(fromclause):
+ return True
+ return False
+
+ def _copy_internals(self, clone=_clone):
self._clone_from_clause()
- self._raw_columns = [c._clone() for c in self._raw_columns]
- self._recorrelate_froms([(f, f._clone()) for f in self._froms])
+ self._raw_columns = [clone(c) for c in self._raw_columns]
+ self._recorrelate_froms([(f, clone(f)) for f in self._froms])
for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
if getattr(self, attr) is not None:
- setattr(self, attr, getattr(self, attr)._clone())
+ setattr(self, attr, clone(getattr(self, attr)))
def get_children(self, column_collections=True, **kwargs):
"""return child elements as per the ClauseElement specification."""
else:
return ()
- def _copy_internals(self):
+ def _copy_internals(self, clone=_clone):
self.parameters = self.parameters.copy()
def values(self, v):
else:
return ()
- def _copy_internals(self):
- self._whereclause = self._whereclause._clone()
+ def _copy_internals(self, clone=_clone):
+ self._whereclause = clone(self._whereclause)
self.parameters = self.parameters.copy()
def values(self, v):
else:
return ()
- def _copy_internals(self):
- self._whereclause = self._whereclause._clone()
+ def _copy_internals(self, clone=_clone):
+ self._whereclause = clone(self._whereclause)
class _IdentifiedClause(ClauseElement):
def __init__(self, ident):
"""Utility functions that build upon SQL and Schema constructs."""
-
class TableCollection(object):
def __init__(self, tables=None):
self.tables = tables or []
if self.selectable.c.get(column.key) is column:
self.result = True
-class AbstractClauseProcessor(visitors.NoColumnVisitor):
- """Traverse a clause and attempt to convert the contents of container elements
- to a converted element.
-
- The conversion operation is defined by subclasses.
+class AbstractClauseProcessor(object):
+ """Traverse and copy a ClauseElement, replacing selected elements based on rules.
+
+ This class implements its own visit-and-copy strategy but maintains the
+ same public interface as visitors.ClauseVisitor.
"""
-
+
+ __traverse_options__ = {'column_collections':False}
+
def convert_element(self, elem):
"""Define the *conversion* method for this ``AbstractClauseProcessor``."""
raise NotImplementedError()
- def copy_and_process(self, list_):
- """Copy the container elements in the given list to a new list and
- process the new list.
- """
-
+ def chain(self, visitor):
+ # chaining AbstractClauseProcessor and other ClauseVisitor
+ # objects separately. All the ACP objects are chained on
+ # their convert_element() method whereas regular visitors
+ # chain on their visit_XXX methods.
+ if isinstance(visitor, AbstractClauseProcessor):
+ attr = '_next_acp'
+ else:
+ attr = '_next'
+
+ tail = self
+ while getattr(tail, attr, None) is not None:
+ tail = getattr(tail, attr)
+ setattr(tail, attr, visitor)
+ return self
+
+ def copy_and_process(self, list_, stop_on=None):
+ """Copy the given list to a new list, with each element traversed individually."""
+
list_ = list(list_)
- self.process_list(list_)
+ stop_on = util.Set()
+ for i in range(0, len(list_)):
+ list_[i] = self.traverse(list_[i], stop_on=stop_on)
return list_
- def process_list(self, list_):
- """Process all elements of the given list in-place."""
-
- for i in range(0, len(list_)):
- elem = self.convert_element(list_[i])
- if elem is not None:
- list_[i] = elem
- else:
- list_[i] = self.traverse(list_[i], clone=True)
-
- def visit_grouping(self, grouping):
- elem = self.convert_element(grouping.elem)
- if elem is not None:
- grouping.elem = elem
+ def _convert_element(self, elem, stop_on):
+ v = self
+ while v is not None:
+ newelem = v.convert_element(elem)
+ if newelem:
+ stop_on.add(newelem)
+ return newelem
+ v = getattr(v, '_next_acp', None)
+ return elem._clone()
+
+ def traverse(self, elem, clone=True, stop_on=None, _clone_toplevel=True):
+ if not clone:
+ raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
- def visit_clauselist(self, clist):
- for i in range(0, len(clist.clauses)):
- n = self.convert_element(clist.clauses[i])
- if n is not None:
- clist.clauses[i] = n
-
- def visit_unary(self, unary):
- elem = self.convert_element(unary.element)
- if elem is not None:
- unary.element = elem
+ if stop_on is None:
+ stop_on = util.Set()
- def visit_binary(self, binary):
- elem = self.convert_element(binary.left)
- if elem is not None:
- binary.left = elem
- elem = self.convert_element(binary.right)
- if elem is not None:
- binary.right = elem
-
- def visit_join(self, join):
- elem = self.convert_element(join.left)
- if elem is not None:
- join.left = elem
- elem = self.convert_element(join.right)
- if elem is not None:
- join.right = elem
- join._init_primary_key()
+ if elem in stop_on:
+ return elem
+
+ if _clone_toplevel:
+ elem = self._convert_element(elem, stop_on)
+ if elem in stop_on:
+ return elem
- def visit_select(self, select):
- fr = util.OrderedSet()
- for elem in select._froms:
- n = self.convert_element(elem)
- if n is not None:
- fr.add((elem, n))
- select._recorrelate_froms(fr)
-
- col = []
- for elem in select._raw_columns:
- n = self.convert_element(elem)
- if n is None:
- col.append(elem)
- else:
- col.append(n)
- select._raw_columns = col
-
+ def clone(element):
+ return self._convert_element(element, stop_on)
+ elem._copy_internals(clone=clone)
+
+ v = getattr(self, '_next', None)
+ while v is not None:
+ meth = getattr(v, "visit_%s" % elem.__visit_name__, None)
+ if meth:
+ meth(elem)
+ v = getattr(v, '_next', None)
+
+ for e in elem.get_children(**self.__traverse_options__):
+ if e not in stop_on:
+ self.traverse(e, stop_on=stop_on, _clone_toplevel=False)
+ return elem
+
class ClauseAdapter(AbstractClauseProcessor):
"""Given a clause (like as in a WHERE criterion), locate columns
which are embedded within a given selectable, and changes those
newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False)
if newcol:
return newcol
- #if newcol is None:
- # self.traverse(col)
- # return col
return newcol
class ClauseVisitor(object):
- """A class that knows how to traverse and visit
- ``ClauseElements``.
+ """Traverses and visits ``ClauseElement`` structures.
- Calls visit_XXX() methods dynamically generated for each particualr
+ Calls visit_XXX() methods dynamically generated for each particular
``ClauseElement`` subclass encountered. Traversal of a
hierarchy of ``ClauseElements`` is achieved via the
``traverse()`` method, which is passed the lead
traversal.insert(0, t)
for c in t.get_children(**self.__traverse_options__):
stack.append(c)
-
+
def traverse(self, obj, stop_on=None, clone=False):
if clone:
obj = obj._clone()
return self
class NoColumnVisitor(ClauseVisitor):
- """a ClauseVisitor that will not traverse the exported Column
- collections on Table, Alias, Select, and CompoundSelect objects
- (i.e. their 'columns' or 'c' attribute).
+ """ClauseVisitor with 'column_collections' set to False; will not
+ traverse the front-facing Column collections on Table, Alias, Select,
+ and CompoundSelect objects.
- this is useful because most traversals don't need those columns, or
- in the case of DefaultCompiler it traverses them explicitly; so
- skipping their traversal here greatly cuts down on method call overhead.
"""
__traverse_options__ = {'column_collections':False}
t1.update().compile()
# TODO: this is alittle high
- @profiling.profiled('ctest_select', call_range=(190, 210), always=True)
+ @profiling.profiled('ctest_select', call_range=(170, 200), always=True)
def test_select(self):
s = select([t1], t1.c.c2==t2.c.c1)
s.compile()
tick = i.execute(Species='Tick', Name='Tick %d' % x, Legs=8)
@testing.supported('postgres')
- @profiling.profiled('properties', call_range=(3030, 3430), always=True)
+ @profiling.profiled('properties', call_range=(2900, 3330), always=True)
def test_3_properties(self):
Zoo = metadata.tables['Zoo']
Animal = metadata.tables['Animal']
ticks = fullobject(Animal.select(Animal.c.Species=='Tick'))
@testing.supported('postgres')
- @profiling.profiled('expressions', call_range=(11350, 13200), always=True)
+ @profiling.profiled('expressions', call_range=(10350, 12200), always=True)
def test_4_expressions(self):
Zoo = metadata.tables['Zoo']
Animal = metadata.tables['Animal']
assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1
@testing.supported('postgres')
- @profiling.profiled('aggregates', call_range=(1000, 1270), always=True)
+ @profiling.profiled('aggregates', call_range=(960, 1170), always=True)
def test_5_aggregates(self):
Animal = metadata.tables['Animal']
Zoo = metadata.tables['Zoo']
legs.sort()
@testing.supported('postgres')
- @profiling.profiled('editing', call_range=(1280, 1390), always=True)
+ @profiling.profiled('editing', call_range=(1200, 1290), always=True)
def test_6_editing(self):
Zoo = metadata.tables['Zoo']
assert SDZ['Founded'] == datetime.date(1935, 9, 13)
@testing.supported('postgres')
- @profiling.profiled('multiview', call_range=(2820, 3155), always=True)
+ @profiling.profiled('multiview', call_range=(2720, 3055), always=True)
def test_7_multiview(self):
Zoo = metadata.tables['Zoo']
Animal = metadata.tables['Animal']
from testlib import *
from sqlalchemy.sql.visitors import *
from sqlalchemy import util
+from sqlalchemy.sql import util as sql_util
+
class TraversalTest(AssertMixin):
"""test ClauseVisitor's traversal, particularly its ability to copy and modify
s3 = vis2.traverse(struct, clone=True)
assert struct != s3
assert struct3 == s3
-
+
+
class ClauseTest(SQLCompileTest):
"""test copy-in-place behavior of various ClauseElements."""
self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2")
def test_clause_adapter(self):
- from sqlalchemy.sql import util as sql_util
t1alias = t1.alias('t1alias')
self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+
+ def test_joins(self):
+ """test that ClauseAdapter can target a Join object, replace it, and not dig into the sub-joins after
+ replacing."""
+
+ metadata = MetaData()
+ a = Table('a', metadata,
+ Column('id', Integer, primary_key=True))
+ b = Table('b', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('aid', Integer, ForeignKey('a.id')),
+ )
+ c = Table('c', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('bid', Integer, ForeignKey('b.id')),
+ )
+
+ d = Table('d', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('aid', Integer, ForeignKey('a.id')),
+ )
+
+ j1 = a.outerjoin(b)
+ j2 = select([j1], use_labels=True)
+
+ j3 = c.join(j2, j2.c.b_id==c.c.bid)
+
+ j4 = j3.outerjoin(d)
+ self.assert_compile(j4, "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) "
+ "ON b_id = c.bid"
+ " LEFT OUTER JOIN d ON a_id = d.aid")
+ j5 = j3.alias('foo')
+ j6 = sql_util.ClauseAdapter(j5).copy_and_process([j4])[0]
+ # this statement takes c join(a join b), wraps it inside an aliased "select * from c join(a join b) AS foo".
+ # the outermost right side "left outer join d" stays the same, except "d" joins against foo.a_id instead
+ # of plain "a_id"
+ self.assert_compile(j6, "(SELECT c.id AS c_id, c.bid AS c_bid, a_id AS a_id, b_id AS b_id, b_aid AS b_aid FROM "
+ "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) "
+ "ON b_id = c.bid) AS foo"
+ " LEFT OUTER JOIN d ON foo.a_id = d.aid")
class SelectTest(SQLCompileTest):
assert str(j) == "a JOIN b ON a.id = b.id AND b.x = :b_x", str(j)
assert list(j.primary_key) == [a.c.id, b.c.x]
+class DerivedTest(AssertMixin):
+ def test_table(self):
+ meta = MetaData()
+ t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+ t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+
+ assert t1.is_derived_from(t1)
+ assert not t2.is_derived_from(t1)
+
+ def test_alias(self):
+ meta = MetaData()
+ t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+ t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+
+ assert t1.alias().is_derived_from(t1)
+ assert not t2.alias().is_derived_from(t1)
+ assert not t1.is_derived_from(t1.alias())
+ assert not t1.is_derived_from(t2.alias())
+
+ def test_select(self):
+ meta = MetaData()
+ t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+ t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30)))
+
+ assert t1.select().is_derived_from(t1)
+ assert not t2.select().is_derived_from(t1)
+
+ assert select([t1, t2]).is_derived_from(t1)
+
+ assert t1.select().alias('foo').is_derived_from(t1)
+ assert select([t1, t2]).alias('foo').is_derived_from(t1)
+ assert not t2.select().alias('foo').is_derived_from(t1)
+
if __name__ == "__main__":
testbase.main()