]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- more fixes to the LIMIT/OFFSET aliasing applied with Query + eagerloads,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Dec 2007 18:32:25 +0000 (18:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 16 Dec 2007 18:32:25 +0000 (18:32 +0000)
in this case when mapped against a select statement [ticket:904]
- _hide_froms logic in expression totally localized to Join class, including search through previous clone sources
- removed "stop_on" from main visitors, not used
- "stop_on" in AbstractClauseProcessor part of constructor, ClauseAdapter sets it up based on given clause
- fixes to is_derived_from() to take previous clone sources into account, Alias takes self + cloned sources into account. this is ultimately what the #904 bug was.

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/orm/eager_relations.py
test/orm/query.py
test/sql/generative.py

diff --git a/CHANGES b/CHANGES
index ee5831d584ea26fff8498d8e7d95438d09a59048..f76a85901078e7cabca91aee25864014a3902ffd 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -172,6 +172,9 @@ CHANGES
    - fixed endless loop issue when using lazy="dynamic" on both 
      sides of a bi-directional relationship [ticket:872]
 
+   - more fixes to the LIMIT/OFFSET aliasing applied with Query + eagerloads,
+     in this case when mapped against a select statement [ticket:904]
+     
    - fix to self-referential eager loading such that if the same mapped
      instance appears in two or more distinct sets of columns in the same
      result set, its eagerly loaded collection will be populated regardless
index d75a9b8b2738818ab06deb0ec78352a5b0a75045..902a4fd3be14f5f7b78fc828a94df37a36706e7a 100644 (file)
@@ -919,7 +919,7 @@ class Query(object):
         adapt_criterion = self.table not in self._get_joinable_tables()
 
         if not adapt_criterion and whereclause is not None and (self.mapper is not self.select_mapper):
-            whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause, stop_on=util.Set([from_obj]))
+            whereclause = sql_util.ClauseAdapter(from_obj).traverse(whereclause)
 
         # TODO: mappers added via add_entity(), adapt their queries also, 
         # if those mappers are polymorphic
index a4d8fa6a08adae4b868eb57ae285e42a434974b0..ddeaaf8addad5c0c589022663cd65ae9d0eda8d7 100644 (file)
@@ -28,6 +28,7 @@ to stay the same in future releases.
 import re
 import datetime
 import warnings
+from itertools import chain
 from sqlalchemy import util, exceptions
 from sqlalchemy.sql import operators, visitors
 from sqlalchemy import types as sqltypes
@@ -864,6 +865,13 @@ class ClauseElement(object):
         
         return c
 
+    def _cloned_set(self):
+        f = self
+        while f is not None:
+            yield f
+            f = getattr(f, '_is_clone_of', None)
+    _cloned_set = property(_cloned_set)
+
     def _get_from_objects(self, **modifiers):
         """Return objects represented in this ``ClauseElement`` that
         should be added to the ``FROM`` list of a query, when this
@@ -1543,7 +1551,8 @@ class FromClause(Selectable):
 
     __visit_name__ = 'fromclause'
     named_with_column=False
-
+    _hide_froms = []
+    
     def __init__(self):
         self.oid_column = None
         
@@ -1588,7 +1597,7 @@ class FromClause(Selectable):
 
         An example would be an Alias of a Table is derived from that Table.
         """
-        return fromclause is self
+        return fromclause in util.Set(self._cloned_set)
 
     def replace_selectable(self, old, alias):
       """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
@@ -1649,22 +1658,6 @@ class FromClause(Selectable):
         return getattr(self, 'name', self.__class__.__name__ + " object")
     description = property(description)
 
-    def _aggregate_hide_froms(self, **modifiers):
-        """Return a list of ``FROM`` clause elements which this ``FromClause`` replaces, taking into account
-        the element which this element was cloned from (and so on until the orginal is reached).
-        """
-        
-        s = self
-        while s is not None:
-            for h in s._hide_froms(**modifiers):
-                yield h
-            s = getattr(s, '_is_clone_of', None)
-            
-    def _hide_froms(self, **modifiers):
-        """Return a list of ``FROM`` clause elements which this ``FromClause`` replaces."""
-
-        return []
-    
     def _clone_from_clause(self):
         # delete all the "generated" collections of columns for a
         # newly cloned FromClause, so that they will be re-derived
@@ -2230,6 +2223,7 @@ class Join(FromClause):
     def __init__(self, left, right, onclause=None, isouter = False):
         self.left = _selectable(left)
         self.right = _selectable(right).self_group()
+        
         self.oid_column = self.left.oid_column
         if onclause is None:
             self.onclause = self._match_primaries(self.left, self.right)
@@ -2303,7 +2297,7 @@ class Join(FromClause):
         self.right = clone(self.right)
         self.onclause = clone(self.onclause)
         self.__folded_equivalents = None
-
+        
     def get_children(self, **kwargs):
         return self.left, self.right, self.onclause
 
@@ -2409,9 +2403,10 @@ class Join(FromClause):
 
         return self.select(use_labels=True, correlate=False).alias(name)
 
-    def _hide_froms(self, **modifiers):
-        return self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
-
+    def _hide_froms(self):
+        return chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set])
+    _hide_froms = property(_hide_froms)
+    
     def _get_from_objects(self, **modifiers):
         return [self] + self.onclause._get_from_objects(**modifiers) + self.left._get_from_objects(**modifiers) + self.right._get_from_objects(**modifiers)
 
@@ -2450,6 +2445,8 @@ class Alias(FromClause):
     description = property(description)
     
     def is_derived_from(self, fromclause):
+        if fromclause in util.Set(self._cloned_set):
+            return True
         return self.selectable.is_derived_from(fromclause)
 
     def supports_execution(self):
@@ -2527,13 +2524,11 @@ class _FromGrouping(FromClause):
         self.elem = elem
 
     columns = c = property(lambda s:s.elem.columns)
-
+    _hide_froms = property(lambda s:s.elem._hide_froms)
+    
     def get_children(self, **kwargs):
         return self.elem,
 
-    def _hide_froms(self, **modifiers):
-        return self.elem._hide_froms(**modifiers)
-
     def _copy_internals(self, clone=_clone):
         self.elem = clone(self.elem)
 
@@ -3066,7 +3061,6 @@ class Select(_SelectBaseMixin, FromClause):
         """
 
         froms = util.OrderedSet()
-        hide_froms = util.Set()
 
         for col in self._raw_columns:
             froms.update(col._get_from_objects())
@@ -3078,14 +3072,13 @@ class Select(_SelectBaseMixin, FromClause):
             froms.update(self._froms)
  
         for f in froms:
-            hide_froms.update(f._aggregate_hide_froms())
-        froms = froms.difference(hide_froms)
+            froms.difference_update(f._hide_froms)
         
         if len(froms) > 1:
             if self.__correlate:
-                froms = froms.difference(self.__correlate)
+                froms.difference_update(self.__correlate)
             if self._should_correlate and existing_froms is not None:
-                froms = froms.difference(existing_froms)
+                froms.difference_update(existing_froms)
                 
             if not froms:
                 raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
@@ -3129,6 +3122,9 @@ class Select(_SelectBaseMixin, FromClause):
     inner_columns = property(_get_inner_columns, doc="""a collection of all ColumnElement expressions which would be rendered into the columns clause of the resulting SELECT statement.""")
 
     def is_derived_from(self, fromclause):
+        if self in util.Set(fromclause._cloned_set):
+            return True
+        
         for f in self.locate_all_froms():
             if f.is_derived_from(fromclause):
                 return True
index 5aa985f4725ce998cbb9f3f7d2369c30c2adb606..d6b10a78a335e835eafcca7d548d4bb3b5e5255d 100644 (file)
@@ -71,6 +71,9 @@ class AbstractClauseProcessor(object):
     
     __traverse_options__ = {'column_collections':False}
     
+    def __init__(self, stop_on=None):
+        self.stop_on = stop_on
+    
     def convert_element(self, elem):
         """Define the *conversion* method for this ``AbstractClauseProcessor``."""
 
@@ -92,13 +95,14 @@ class AbstractClauseProcessor(object):
         setattr(tail, attr, visitor)
         return self
 
-    def copy_and_process(self, list_, stop_on=None):
+    def copy_and_process(self, list_):
         """Copy the given list to a new list, with each element traversed individually."""
         
         list_ = list(list_)
-        stop_on = util.Set()
+        stop_on = util.Set(self.stop_on or [])
+        cloned = {}
         for i in range(0, len(list_)):
-            list_[i] = self.traverse(list_[i], stop_on=stop_on)
+            list_[i] = self._traverse(list_[i], stop_on, cloned, _clone_toplevel=True)
         return list_
 
     def _convert_element(self, elem, stop_on, cloned):
@@ -116,13 +120,11 @@ class AbstractClauseProcessor(object):
             cloned[elem] = elem._clone()
         return cloned[elem]
         
-    def traverse(self, elem, clone=True, stop_on=None):
+    def traverse(self, elem, clone=True):
         if not clone:
             raise exceptions.ArgumentError("AbstractClauseProcessor 'clone' argument must be True")
         
-        if stop_on is None:
-            stop_on = util.Set()
-        return self._traverse(elem, stop_on, {}, _clone_toplevel=True)
+        return self._traverse(elem, util.Set(self.stop_on or []), {}, _clone_toplevel=True)
         
     def _traverse(self, elem, stop_on, cloned, _clone_toplevel=False):
         if elem in stop_on:
@@ -178,6 +180,7 @@ class ClauseAdapter(AbstractClauseProcessor):
     """
 
     def __init__(self, selectable, include=None, exclude=None, equivalents=None):
+        AbstractClauseProcessor.__init__(self, [selectable])
         self.selectable = selectable
         self.include = include
         self.exclude = exclude
index 150ee9cc7b497c101cedbb34df47dcb5b8fc29a4..bb63ab09c92d547e74c4e5b277782e25b860ad88 100644 (file)
@@ -37,18 +37,17 @@ class ClauseVisitor(object):
                 meth(obj, **kwargs)
             v = getattr(v, '_next', None)
         
-    def iterate(self, obj, stop_on=None):
+    def iterate(self, obj):
         stack = [obj]
         traversal = []
         while len(stack) > 0:
             t = stack.pop()
-            if stop_on is None or t not in stop_on:
-                yield t
-                traversal.insert(0, t)
-                for c in t.get_children(**self.__traverse_options__):
-                    stack.append(c)
+            yield t
+            traversal.insert(0, t)
+            for c in t.get_children(**self.__traverse_options__):
+                stack.append(c)
     
-    def traverse(self, obj, stop_on=None, clone=False):
+    def traverse(self, obj, clone=False):
         
         if clone:
             cloned = {}
@@ -60,17 +59,15 @@ class ClauseVisitor(object):
                 return cloned[obj]
             
             obj = do_clone(obj)
-            
         stack = [obj]
         traversal = []
         while len(stack) > 0:
             t = stack.pop()
-            if stop_on is None or t not in stop_on:
-                traversal.insert(0, t)
-                if clone:
-                    t._copy_internals(clone=do_clone)
-                for c in t.get_children(**self.__traverse_options__):
-                    stack.append(c)
+            traversal.insert(0, t)
+            if clone:
+                t._copy_internals(clone=do_clone)
+            for c in t.get_children(**self.__traverse_options__):
+                stack.append(c)
         for target in traversal:
             v = self
             while v is not None:
index 7a822234ce4e3e79e35fd99fd8c3dc2bc8fa9146..bef4ecffc087961520c271c2920949910f55548c 100644 (file)
@@ -418,6 +418,23 @@ class EagerTest(FixtureTest):
             )
         ] == l.all()
 
+    def test_limit_4(self):
+        # tests the LIMIT/OFFSET aliasing on a mapper against a select.   original issue from ticket #904
+        sel = select([users, addresses.c.email_address], users.c.id==addresses.c.user_id).alias('useralias')
+        mapper(User, sel, properties={
+            'orders':relation(Order, primaryjoin=sel.c.id==orders.c.user_id, lazy=False)
+        })
+        mapper(Order, orders)
+        
+        sess = create_session()
+        self.assertEquals(sess.query(User).first(), 
+            User(name=u'jack',orders=[
+                Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), 
+                Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), 
+                Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5)],
+            email_address=u'jack@bean.com',id=7)
+        )
+        
     def test_one_to_many_scalar(self):
         mapper(User, users, properties = dict(
             address = relation(mapper(Address, addresses), lazy=False, uselist=False)
index 1dc50ebbb10e573979fb85995bdd016b4777dec9..9471f1012a2753b17fab6663ef5c9cad45973c96 100644 (file)
@@ -955,6 +955,10 @@ class SelectFromTest(QueryTest):
         self.assertEquals(sess.query(User).select_from(sel).order_by(asc(User.name)).all(), [
             User(name='ed',id=8), User(name='jack',id=7)
         ])
+        
+        self.assertEquals(sess.query(User).select_from(sel).options(eagerload('addresses')).first(), 
+            User(name='jack', addresses=[Address(id=1)])
+        )
 
     def test_join(self):
         mapper(User, users, properties = {
index 847443330f661cd4d434235ee7cdf99e9157d87e..41b4caebf7ab3fa12e29f6f82070312d781e713a 100644 (file)
@@ -269,7 +269,22 @@ class ClauseTest(SQLCompileTest):
                 
         self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :table1_col2_1")
 
-    def test_clause_adapter(self):
+class ClauseAdapterTest(SQLCompileTest):
+    def setUpAll(self):
+        global t1, t2
+        t1 = table("table1", 
+            column("col1"),
+            column("col2"),
+            column("col3"),
+            )
+        t2 = table("table2", 
+            column("col1"),
+            column("col2"),
+            column("col3"),
+            )
+            
+
+    def test_table_to_alias(self):
         
         t1alias = t1.alias('t1alias')
         
@@ -302,7 +317,7 @@ class ClauseTest(SQLCompileTest):
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
         self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
     
-    def test_selfreferential(self):
+    def test_include_exclude(self):
         m = MetaData()
         a=Table( 'a',m,
           Column( 'id',    Integer, primary_key=True),
@@ -319,10 +334,7 @@ class ClauseTest(SQLCompileTest):
         
         assert str(e) == "a_1.id = a.xxx_id"
 
-    def test_joins(self):
-        """test that ClauseAdapter can target a Join object, replace it, and not dig into the sub-joins after
-        replacing."""
-        
+    def test_join_to_alias(self):
         metadata = MetaData()
         a = Table('a', metadata,
             Column('id', Integer, primary_key=True))
@@ -359,6 +371,42 @@ class ClauseTest(SQLCompileTest):
                                 "c JOIN (SELECT a.id AS a_id, b.id AS b_id, b.aid AS b_aid FROM a LEFT OUTER JOIN b ON a.id = b.aid) "
                                 "ON b_id = c.bid) AS foo"
                                 " LEFT OUTER JOIN d ON foo.a_id = d.aid")
+    
+    def test_derived_from(self):
+        assert select([t1]).is_derived_from(t1)
+        assert not select([t2]).is_derived_from(t1)
+        assert not t1.is_derived_from(select([t1]))
+        assert t1.alias().is_derived_from(t1)
+        
+        
+        s1 = select([t1, t2]).alias('foo')
+        s2 = select([s1]).limit(5).offset(10).alias()
+        assert s2.is_derived_from(s1)
+        s2 = s2._clone()
+        assert s2.is_derived_from(s1)
+        
+    def test_aliasedselect_to_aliasedselect(self):
+        # original issue from ticket #904
+        s1 = select([t1]).alias('foo')
+        s2 = select([s1]).limit(5).offset(10).alias()
+
+        self.assert_compile(sql_util.ClauseAdapter(s2).traverse(s1), 
+            "SELECT foo.col1, foo.col2, foo.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo  LIMIT 5 OFFSET 10")
+        
+        j = s1.outerjoin(t2, s1.c.col1==t2.c.col1)
+        self.assert_compile(sql_util.ClauseAdapter(s2).traverse(j).select(), 
+            "SELECT anon_1.col1, anon_1.col2, anon_1.col3, table2.col1, table2.col2, table2.col3 FROM "\
+            "(SELECT foo.col1 AS col1, foo.col2 AS col2, foo.col3 AS col3 FROM "\
+            "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo  LIMIT 5 OFFSET 10) AS anon_1 "\
+            "LEFT OUTER JOIN table2 ON anon_1.col1 = table2.col1")
+
+        talias = t1.alias('bar')
+        j = s1.outerjoin(talias, s1.c.col1==talias.c.col1)
+        self.assert_compile(sql_util.ClauseAdapter(s2).traverse(j).select(), 
+            "SELECT anon_1.col1, anon_1.col2, anon_1.col3, bar.col1, bar.col2, bar.col3 FROM "\
+            "(SELECT foo.col1 AS col1, foo.col2 AS col2, foo.col3 AS col3 FROM "\
+            "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1) AS foo  LIMIT 5 OFFSET 10) AS anon_1 "\
+            "LEFT OUTER JOIN table1 AS bar ON anon_1.col1 = bar.col1")
         
         
 class SelectTest(SQLCompileTest):