then re-used. the FROM calculation of a Select normalizes the list of hide_froms against all
previous incarnations of each FROM clause, using a tag attached from cloned clause to
previous.
"""
c = self.__class__.__new__(self.__class__)
c.__dict__ = self.__dict__.copy()
+
+ # this is a marker that helps to "equate" clauses to each other
+ # when a Select returns its list of FROM clauses. the cloning
+ # process leaves around a lot of remnants of the previous clause
+ # typically in the form of column expressions still attached to the
+ # old table.
+ c._is_clone_of = self
+
return c
def _get_from_objects(self, **modifiers):
self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit])
def description(self):
- return "Join object on %s and %s" % (self.left.description, self.right.description)
+ return "Join object on %s(%d) and %s(%d)" % (self.left.description, id(self.left), self.right.description, id(self.right))
description = property(description)
primary_key = property(lambda s:s.__primary_key)
#return self.selectable._exportable_columns()
return self.selectable.columns
- def _clone(self):
- # TODO: need test coverage to assert ClauseAdapter behavior
- # here; must identify non-ORM failure cases when a. _clone() returns 'self' in all
- # cases and b. when _clone() does an actual _clone() in all cases.
- if isinstance(self.selectable, TableClause):
- return self
- else:
- return super(Alias, self)._clone()
-
def _copy_internals(self, clone=_clone):
self._clone_from_clause()
self.selectable = _clone(self.selectable)
for col in self._raw_columns:
for f in col._hide_froms():
hide_froms.add(f)
+ while hasattr(f, '_is_clone_of'):
+ hide_froms.add(f._is_clone_of)
+ f = f._is_clone_of
for f in col._get_from_objects():
froms.add(f)
froms.add(elem)
for f in elem._get_from_objects():
froms.add(f)
-
+
for elem in froms:
for f in elem._hide_froms():
hide_froms.add(f)
-
+ while hasattr(f, '_is_clone_of'):
+ hide_froms.add(f._is_clone_of)
+ f = f._is_clone_of
+
froms = froms.difference(hide_froms)
-
+
if len(froms) > 1:
corr = self.__correlate
if self._should_correlate and existing_froms is not None:
corr = existing_froms.union(corr)
+
+ for f in list(corr):
+ while hasattr(f, '_is_clone_of'):
+ corr.add(f._is_clone_of)
+ f = f._is_clone_of
+
f = froms.difference(corr)
if len(f) == 0:
raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
def _copy_internals(self, clone=_clone):
self._clone_from_clause()
- self._raw_columns = [clone(c) for c in self._raw_columns]
self._recorrelate_froms([(f, clone(f)) for f in self._froms])
+ self._raw_columns = [clone(c) for c in self._raw_columns]
for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr)))
list_[i] = self.traverse(list_[i], stop_on=stop_on)
return list_
- def _convert_element(self, elem, stop_on):
+ def _convert_element(self, elem, stop_on, cloned):
v = self
while v is not None:
newelem = v.convert_element(elem)
stop_on.add(newelem)
return newelem
v = getattr(v, '_next_acp', None)
- return elem._clone()
- def traverse(self, elem, clone=True, stop_on=None, _clone_toplevel=True):
+ if elem not in cloned:
+ # the full traversal will only make a clone of a particular element
+ # once.
+ cloned[elem] = elem._clone()
+ return cloned[elem]
+
+ def traverse(self, elem, clone=True, stop_on=None):
if not clone:
raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
-
+
if stop_on is None:
stop_on = util.Set()
-
+ return self._traverse(elem, stop_on, {}, _clone_toplevel=True)
+
+ def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False):
if elem in stop_on:
return elem
if _clone_toplevel:
- elem = self._convert_element(elem, stop_on)
+ elem = self._convert_element(elem, stop_on, cloned)
if elem in stop_on:
return elem
def clone(element):
- return self._convert_element(element, stop_on)
+ return self._convert_element(element, stop_on, cloned)
elem._copy_internals(clone=clone)
v = getattr(self, '_next', None)
for e in elem.get_children(**self.__traverse_options__):
if e not in stop_on:
- self.traverse(e, stop_on=stop_on, _clone_toplevel=False)
+ self._traverse(e, stop_on, cloned)
return elem
class ClauseAdapter(AbstractClauseProcessor):
traversal.insert(0, t)
for c in t.get_children(**self.__traverse_options__):
stack.append(c)
-
+
def traverse(self, obj, stop_on=None, clone=False):
+
if clone:
- obj = obj._clone()
+ cloned = {}
+ def do_clone(obj):
+ # the full traversal will only make a clone of a particular element
+ # once.
+ if obj not in cloned:
+ cloned[obj] = obj._clone()
+ return cloned[obj]
+
+ obj = do_clone(obj)
stack = [obj]
traversal = []
if stop_on is None or t not in stop_on:
traversal.insert(0, t)
if clone:
- t._copy_internals()
+ t._copy_internals(clone=do_clone)
for c in t.get_children(**self.__traverse_options__):
stack.append(c)
for target in traversal:
import testbase
from sqlalchemy import *
from sqlalchemy.sql import table, column, ClauseElement
+from sqlalchemy.sql.expression import _clone
from testlib import *
from sqlalchemy.sql.visitors import *
from sqlalchemy import util
return True
return False
- def _copy_internals(self):
- self.items = [i._clone() for i in self.items]
+ def _copy_internals(self, clone=_clone):
+ self.items = [clone(i) for i in self.items]
def get_children(self, **kwargs):
return self.items
print str(s5)
assert str(s5) == s5_assert
assert str(s4) == s4_assert
-
+
+ def test_alias(self):
+ subq = t2.select().alias('subq')
+ s = select([t1.c.col1, subq.c.col1], from_obj=[t1, subq, t1.join(subq, t1.c.col1==subq.c.col2)])
+ orig = str(s)
+ s2 = ClauseVisitor().traverse(s, clone=True)
+ assert orig == str(s) == str(s2)
+
+ s4 = ClauseVisitor().traverse(s2, clone=True)
+ assert orig == str(s) == str(s2) == str(s4)
+
+ s3 = sql_util.ClauseAdapter(table('foo')).traverse(s, clone=True)
+ assert orig == str(s) == str(s3)
+
+ s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3, clone=True)
+ assert orig == str(s) == str(s3) == str(s4)
+
def test_correlated_select(self):
s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2)
class Vis(ClauseVisitor):