]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- improved support for complex queries embedded into "where" criterion
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Feb 2007 01:47:54 +0000 (01:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Feb 2007 01:47:54 +0000 (01:47 +0000)
 for query.select() [ticket:449]
- contains_eager('foo') automatically implies eagerload('foo')
- query.options() can take a combiantion MapperOptions and tuples of MapperOptions,
so that functions can return groups
- refactoring to Aliasizer and ClauseAdapter so that they share a common base methodology,
which addresses all sql.ColumnElements instead of just schema.Column.  common list-processing
methods added.
- query.compile and eagerloader._aliasize_orderby make usage of improved list processing on
above.
- query.compile, within the "nested select generate" step processes the order_by clause using
the ClauseAdapter instead of Aliasizer since there is only one "target"

CHANGES
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql_util.py
lib/sqlalchemy/util.py
test/orm/eagertest3.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 38474189e36a66331696f0b0992091fbe86a5a55..4ff29d3b008d5fac821f21c46be0d6687ea3a3ec 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -15,6 +15,9 @@
   the relationship.
   - eager loading is slightly more strict about detecting "self-referential"
   relationships, specifically between polymorphic mappers.
+  - improved support for complex queries embedded into "where" criterion
+  for query.select() [ticket:449]
+  - contains_eager('foo') automatically implies eagerload('foo')
   - fixed bug where cascade operations incorrectly included deleted collection
   items in the cascade [ticket:445]
   - fix to deferred so that load operation doesnt mistakenly occur when only
index 1e1a75b63159613530236ab80e6ca155ee3d7c55..a933cb1b27c4349f12687a0429bdaf30cb8ae01e 100644 (file)
@@ -124,8 +124,8 @@ def contains_eager(key, decorator=None):
     a custom row decorator.  
     
     used when feeding SQL result sets directly into
-    query.instances()."""
-    return strategies.RowDecorateOption(key, decorator=decorator)
+    query.instances().  Also bundles an EagerLazyOption to turn on eager loading in case it isnt already."""
+    return (strategies.EagerLazyOption(key, lazy=False), strategies.RowDecorateOption(key, decorator=decorator))
     
 def defer(name):
     """return a MapperOption that will convert the column property of the given 
index 959c9c36b31904c8108daaa954e7005f1c9ac051..f40856b41a27df25682278c5bbfd5d31a296db4d 100644 (file)
@@ -101,7 +101,7 @@ class OperationContext(object):
         self.options = options
         self.attributes = {}
         self.recursion_stack = util.Set()
-        for opt in options:
+        for opt in util.flatten_iterator(options):
             self.accept_option(opt)
     def accept_option(self, opt):
         pass
index bc047ff504609ff05b6879e1a9695cbeee4fc31d..adec69116b600610c8189429d9fb0e06ce6c9d05 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
 from sqlalchemy.orm import mapper, class_mapper
 from sqlalchemy.orm.interfaces import OperationContext
 
@@ -34,7 +34,7 @@ class Query(object):
                 _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type))
             self.mapper._get_clause = _get_clause
         self._get_clause = self.mapper._get_clause
-        for opt in self.with_options:
+        for opt in util.flatten_iterator(self.with_options):
             opt.process_query(self)
     
     def _insert_extension(self, ext):
@@ -440,7 +440,8 @@ class Query(object):
             if order_by:
                 order_by = util.to_list(order_by) or []
                 cf = sql_util.ColumnFinder()
-                [o.accept_visitor(cf) for o in order_by]
+                for o in order_by:
+                    o.accept_visitor(cf)
             else:
                 cf = []
                 
@@ -449,17 +450,11 @@ class Query(object):
                 s2.order_by(*util.to_list(order_by))
             s3 = s2.alias('tbl_row_count')
             crit = s3.primary_key==self.table.primary_key
-            statement = sql.select([], crit, from_obj=[self.table], use_labels=True, for_update=for_update)
+            statement = sql.select([], crit, use_labels=True, for_update=for_update)
             # now for the order by, convert the columns to their corresponding columns
             # in the "rowcount" query, and tack that new order by onto the "rowcount" query
             if order_by:
-                class Aliasizer(sql_util.Aliasizer):
-                    def get_alias(self, table):
-                        return s3
-                order_by = [o.copy_container() for o in order_by]
-                aliasizer = Aliasizer(*[t for t in sql_util.TableFinder(s3)])
-                [o.accept_visitor(aliasizer) for  o in order_by]
-                statement.order_by(*util.to_list(order_by))
+                statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
         else:
             statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
             if order_by:
index 50a2ea27d2eecf7c4833d5f3b5f6e40ecbb86ce8..29b60a8f614763e3b2901a15fc5a29bebd901adb 100644 (file)
@@ -419,15 +419,11 @@ class EagerLoader(AbstractRelationLoader):
 
         def _aliasize_orderby(self, orderby, copy=True):
             if copy:
-                orderby = [o.copy_container() for o in util.to_list(orderby)]
+                return self.aliasizer.copy_and_process(util.to_list(orderby))
             else:
                 orderby = util.to_list(orderby)
-            for i in range(0, len(orderby)):
-                if isinstance(orderby[i], schema.Column):
-                    orderby[i] = self.eagertarget.corresponding_column(orderby[i])
-                else:
-                    orderby[i].accept_visitor(self.aliasizer)
-            return orderby
+                self.aliasizer.process_list(orderby)
+                return orderby
 
         def _create_decorator_row(self):
             class EagerRowAdapter(object):
index 10d4495d931e64264373af5c884dc5bdbe250021..e672feb1a9ec0e6282dc78e1c6eba9a8de43cafa 100644 (file)
@@ -78,8 +78,53 @@ class ColumnFinder(sql.ClauseVisitor):
         self.columns.add(c)
     def __iter__(self):
         return iter(self.columns)
-            
-class Aliasizer(sql.ClauseVisitor):
+
+class ColumnsInClause(sql.ClauseVisitor):
+    """given a selectable, visits clauses and determines if any columns from the clause are in the selectable"""
+    def __init__(self, selectable):
+        self.selectable = selectable
+        self.result = False
+    def visit_column(self, column):
+        if self.selectable.c.get(column.key) is column:
+            self.result = True
+
+class AbstractClauseProcessor(sql.ClauseVisitor):
+    """traverses a clause and attempts to convert the contents of container elements
+    to a converted element.  the conversion operation is defined by subclasses."""
+    def convert_element(self, elem):
+        """define the 'conversion' method for this AbstractClauseProcessor"""
+        raise NotImplementedError()
+    def copy_and_process(self, list_):
+        """copy the container elements in the given list to a new list and
+        process the new list."""
+        list_ = [o.copy_container() for o in list_]
+        self.process_list(list_)
+        return list_
+
+    def process_list(self, list_):
+        """process all elements of the given list in-place"""
+        for i in range(0, len(list_)):
+            elem = self.convert_element(list_[i])
+            if elem is not None:
+                list_[i] = elem
+            else:
+                list_[i].accept_visitor(self)
+    def visit_compound(self, compound):
+        self.visit_clauselist(compound)
+    def visit_clauselist(self, clist):
+        for i in range(0, len(clist.clauses)):
+            n = self.convert_element(clist.clauses[i])
+            if n is not None:
+                clist.clauses[i] = n
+    def visit_binary(self, binary):
+        elem = self.convert_element(binary.left)
+        if elem is not None:
+            binary.left = elem
+        elem = self.convert_element(binary.right)
+        if elem is not None:
+            binary.right = elem
+                
+class Aliasizer(AbstractClauseProcessor):
     """converts a table instance within an expression to be an alias of that table."""
     def __init__(self, *tables, **kwargs):
         self.tables = {}
@@ -95,21 +140,13 @@ class Aliasizer(sql.ClauseVisitor):
         self.binary = None
     def get_alias(self, table):
         return self.aliases[table]
-    def visit_compound(self, compound):
-        self.visit_clauselist(compound)
-    def visit_clauselist(self, clist):
-        for i in range(0, len(clist.clauses)):
-            if isinstance(clist.clauses[i], schema.Column) and self.tables.has_key(clist.clauses[i].table):
-                orig = clist.clauses[i]
-                clist.clauses[i] = self.get_alias(clist.clauses[i].table).corresponding_column(clist.clauses[i])
-    def visit_binary(self, binary):
-        if isinstance(binary.left, schema.Column) and self.tables.has_key(binary.left.table):
-            binary.left = self.get_alias(binary.left.table).corresponding_column(binary.left)
-        if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table):
-            binary.right = self.get_alias(binary.right.table).corresponding_column(binary.right)
-
+    def convert_element(self, elem):
+        if isinstance(elem, sql.ColumnElement) and hasattr(elem, 'table') and self.tables.has_key(elem.table):
+            return self.get_alias(elem.table).corresponding_column(elem)
+        else:
+            return None
 
-class ClauseAdapter(sql.ClauseVisitor):
+class ClauseAdapter(AbstractClauseProcessor):
     """given a clause (like as in a WHERE criterion), locates columns which 'correspond' to a given selectable, 
     and changes those columns to be that of the selectable.
     
@@ -140,7 +177,8 @@ class ClauseAdapter(sql.ClauseVisitor):
         self.include = include
         self.exclude = exclude
         self.equivalents = equivalents
-    def include_col(self, col):
+        
+    def convert_element(self, col):
         if not isinstance(col, sql.ColumnElement):
             return None
         if self.include is not None:
@@ -153,19 +191,4 @@ class ClauseAdapter(sql.ClauseVisitor):
         if newcol is None and self.equivalents is not None and col in self.equivalents:
             newcol = self.selectable.corresponding_column(self.equivalents[col], raiseerr=False, keys_ok=False)
         return newcol
-    def visit_binary(self, binary):
-        col = self.include_col(binary.left)
-        if col is not None:
-            binary.left = col
-        col = self.include_col(binary.right)
-        if col is not None:
-            binary.right = col
-
-class ColumnsInClause(sql.ClauseVisitor):
-    """given a selectable, visits clauses and determines if any columns from the clause are in the selectable"""
-    def __init__(self, selectable):
-        self.selectable = selectable
-        self.result = False
-    def visit_column(self, column):
-        if self.selectable.c.get(column.key) is column:
-            self.result = True
+    
index 54b1afa9fbd33e1fff1f96129320b943061241cd..c2e0dbc45ff1655ca676c5aab354a3dbfee7b0af 100644 (file)
@@ -35,6 +35,16 @@ def to_set(x):
     else:
         return x
 
+def flatten_iterator(x):
+    """given an iterator of which further sub-elements may also be iterators,
+    flatten the sub-elements into a single iterator."""
+    for elem in x:
+        if hasattr(elem, '__iter__'):
+            for y in flatten_iterator(elem):
+                yield y
+        else:
+            yield elem
+            
 def reversed(seq):
     try:
         return __builtin__.reversed(seq)
index 10f8fce0a3dc6a5e93ef4073a43e7b403428cc99..e33ce439449b6baf3b4fa9122a38d3b6dcbf23bf 100644 (file)
@@ -2,6 +2,7 @@ from testbase import PersistTest, AssertMixin
 import testbase
 from sqlalchemy import *
 from sqlalchemy.ext.selectresults import SelectResults
+import random
 
 class EagerTest(AssertMixin):
     def setUpAll(self):
@@ -197,7 +198,74 @@ class EagerTest2(AssertMixin):
         session.clear()
         obj = session.query(Left).get_by(tag='tag1')
         print obj.middle.right[0]
+
+class EagerTest3(testbase.ORMTest):
+    """test eager loading combined with nested SELECT statements, functions, and aggregates"""
+    def define_tables(self, metadata):
+        global datas, foo, stats
+        datas=Table( 'datas',metadata,
+         Column ( 'id', Integer, primary_key=True,nullable=False ),
+         Column ( 'a', Integer , nullable=False ) )
+
+        foo=Table('foo',metadata,
+         Column ( 'data_id', Integer, ForeignKey('datas.id'),nullable=False,primary_key=True ),
+         Column ( 'bar', Integer ) )
+
+        stats=Table('stats',metadata,
+        Column ( 'id', Integer, primary_key=True, nullable=False ),
+        Column ( 'data_id', Integer, ForeignKey('datas.id')),
+        Column ( 'somedata', Integer, nullable=False ))
+        
+    def test_nesting_with_functions(self):
+        class Data(object): pass
+        class Foo(object):pass
+        class Stat(object): pass
+
+        Data.mapper=mapper(Data,datas)
+        Foo.mapper=mapper(Foo,foo,properties={'data':relation(Data,backref=backref('foo',uselist=False))})
+        Stat.mapper=mapper(Stat,stats,properties={'data':relation(Data)})
+
+        s=create_session()
+        data = []
+        for x in range(5):
+            d=Data()
+            d.a=x
+            s.save(d)
+            data.append(d)
+            
+        for x in range(10):
+            rid=random.randint(0,len(data) - 1)
+            somedata=random.randint(1,50000)
+            stat=Stat()
+            stat.data = data[rid]
+            stat.somedata=somedata
+            s.save(stat)
+
+        s.flush()
+
+        arb_data=select(
+            [stats.c.data_id,func.max(stats.c.somedata).label('max')],
+            stats.c.data_id<=25,
+            group_by=[stats.c.data_id]).alias('arb')
+        
+        arb_result = arb_data.execute().fetchall()
+        # order the result list descending based on 'max'
+        arb_result.sort(lambda a, b:cmp(b['max'],a['max']))
+        # extract just the "data_id" from it
+        arb_result = [row['data_id'] for row in arb_result]
+        
+        # now query for Data objects using that above select, adding the 
+        # "order by max desc" separately
+        q=s.query(Data).options(eagerload('foo')).select(
+            from_obj=[datas.join(arb_data,arb_data.c.data_id==datas.c.id)],
+            order_by=[desc(arb_data.c.max)],limit=10)
+        
+        # extract "data_id" from the list of result objects
+        verify_result = [d.id for d in q]
         
+        # assert equality including ordering (may break if the DB "ORDER BY" and python's sort() used differing
+        # algorithms and there are repeated 'somedata' values in the list)
+        assert verify_result == arb_result
         
 if __name__ == "__main__":    
     testbase.main()
index 479d484539c97207984015e96e2744376863f9f0..46854b2ceac5c9805b8e9e6be8545ab1937a583f 100644 (file)
@@ -1028,7 +1028,9 @@ class EagerTest(MapperSuperTest):
 
     def testcustomeagerquery(self):
         mapper(User, users, properties={
-            'addresses':relation(Address, lazy=False)
+            # setting lazy=True - the contains_eager() option below
+            # should imply eagerload()
+            'addresses':relation(Address, lazy=True)
         })
         mapper(Address, addresses)