From: Mike Bayer Date: Sat, 3 Nov 2007 22:13:17 +0000 (+0000) Subject: - rewritten ClauseAdapter merged from the eager_minus_join branch; this is a much... X-Git-Tag: rel_0_4_1~65 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=0af3f8f35b5e46f749d328e6fae90f6ff4915e97;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - rewritten ClauseAdapter merged from the eager_minus_join branch; this is a much simpler and "correct" version which will copy all elements exactly once, except for those which were replaced with target elements. It also can match a wider variety of target elements including joins and selects on identity alone. --- diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index f4f8aa689f..0197900a58 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -644,7 +644,6 @@ class OracleCompiler(compiler.DefaultCompiler): orderby = self.process(select._order_by_clause) if not orderby: orderby = select.oid_column - self.traverse(orderby) orderby = self.process(orderby) oldselect = select diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 660d546047..bd82f897d9 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -891,7 +891,7 @@ class Connection(Connectable): executors = { expression._Function : _execute_function, expression.ClauseElement : _execute_clauseelement, - visitors.ClauseVisitor : _execute_compiled, + Compiled : _execute_compiled, schema.SchemaItem:_execute_default, str.__mro__[-2] : _execute_text } diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 09a3a0f5b7..b6200fee52 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -800,7 +800,7 @@ class Query(object): # adapt the given WHERECLAUSE to adjust instances of this query's mapped # table to be that of our select_table, # which may be the "polymorphic" selectable used by our mapper. - sql_util.ClauseAdapter(self.table).traverse(whereclause, stop_on=util.Set([self.table])) + whereclause = sql_util.ClauseAdapter(self.table).traverse(whereclause, stop_on=util.Set([self.table])) # if extra entities, adapt the criterion to those as well for m in self._entities: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index fa4ac5a9f4..ef66ffd5a6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -87,10 +87,11 @@ OPERATORS = { operators.isnot : 'IS NOT' } -class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): +class DefaultCompiler(engine.Compiled): """Default implementation of Compiled. - Compiles ClauseElements into SQL strings. + Compiles ClauseElements into SQL strings. Uses a similar visit + paradigm as visitors.ClauseVisitor but implements its own traversal. """ __traverse_options__ = {'column_collections':False, 'entry':True} @@ -163,7 +164,9 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): if stack: self.stack.append(stack) try: - return self.traverse_single(obj, **kwargs) + meth = getattr(self, "visit_%s" % obj.__visit_name__, None) + if meth: + return meth(obj, **kwargs) finally: if stack: self.stack.pop(-1) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 49dbc143a8..67c1b727a9 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -775,7 +775,9 @@ func = _FunctionGenerator() # TODO: use UnaryExpression for this instead ? modifier = _FunctionGenerator(group=False) - +def _clone(element): + return element._clone() + def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) @@ -908,7 +910,7 @@ class ClauseElement(object): return self is other - def _copy_internals(self): + def _copy_internals(self, clone=_clone): """Reassign internal elements to be clones of themselves. Called during a copy-and-traverse operation on newly @@ -1580,8 +1582,7 @@ class FromClause(Selectable): An example would be an Alias of a Table is derived from that Table. """ - - return False + return fromclause is self def replace_selectable(self, old, alias): """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" @@ -1874,8 +1875,8 @@ class _TextClause(ClauseElement): columns = property(lambda s:[]) - def _copy_internals(self): - self.bindparams = dict([(b.key, b._clone()) for b in self.bindparams.values()]) + def _copy_internals(self, clone=_clone): + self.bindparams = dict([(b.key, clone(b)) for b in self.bindparams.values()]) def get_children(self, **kwargs): return self.bindparams.values() @@ -1933,8 +1934,8 @@ class ClauseList(ClauseElement): else: self.clauses.append(_literal_as_text(clause)) - def _copy_internals(self): - self.clauses = [clause._clone() for clause in self.clauses] + def _copy_internals(self, clone=_clone): + self.clauses = [clone(clause) for clause in self.clauses] def get_children(self, **kwargs): return self.clauses @@ -1989,8 +1990,8 @@ class _CalculatedClause(ColumnElement): key = property(lambda self:self.name or "_calc_") - def _copy_internals(self): - self.clause_expr = self.clause_expr._clone() + def _copy_internals(self, clone=_clone): + self.clause_expr = clone(self.clause_expr) def clauses(self): if isinstance(self.clause_expr, _Grouping): @@ -2038,8 +2039,8 @@ class _Function(_CalculatedClause, FromClause): key = property(lambda self:self.name) columns = property(lambda self:[self]) - def _copy_internals(self): - _CalculatedClause._copy_internals(self) + def _copy_internals(self, clone=_clone): + _CalculatedClause._copy_internals(self, clone=clone) self._clone_from_clause() def get_children(self, **kwargs): @@ -2059,9 +2060,9 @@ class _Cast(ColumnElement): self.typeclause = _TypeClause(self.type) self._distance = 0 - def _copy_internals(self): - self.clause = self.clause._clone() - self.typeclause = self.typeclause._clone() + def _copy_internals(self, clone=_clone): + self.clause = clone(self.clause) + self.typeclause = clone(self.typeclause) def get_children(self, **kwargs): return self.clause, self.typeclause @@ -2092,8 +2093,8 @@ class _UnaryExpression(ColumnElement): def _get_from_objects(self, **modifiers): return self.element._get_from_objects(**modifiers) - def _copy_internals(self): - self.element = self.element._clone() + def _copy_internals(self, clone=_clone): + self.element = clone(self.element) def get_children(self, **kwargs): return self.element, @@ -2134,9 +2135,9 @@ class _BinaryExpression(ColumnElement): def _get_from_objects(self, **modifiers): return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers) - def _copy_internals(self): - self.left = self.left._clone() - self.right = self.right._clone() + def _copy_internals(self, clone=_clone): + self.left = clone(self.left) + self.right = clone(self.right) def get_children(self, **kwargs): return self.left, self.right @@ -2265,11 +2266,11 @@ class Join(FromClause): self._foreign_keys.add(f) return column - def _copy_internals(self): + def _copy_internals(self, clone=_clone): self._clone_from_clause() - self.left = self.left._clone() - self.right = self.right._clone() - self.onclause = self.onclause._clone() + self.left = clone(self.left) + self.right = clone(self.right) + self.onclause = clone(self.onclause) self.__folded_equivalents = None self._init_primary_key() @@ -2414,15 +2415,7 @@ class Alias(FromClause): self.oid_column = None def is_derived_from(self, fromclause): - x = self.selectable - while True: - if x is fromclause: - return True - if isinstance(x, Alias): - x = x.selectable - else: - break - return False + return self.selectable.is_derived_from(fromclause) def supports_execution(self): return self.original.supports_execution() @@ -2437,13 +2430,12 @@ class Alias(FromClause): #return self.selectable._exportable_columns() return self.selectable.columns - def _copy_internals(self): - self._clone_from_clause() - self.selectable = self.selectable._clone() - baseselectable = self.selectable - while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable - self.original = baseselectable + def _clone(self): + # Alias is immutable + return self + + def _copy_internals(self, clone=_clone): + pass def get_children(self, **kwargs): for c in self.c: @@ -2469,8 +2461,8 @@ class _ColumnElementAdapter(ColumnElement): key = property(lambda s: s.elem.key) _label = property(lambda s: s.elem._label) - def _copy_internals(self): - self.elem = self.elem._clone() + def _copy_internals(self, clone=_clone): + self.elem = clone(self.elem) def get_children(self, **kwargs): return self.elem, @@ -2503,8 +2495,8 @@ class _FromGrouping(FromClause): def _hide_froms(self, **modifiers): return self.elem._hide_froms(**modifiers) - def _copy_internals(self): - self.elem = self.elem._clone() + def _copy_internals(self, clone=_clone): + self.elem = clone(self.elem) def _get_from_objects(self, **modifiers): return self.elem._get_from_objects(**modifiers) @@ -2538,8 +2530,8 @@ class _Label(ColumnElement): def expression_element(self): return self.obj - def _copy_internals(self): - self.obj = self.obj._clone() + def _copy_internals(self, clone=_clone): + self.obj = clone(self.obj) def get_children(self, **kwargs): return self.obj, @@ -2935,13 +2927,13 @@ class CompoundSelect(_SelectBaseMixin, FromClause): col.orig_set = colset return col - def _copy_internals(self): + def _copy_internals(self, clone=_clone): self._clone_from_clause() self._col_map = {} - self.selects = [s._clone() for s in self.selects] + self.selects = [clone(s) for s in self.selects] for attr in ('_order_by_clause', '_group_by_clause'): if getattr(self, attr) is not None: - setattr(self, attr, getattr(self, attr)._clone()) + setattr(self, attr, clone(getattr(self, attr))) def get_children(self, column_collections=True, **kwargs): return (column_collections and list(self.c) or []) + \ @@ -3091,13 +3083,19 @@ class Select(_SelectBaseMixin, FromClause): inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""") - def _copy_internals(self): + def is_derived_from(self, fromclause): + for f in self.locate_all_froms(): + if f.is_derived_from(fromclause): + return True + return False + + def _copy_internals(self, clone=_clone): self._clone_from_clause() - self._raw_columns = [c._clone() for c in self._raw_columns] - self._recorrelate_froms([(f, f._clone()) for f in self._froms]) + self._raw_columns = [clone(c) for c in self._raw_columns] + self._recorrelate_froms([(f, clone(f)) for f in self._froms]) for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): if getattr(self, attr) is not None: - setattr(self, attr, getattr(self, attr)._clone()) + setattr(self, attr, clone(getattr(self, attr))) def get_children(self, column_collections=True, **kwargs): """return child elements as per the ClauseElement specification.""" @@ -3394,7 +3392,7 @@ class Insert(_UpdateBase): else: return () - def _copy_internals(self): + def _copy_internals(self, clone=_clone): self.parameters = self.parameters.copy() def values(self, v): @@ -3423,8 +3421,8 @@ class Update(_UpdateBase): else: return () - def _copy_internals(self): - self._whereclause = self._whereclause._clone() + def _copy_internals(self, clone=_clone): + self._whereclause = clone(self._whereclause) self.parameters = self.parameters.copy() def values(self, v): @@ -3449,8 +3447,8 @@ class Delete(_UpdateBase): else: return () - def _copy_internals(self): - self._whereclause = self._whereclause._clone() + def _copy_internals(self, clone=_clone): + self._whereclause = clone(self._whereclause) class _IdentifiedClause(ClauseElement): def __init__(self, ident): diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 8876f42baa..eed06cfc32 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -3,7 +3,6 @@ from sqlalchemy.sql import expression, visitors """Utility functions that build upon SQL and Schema constructs.""" - class TableCollection(object): def __init__(self, tables=None): self.tables = tables or [] @@ -110,87 +109,86 @@ class ColumnsInClause(visitors.ClauseVisitor): if self.selectable.c.get(column.key) is column: self.result = True -class AbstractClauseProcessor(visitors.NoColumnVisitor): - """Traverse a clause and attempt to convert the contents of container elements - to a converted element. - - The conversion operation is defined by subclasses. +class AbstractClauseProcessor(object): + """Traverse and copy a ClauseElement, replacing selected elements based on rules. + + This class implements its own visit-and-copy strategy but maintains the + same public interface as visitors.ClauseVisitor. """ - + + __traverse_options__ = {'column_collections':False} + def convert_element(self, elem): """Define the *conversion* method for this ``AbstractClauseProcessor``.""" raise NotImplementedError() - def copy_and_process(self, list_): - """Copy the container elements in the given list to a new list and - process the new list. - """ - + def chain(self, visitor): + # chaining AbstractClauseProcessor and other ClauseVisitor + # objects separately. All the ACP objects are chained on + # their convert_element() method whereas regular visitors + # chain on their visit_XXX methods. + if isinstance(visitor, AbstractClauseProcessor): + attr = '_next_acp' + else: + attr = '_next' + + tail = self + while getattr(tail, attr, None) is not None: + tail = getattr(tail, attr) + setattr(tail, attr, visitor) + return self + + def copy_and_process(self, list_, stop_on=None): + """Copy the given list to a new list, with each element traversed individually.""" + list_ = list(list_) - self.process_list(list_) + stop_on = util.Set() + for i in range(0, len(list_)): + list_[i] = self.traverse(list_[i], stop_on=stop_on) return list_ - def process_list(self, list_): - """Process all elements of the given list in-place.""" - - for i in range(0, len(list_)): - elem = self.convert_element(list_[i]) - if elem is not None: - list_[i] = elem - else: - list_[i] = self.traverse(list_[i], clone=True) - - def visit_grouping(self, grouping): - elem = self.convert_element(grouping.elem) - if elem is not None: - grouping.elem = elem + def _convert_element(self, elem, stop_on): + v = self + while v is not None: + newelem = v.convert_element(elem) + if newelem: + stop_on.add(newelem) + return newelem + v = getattr(v, '_next_acp', None) + return elem._clone() + + def traverse(self, elem, clone=True, stop_on=None, _clone_toplevel=True): + if not clone: + raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True") - def visit_clauselist(self, clist): - for i in range(0, len(clist.clauses)): - n = self.convert_element(clist.clauses[i]) - if n is not None: - clist.clauses[i] = n - - def visit_unary(self, unary): - elem = self.convert_element(unary.element) - if elem is not None: - unary.element = elem + if stop_on is None: + stop_on = util.Set() - def visit_binary(self, binary): - elem = self.convert_element(binary.left) - if elem is not None: - binary.left = elem - elem = self.convert_element(binary.right) - if elem is not None: - binary.right = elem - - def visit_join(self, join): - elem = self.convert_element(join.left) - if elem is not None: - join.left = elem - elem = self.convert_element(join.right) - if elem is not None: - join.right = elem - join._init_primary_key() + if elem in stop_on: + return elem + + if _clone_toplevel: + elem = self._convert_element(elem, stop_on) + if elem in stop_on: + return elem - def visit_select(self, select): - fr = util.OrderedSet() - for elem in select._froms: - n = self.convert_element(elem) - if n is not None: - fr.add((elem, n)) - select._recorrelate_froms(fr) - - col = [] - for elem in select._raw_columns: - n = self.convert_element(elem) - if n is None: - col.append(elem) - else: - col.append(n) - select._raw_columns = col - + def clone(element): + return self._convert_element(element, stop_on) + elem._copy_internals(clone=clone) + + v = getattr(self, '_next', None) + while v is not None: + meth = getattr(v, "visit_%s" % elem.__visit_name__, None) + if meth: + meth(elem) + v = getattr(v, '_next', None) + + for e in elem.get_children(**self.__traverse_options__): + if e not in stop_on: + self.traverse(e, stop_on=stop_on, _clone_toplevel=False) + return elem + class ClauseAdapter(AbstractClauseProcessor): """Given a clause (like as in a WHERE criterion), locate columns which are embedded within a given selectable, and changes those @@ -243,9 +241,6 @@ class ClauseAdapter(AbstractClauseProcessor): newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False) if newcol: return newcol - #if newcol is None: - # self.traverse(col) - # return col return newcol diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 98e4de6c33..bf15c2b7ee 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,8 +1,7 @@ class ClauseVisitor(object): - """A class that knows how to traverse and visit - ``ClauseElements``. + """Traverses and visits ``ClauseElement`` structures. - Calls visit_XXX() methods dynamically generated for each particualr + Calls visit_XXX() methods dynamically generated for each particular ``ClauseElement`` subclass encountered. Traversal of a hierarchy of ``ClauseElements`` is achieved via the ``traverse()`` method, which is passed the lead @@ -40,7 +39,7 @@ class ClauseVisitor(object): traversal.insert(0, t) for c in t.get_children(**self.__traverse_options__): stack.append(c) - + def traverse(self, obj, stop_on=None, clone=False): if clone: obj = obj._clone() @@ -75,13 +74,10 @@ class ClauseVisitor(object): return self class NoColumnVisitor(ClauseVisitor): - """a ClauseVisitor that will not traverse the exported Column - collections on Table, Alias, Select, and CompoundSelect objects - (i.e. their 'columns' or 'c' attribute). + """ClauseVisitor with 'column_collections' set to False; will not + traverse the front-facing Column collections on Table, Alias, Select, + and CompoundSelect objects. - this is useful because most traversals don't need those columns, or - in the case of DefaultCompiler it traverses them explicitly; so - skipping their traversal here greatly cuts down on method call overhead. """ __traverse_options__ = {'column_collections':False} diff --git a/test/profiling/compiler.py b/test/profiling/compiler.py index 470b45cb89..29e17db778 100644 --- a/test/profiling/compiler.py +++ b/test/profiling/compiler.py @@ -24,7 +24,7 @@ class CompileTest(AssertMixin): t1.update().compile() # TODO: this is alittle high - @profiling.profiled('ctest_select', call_range=(190, 210), always=True) + @profiling.profiled('ctest_select', call_range=(170, 200), always=True) def test_select(self): s = select([t1], t1.c.c2==t2.c.c1) s.compile() diff --git a/test/profiling/zoomark.py b/test/profiling/zoomark.py index 6eb313b429..833aec46c6 100644 --- a/test/profiling/zoomark.py +++ b/test/profiling/zoomark.py @@ -126,7 +126,7 @@ class ZooMarkTest(testing.AssertMixin): tick = i.execute(Species='Tick', Name='Tick %d' % x, Legs=8) @testing.supported('postgres') - @profiling.profiled('properties', call_range=(3030, 3430), always=True) + @profiling.profiled('properties', call_range=(2900, 3330), always=True) def test_3_properties(self): Zoo = metadata.tables['Zoo'] Animal = metadata.tables['Animal'] @@ -149,7 +149,7 @@ class ZooMarkTest(testing.AssertMixin): ticks = fullobject(Animal.select(Animal.c.Species=='Tick')) @testing.supported('postgres') - @profiling.profiled('expressions', call_range=(11350, 13200), always=True) + @profiling.profiled('expressions', call_range=(10350, 12200), always=True) def test_4_expressions(self): Zoo = metadata.tables['Zoo'] Animal = metadata.tables['Animal'] @@ -203,7 +203,7 @@ class ZooMarkTest(testing.AssertMixin): assert len(fulltable(Animal.select(func.date_part('day', Animal.c.LastEscape) == 21))) == 1 @testing.supported('postgres') - @profiling.profiled('aggregates', call_range=(1000, 1270), always=True) + @profiling.profiled('aggregates', call_range=(960, 1170), always=True) def test_5_aggregates(self): Animal = metadata.tables['Animal'] Zoo = metadata.tables['Zoo'] @@ -245,7 +245,7 @@ class ZooMarkTest(testing.AssertMixin): legs.sort() @testing.supported('postgres') - @profiling.profiled('editing', call_range=(1280, 1390), always=True) + @profiling.profiled('editing', call_range=(1200, 1290), always=True) def test_6_editing(self): Zoo = metadata.tables['Zoo'] @@ -274,7 +274,7 @@ class ZooMarkTest(testing.AssertMixin): assert SDZ['Founded'] == datetime.date(1935, 9, 13) @testing.supported('postgres') - @profiling.profiled('multiview', call_range=(2820, 3155), always=True) + @profiling.profiled('multiview', call_range=(2720, 3055), always=True) def test_7_multiview(self): Zoo = metadata.tables['Zoo'] Animal = metadata.tables['Animal'] diff --git a/test/sql/generative.py b/test/sql/generative.py index 437f0874ef..7892732d6a 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -4,6 +4,8 @@ from sqlalchemy.sql import table, column, ClauseElement from testlib import * from sqlalchemy.sql.visitors import * from sqlalchemy import util +from sqlalchemy.sql import util as sql_util + class TraversalTest(AssertMixin): """test ClauseVisitor's traversal, particularly its ability to copy and modify @@ -133,7 +135,8 @@ class TraversalTest(AssertMixin): s3 = vis2.traverse(struct, clone=True) assert struct != s3 assert struct3 == s3 - + + class ClauseTest(SQLCompileTest): """test copy-in-place behavior of various ClauseElements.""" @@ -230,7 +233,6 @@ class ClauseTest(SQLCompileTest): self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2") def test_clause_adapter(self): - from sqlalchemy.sql import util as sql_util t1alias = t1.alias('t1alias') @@ -257,7 +259,47 @@ class ClauseTest(SQLCompileTest): self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2") self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2") + + def test_joins(self): + """test that ClauseAdapter can target a Join object, replace it, and not dig into the sub-joins after + replacing.""" + + metadata = MetaData() + a = Table('a', metadata, + Column('id', Integer, primary_key=True)) + b = Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + c = Table('c', metadata, + Column('id', Integer, primary_key=True), + Column('bid', Integer, ForeignKey('b.id')), + ) + + d = Table('d', metadata, + Column('id', Integer, primary_key=True), + Column('aid', Integer, ForeignKey('a.id')), + ) + + j1 = a.outerjoin(b) + j2 = select([j1], use_labels=True) + + j3 = c.join(j2, j2.c.b_id==c.c.bid) + + j4 = j3.outerjoin(d) + self.assert_compile(j4, "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) " + "ON b_id = c.bid" + " LEFT OUTER JOIN d ON a_id = d.aid") + j5 = j3.alias('foo') + j6 = sql_util.ClauseAdapter(j5).copy_and_process([j4])[0] + # this statement takes c join(a join b), wraps it inside an aliased "select * from c join(a join b) AS foo". + # the outermost right side "left outer join d" stays the same, except "d" joins against foo.a_id instead + # of plain "a_id" + self.assert_compile(j6, "(SELECT c.id AS c_id, c.bid AS c_bid, a_id AS a_id, b_id AS b_id, b_aid AS b_aid FROM " + "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) " + "ON b_id = c.bid) AS foo" + " LEFT OUTER JOIN d ON foo.a_id = d.aid") class SelectTest(SQLCompileTest): diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 83203ad8e2..72f5f35d04 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -230,6 +230,39 @@ class PrimaryKeyTest(AssertMixin): assert str(j) == "a JOIN b ON a.id = b.id AND b.x = :b_x", str(j) assert list(j.primary_key) == [a.c.id, b.c.x] +class DerivedTest(AssertMixin): + def test_table(self): + meta = MetaData() + t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + + assert t1.is_derived_from(t1) + assert not t2.is_derived_from(t1) + + def test_alias(self): + meta = MetaData() + t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + + assert t1.alias().is_derived_from(t1) + assert not t2.alias().is_derived_from(t1) + assert not t1.is_derived_from(t1.alias()) + assert not t1.is_derived_from(t2.alias()) + + def test_select(self): + meta = MetaData() + t1 = Table('t1', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + t2 = Table('t2', meta, Column('c1', Integer, primary_key=True), Column('c2', String(30))) + + assert t1.select().is_derived_from(t1) + assert not t2.select().is_derived_from(t1) + + assert select([t1, t2]).is_derived_from(t1) + + assert t1.select().alias('foo').is_derived_from(t1) + assert select([t1, t2]).alias('foo').is_derived_from(t1) + assert not t2.select().alias('foo').is_derived_from(t1) + if __name__ == "__main__": testbase.main()