]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- eager loading with LIMIT/OFFSET applied no longer adds the primary
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Nov 2007 23:17:34 +0000 (23:17 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Nov 2007 23:17:34 +0000 (23:17 +0000)
table joined to a limited subquery of itself; the eager loads now
join directly to the subquery which also provides the primary table's
columns to the result set.  This eliminates a JOIN from all eager loads
with LIMIT/OFFSET.  [ticket:843]

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/visitors.py
test/sql/selectable.py

diff --git a/CHANGES b/CHANGES
index b877d99b31cbb6ad1404c863f15db4f63953b3a9..1d8a5301a0e2451ccb348cd64144ad29612de76c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -28,6 +28,11 @@ CHANGES
     for the column, for those DB's who provide it via cursor.lastrowid
 
 - orm
+  - eager loading with LIMIT/OFFSET applied no longer adds the primary 
+    table joined to a limited subquery of itself; the eager loads now
+    join directly to the subquery which also provides the primary table's
+    columns to the result set.  This eliminates a JOIN from all eager loads
+    with LIMIT/OFFSET.  [ticket:843]
 
   - Mapped classes may now define __eq__, __hash__, and __nonzero__ methods
     with arbitrary sementics.  The orm now handles all mapped instances on
index b6200fee52cecb142e4d1564845e34d409ff0ecf..ac0dc83ab61d8c809537f32aada6c10c57abd3a4 100644 (file)
@@ -48,6 +48,7 @@ class Query(object):
         self._eager_loaders = util.Set([x for x in self.mapper._eager_loaders])
         self._attributes = {}
         self._current_path = ()
+        self._primary_adapter=None
         
     def _clone(self):
         q = Query.__new__(Query)
@@ -686,7 +687,10 @@ class Query(object):
             result = util.UniqueAppender([])
                     
         for row in cursor.fetchall():
-            self.select_mapper._instance(context, row, result)
+            if self._primary_adapter:
+                self.select_mapper._instance(context, self._primary_adapter(row), result)
+            else:
+                self.select_mapper._instance(context, row, result)
             for proc in process:
                 proc[0](context, row)
 
@@ -836,9 +840,35 @@ class Query(object):
         if self.table not in alltables:
             from_obj.append(self.table)
 
+        context.from_clauses = from_obj
+        
+        # give all the attached properties a chance to modify the query
+        # TODO: doing this off the select_mapper.  if its the polymorphic mapper, then
+        # it has no relations() on it.  should we compile those too into the query ?  (i.e. eagerloads)
+        for value in self.select_mapper.iterate_properties:
+            context.exec_with_path(self.select_mapper, value.key, value.setup, context)
+
+        # additional entities/columns, add those to selection criterion
+        for tup in self._entities:
+            (m, alias, alias_id) = tup
+            clauses = self._get_entity_clauses(tup)
+            if isinstance(m, mapper.Mapper):
+                for value in m.iterate_properties:
+                    context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses)
+            elif isinstance(m, sql.ColumnElement):
+                if clauses is not None:
+                    m = clauses.adapt_clause(m)
+                context.secondary_columns.append(m)
+            
         if self._eager_loaders and self._nestable(**self._select_args()):
-            # if theres an order by, add those columns to the column list
-            # of the "rowcount" query we're going to make
+            # eager loaders are present, and the SELECT has limiting criterion
+            # produce a "wrapped" selectable.
+            
+            # ensure all 'order by' elements are ClauseElement instances
+            # (since they will potentially be aliased)
+            # locate all embedded Column clauses so they can be added to the
+            # "inner" select statement where they'll be available to the enclosing
+            # statement's "order by"
             if order_by:
                 order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
                 cf = sql_util.ColumnFinder()
@@ -847,20 +877,35 @@ class Query(object):
             else:
                 cf = []
 
-            s2 = sql.select(self.primary_key_columns + list(cf), whereclause, use_labels=True, from_obj=from_obj, correlate=False, **self._select_args())
+            s2 = sql.select(context.primary_columns + list(cf), whereclause, from_obj=context.from_clauses, use_labels=True, correlate=False, **self._select_args())
+
             if order_by:
-                s2 = s2.order_by(*util.to_list(order_by))
-            s3 = s2.alias('tbl_row_count')
-            crit = s3.primary_key==self.primary_key_columns
-            statement = sql.select([], crit, use_labels=True, for_update=for_update)
-            # now for the order by, convert the columns to their corresponding columns
-            # in the "rowcount" query, and tack that new order by onto the "rowcount" query
+                s2.append_order_by(*util.to_list(order_by))
+            
+            s3 = s2.alias('primary_tbl_limited')
+                
+            self._primary_adapter = mapperutil.create_row_adapter(s3, self.table)
+
+            statement = sql.select([s3] + context.secondary_columns, for_update=for_update, use_labels=True)
+
+            if context.eager_joins:
+                statement.append_from(sql_util.ClauseAdapter(s3).traverse(context.eager_joins), _copy_collection=False)
+
             if order_by:
-                statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
+                statement.append_order_by(*sql_util.ClauseAdapter(s3).copy_and_process(util.to_list(order_by)))
+
+            statement.append_order_by(*context.eager_order_by)
         else:
-            statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args())
+            statement = sql.select(context.primary_columns + context.secondary_columns, whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **self._select_args())
+
+            if context.eager_joins:
+                statement.append_from(context.eager_joins, _copy_collection=False)
+
             if order_by:
                 statement.append_order_by(*util.to_list(order_by))
+
+            if context.eager_order_by:
+                statement.append_order_by(*context.eager_order_by)
                 
             # for a DISTINCT query, you need the columns explicitly specified in order
             # to use it in "order_by".  ensure they are in the column criterion (particularly oid).
@@ -870,24 +915,6 @@ class Query(object):
                 [statement.append_column(c) for c in util.to_list(order_by)]
 
         context.statement = statement
-        
-        # give all the attached properties a chance to modify the query
-        # TODO: doing this off the select_mapper.  if its the polymorphic mapper, then
-        # it has no relations() on it.  should we compile those too into the query ?  (i.e. eagerloads)
-        for value in self.select_mapper.iterate_properties:
-            context.exec_with_path(self.select_mapper, value.key, value.setup, context)
-        
-        # additional entities/columns, add those to selection criterion
-        for tup in self._entities:
-            (m, alias, alias_id) = tup
-            clauses = self._get_entity_clauses(tup)
-            if isinstance(m, mapper.Mapper):
-                for value in m.iterate_properties:
-                    context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses)
-            elif isinstance(m, sql.ColumnElement):
-                if clauses is not None:
-                    m = clauses.adapt_clause(m)
-                statement.append_column(m)
                 
         return context
 
@@ -1164,7 +1191,10 @@ class QueryContext(object):
         self.version_check = query._version_check
         self.identity_map = {}
         self.path = ()
-
+        self.primary_columns = []
+        self.secondary_columns = []
+        self.eager_order_by = []
+        self.eager_joins = None
         self.options = query._with_options
         self.attributes = query._attributes.copy()
     
index b699bfee5341d520da1229cff52360b32bc9c5cd..be783fb3999f9c14847ecc2345fe8b4175e1e61c 100644 (file)
@@ -27,9 +27,9 @@ class ColumnLoader(LoaderStrategy):
     def setup_query(self, context, parentclauses=None, **kwargs):
         for c in self.columns:
             if parentclauses is not None:
-                context.statement.append_column(parentclauses.aliased_column(c))
+                context.secondary_columns.append(parentclauses.aliased_column(c))
             else:
-                context.statement.append_column(c)
+                context.primary_columns.append(c)
         
     def init_class_attribute(self):
         self.is_class_level = True
@@ -498,10 +498,8 @@ class EagerLoader(AbstractRelationLoader):
         else:
             localparent = parentmapper
         
-        statement = context.statement
-        
-        if hasattr(statement, '_outerjoin'):
-            towrap = statement._outerjoin
+        if context.eager_joins:
+            towrap = context.eager_joins
         elif isinstance(localparent.mapped_table, sql.Join):
             towrap = localparent.mapped_table
         else:
@@ -509,7 +507,8 @@ class EagerLoader(AbstractRelationLoader):
             # this will locate the selectable inside of any containers it may be a part of (such
             # as a join).  if its inside of a join, we want to outer join on that join, not the 
             # selectable.
-            for fromclause in statement.froms:
+            # TODO: slightly hacky way to get at all the froms
+            for fromclause in sql.select(from_obj=context.from_clauses).froms:
                 if fromclause is localparent.mapped_table:
                     towrap = fromclause
                     break
@@ -518,7 +517,7 @@ class EagerLoader(AbstractRelationLoader):
                         towrap = fromclause
                         break
             else:
-                raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table))
+                raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join onto, for mapped table %s" % (localparent.mapped_table))
         
         # create AliasedClauses object to build up the eager query.  this is cached after 1st creation.    
         try:
@@ -532,19 +531,17 @@ class EagerLoader(AbstractRelationLoader):
         context.attributes[("eager_row_processor", path)] = clauses.row_decorator
         
         if self.secondaryjoin is not None:
-            statement._outerjoin = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin)
+            context.eager_joins = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin)
             if self.order_by is False and self.secondary.default_order_by() is not None:
-                statement.append_order_by(*clauses.secondary.default_order_by())
+                context.eager_order_by.append(*clauses.secondary.default_order_by())
         else:
-            statement._outerjoin = towrap.outerjoin(clauses.alias, clauses.primaryjoin)
+            context.eager_joins = towrap.outerjoin(clauses.alias, clauses.primaryjoin)
             if self.order_by is False and clauses.alias.default_order_by() is not None:
-                statement.append_order_by(*clauses.alias.default_order_by())
+                context.eager_order_by.append(*clauses.alias.default_order_by())
 
         if clauses.order_by:
-            statement.append_order_by(*util.to_list(clauses.order_by))
+            context.eager_order_by.append(*util.to_list(clauses.order_by))
         
-        statement.append_from(statement._outerjoin)
-
         for value in self.select_mapper.iterate_properties:
             context.exec_with_path(self.select_mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.select_mapper)
         
index 08289ae8961913072e5d55e6536bc29bc76f1ec2..87b74c3c27b13b38cdf369c8611d186a54253167 100644 (file)
@@ -224,27 +224,29 @@ class AliasedClauses(object):
         of that table, in such a way that the row can be passed to logic which knows nothing about the aliased form
         of the table.
         """
-        class AliasedRowAdapter(object):
-            def __init__(self, row):
-                self.row = row
-            def __contains__(self, key):
-                return key in map or key in self.row
-            def has_key(self, key):
-                return key in self
-            def __getitem__(self, key):
-                if key in map:
-                    key = map[key]
-                return self.row[key]
-            def keys(self):
-                return map.keys()
-        map = {}
-        for c in self.mapped_table.c:
-            map[c] = self.alias.corresponding_column(c)
-                
-        AliasedRowAdapter.map = map
-        return AliasedRowAdapter
+        return create_row_adapter(self.alias, self.mapped_table)
+
+def create_row_adapter(from_, to):
+    map = {}        
+    for c in to.c:
+        map[c] = from_.corresponding_column(c)
+
+    class AliasedRow(object):
+        def __init__(self, row):
+            self.row = row
+        def __contains__(self, key):
+            return key in map or key in self.row
+        def has_key(self, key):
+            return key in self
+        def __getitem__(self, key):
+            if key in map:
+                key = map[key]
+            return self.row[key]
+        def keys(self):
+            return map.keys()
+    AliasedRow.map = map
+    return AliasedRow
 
-    
 class PropertyAliasedClauses(AliasedClauses):
     """extends AliasedClauses to add support for primary/secondary joins on a relation()."""
     
index bf15c2b7eec69e2ecd6c985a9a12b1f0dae06232..9bc5d2479fd2adafc1d9daf9847a7cb1453c0328 100644 (file)
@@ -29,6 +29,14 @@ class ClauseVisitor(object):
         if meth:
             return meth(obj, **kwargs)
 
+    def traverse_chained(self, obj, **kwargs):
+        v = self
+        while v is not None:
+            meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
+            if meth:
+                meth(obj, **kwargs)
+            v = getattr(v, '_next', None)
+        
     def iterate(self, obj, stop_on=None):
         stack = [obj]
         traversal = []
index 72f5f35d040a1e8b932bb850732db83d8b07d901..9fca3ee08ef09083ebdbbf7376d28a1dfc2744a3 100755 (executable)
@@ -166,7 +166,27 @@ class SelectableTest(AssertMixin):
         print str(criterion)
         print str(j.onclause)
         self.assert_(criterion.compare(j.onclause))
-        
+    
+    def testtablejoinedtoselectoftable(self):
+        metadata = MetaData()
+        a = Table('a', metadata,
+            Column('id', Integer, primary_key=True))
+        b = Table('b', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('aid', Integer, ForeignKey('a.id')),
+            )
+
+        j1 = a.outerjoin(b)
+        j2 = select([a.c.id.label('aid')]).alias('bar')
+
+        j3 = a.join(j2, j2.c.aid==a.c.id)
+
+        j4 = select([j3]).alias('foo')
+        print j4
+        print j4.corresponding_column(j2.c.aid)
+        print j4.c.aid
+        # TODO: this is the assertion case which fails
+#        assert j4.corresponding_column(j2.c.aid) is j4.c.aid
 
 class PrimaryKeyTest(AssertMixin):
     def test_join_pk_collapse_implicit(self):