From: Mike Bayer Date: Thu, 8 Nov 2007 18:06:21 +0000 (+0000) Subject: more changes to traverse-and-clone; a particular element will only be cloned once... X-Git-Tag: rel_0_4_1~40 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2cbb133567befca7e92f8e3bbc0aaae96b1781a8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git more changes to traverse-and-clone; a particular element will only be cloned once and is then re-used. the FROM calculation of a Select normalizes the list of hide_froms against all previous incarnations of each FROM clause, using a tag attached from cloned clause to previous. --- diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 6276f33bd9..e066632afb 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -844,6 +844,14 @@ class ClauseElement(object): """ c = self.__class__.__new__(self.__class__) c.__dict__ = self.__dict__.copy() + + # this is a marker that helps to "equate" clauses to each other + # when a Select returns its list of FROM clauses. the cloning + # process leaves around a lot of remnants of the previous clause + # typically in the form of column expressions still attached to the + # old table. + c._is_clone_of = self + return c def _get_from_objects(self, **modifiers): @@ -2212,7 +2220,7 @@ class Join(FromClause): self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit]) def description(self): - return "Join object on %s and %s" % (self.left.description, self.right.description) + return "Join object on %s(%d) and %s(%d)" % (self.left.description, id(self.left), self.right.description, id(self.right)) description = property(description) primary_key = property(lambda s:s.__primary_key) @@ -2394,15 +2402,6 @@ class Alias(FromClause): #return self.selectable._exportable_columns() return self.selectable.columns - def _clone(self): - # TODO: need test coverage to assert ClauseAdapter behavior - # here; must identify non-ORM failure cases when a. _clone() returns 'self' in all - # cases and b. when _clone() does an actual _clone() in all cases. - if isinstance(self.selectable, TableClause): - return self - else: - return super(Alias, self)._clone() - def _copy_internals(self, clone=_clone): self._clone_from_clause() self.selectable = _clone(self.selectable) @@ -2996,6 +2995,9 @@ class Select(_SelectBaseMixin, FromClause): for col in self._raw_columns: for f in col._hide_froms(): hide_froms.add(f) + while hasattr(f, '_is_clone_of'): + hide_froms.add(f._is_clone_of) + f = f._is_clone_of for f in col._get_from_objects(): froms.add(f) @@ -3007,17 +3009,26 @@ class Select(_SelectBaseMixin, FromClause): froms.add(elem) for f in elem._get_from_objects(): froms.add(f) - + for elem in froms: for f in elem._hide_froms(): hide_froms.add(f) - + while hasattr(f, '_is_clone_of'): + hide_froms.add(f._is_clone_of) + f = f._is_clone_of + froms = froms.difference(hide_froms) - + if len(froms) > 1: corr = self.__correlate if self._should_correlate and existing_froms is not None: corr = existing_froms.union(corr) + + for f in list(corr): + while hasattr(f, '_is_clone_of'): + corr.add(f._is_clone_of) + f = f._is_clone_of + f = froms.difference(corr) if len(f) == 0: raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate())) @@ -3070,8 +3081,8 @@ class Select(_SelectBaseMixin, FromClause): def _copy_internals(self, clone=_clone): self._clone_from_clause() - self._raw_columns = [clone(c) for c in self._raw_columns] self._recorrelate_froms([(f, clone(f)) for f in self._froms]) + self._raw_columns = [clone(c) for c in self._raw_columns] for attr in ('_whereclause', '_having', '_order_by_clause', '_group_by_clause'): if getattr(self, attr) is not None: setattr(self, attr, clone(getattr(self, attr))) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index ecf4f3c163..81d28ac7ed 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -148,7 +148,7 @@ class AbstractClauseProcessor(object): list_[i] = self.traverse(list_[i], stop_on=stop_on) return list_ - def _convert_element(self, elem, stop_on): + def _convert_element(self, elem, stop_on, cloned): v = self while v is not None: newelem = v.convert_element(elem) @@ -156,25 +156,32 @@ class AbstractClauseProcessor(object): stop_on.add(newelem) return newelem v = getattr(v, '_next_acp', None) - return elem._clone() - def traverse(self, elem, clone=True, stop_on=None, _clone_toplevel=True): + if elem not in cloned: + # the full traversal will only make a clone of a particular element + # once. + cloned[elem] = elem._clone() + return cloned[elem] + + def traverse(self, elem, clone=True, stop_on=None): if not clone: raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True") - + if stop_on is None: stop_on = util.Set() - + return self._traverse(elem, stop_on, {}, _clone_toplevel=True) + + def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False): if elem in stop_on: return elem if _clone_toplevel: - elem = self._convert_element(elem, stop_on) + elem = self._convert_element(elem, stop_on, cloned) if elem in stop_on: return elem def clone(element): - return self._convert_element(element, stop_on) + return self._convert_element(element, stop_on, cloned) elem._copy_internals(clone=clone) v = getattr(self, '_next', None) @@ -186,7 +193,7 @@ class AbstractClauseProcessor(object): for e in elem.get_children(**self.__traverse_options__): if e not in stop_on: - self.traverse(e, stop_on=stop_on, _clone_toplevel=False) + self._traverse(e, stop_on, cloned) return elem class ClauseAdapter(AbstractClauseProcessor): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 9bc5d2479f..1a0629a17d 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -47,10 +47,19 @@ class ClauseVisitor(object): traversal.insert(0, t) for c in t.get_children(**self.__traverse_options__): stack.append(c) - + def traverse(self, obj, stop_on=None, clone=False): + if clone: - obj = obj._clone() + cloned = {} + def do_clone(obj): + # the full traversal will only make a clone of a particular element + # once. + if obj not in cloned: + cloned[obj] = obj._clone() + return cloned[obj] + + obj = do_clone(obj) stack = [obj] traversal = [] @@ -59,7 +68,7 @@ class ClauseVisitor(object): if stop_on is None or t not in stop_on: traversal.insert(0, t) if clone: - t._copy_internals() + t._copy_internals(clone=do_clone) for c in t.get_children(**self.__traverse_options__): stack.append(c) for target in traversal: diff --git a/test/sql/generative.py b/test/sql/generative.py index 2d1f3ccf91..1497ecde3d 100644 --- a/test/sql/generative.py +++ b/test/sql/generative.py @@ -1,6 +1,7 @@ import testbase from sqlalchemy import * from sqlalchemy.sql import table, column, ClauseElement +from sqlalchemy.sql.expression import _clone from testlib import * from sqlalchemy.sql.visitors import * from sqlalchemy import util @@ -56,8 +57,8 @@ class TraversalTest(AssertMixin): return True return False - def _copy_internals(self): - self.items = [i._clone() for i in self.items] + def _copy_internals(self, clone=_clone): + self.items = [clone(i) for i in self.items] def get_children(self, **kwargs): return self.items @@ -223,7 +224,23 @@ class ClauseTest(SQLCompileTest): print str(s5) assert str(s5) == s5_assert assert str(s4) == s4_assert - + + def test_alias(self): + subq = t2.select().alias('subq') + s = select([t1.c.col1, subq.c.col1], from_obj=[t1, subq, t1.join(subq, t1.c.col1==subq.c.col2)]) + orig = str(s) + s2 = ClauseVisitor().traverse(s, clone=True) + assert orig == str(s) == str(s2) + + s4 = ClauseVisitor().traverse(s2, clone=True) + assert orig == str(s) == str(s2) == str(s4) + + s3 = sql_util.ClauseAdapter(table('foo')).traverse(s, clone=True) + assert orig == str(s) == str(s3) + + s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3, clone=True) + assert orig == str(s) == str(s3) == str(s4) + def test_correlated_select(self): s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2) class Vis(ClauseVisitor):