]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Query has add_entity() and add_column() generative methods. these
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Mar 2007 03:41:55 +0000 (03:41 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Mar 2007 03:41:55 +0000 (03:41 +0000)
will add the given mapper/class or ColumnElement to the query at compile
time, and apply them to the instances method.  the user is responsible
for constructing reasonable join conditions (otherwise you can get
full cartesian products).  result set is the list of tuples, non-uniqued.
- fixed multi-mapper instances() to pad out shorter results with None so
zip() gets everything

CHANGES
examples/association/proxied_association.py
lib/sqlalchemy/orm/query.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 35ffc81d64769e5b175ad11c58770e8efcc015c1..799690921799c531172356fe1ba97401d896bd36 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       as a list of tuples.  this corresponds to the documented behavior.
       So that instances match up properly, the "uniquing" is disabled when 
       this feature is used.
+    - Query has add_entity() and add_column() generative methods.  these
+      will add the given mapper/class or ColumnElement to the query at compile
+      time, and apply them to the instances method.  the user is responsible
+      for constructing reasonable join conditions (otherwise you can get
+      full cartesian products).  result set is the list of tuples, non-uniqued.
     - strings and columns can also be sent to the *args of instances() where
       those exact result columns will be part of the result tuples.
     - a full select() construct can be passed to query.select() (which
index e6180ad8a9775dbf12e08c11cc56ba919515650c..31a64ce7a0d4b44b57664f3626af9afce579677e 100644 (file)
@@ -97,11 +97,11 @@ print [(item.item.description, item.price) for item in order.itemassociations]
 print [(item.description, item.price) for item in order.items]
 
 # print customers who bought 'MySQL Crowbar' on sale
-result = SelectResults(session.query(Order)).join_to('item').select(and_(items.c.description=='MySQL Crowbar', items.c.price>orderitems.c.price))
+result = session.query(Order).join('item').filter(and_(items.c.description=='MySQL Crowbar', items.c.price>orderitems.c.price))
 print [order.customer_name for order in result]
 
 # print customers who got the special T-shirt discount
-result = SelectResults(session.query(Order)).join_to('item').select(and_(items.c.description=='SA T-Shirt', items.c.price>orderitems.c.price))
+result = session.query(Order).join('item').filter(and_(items.c.description=='SA T-Shirt', items.c.price>orderitems.c.price))
 print [order.customer_name for order in result]
 
 
index 6650954e1bbd5feaeb11236bba32cc090530c5b0..a1c8b6af51aa39b147e242014f93bcc4c8a3df7e 100644 (file)
@@ -33,9 +33,12 @@ class Query(object):
             for primary_key in self.primary_key_columns:
                 _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
             self.mapper._get_clause = _get_clause
+            
+        self._entities = []
         self._get_clause = self.mapper._get_clause
 
         self._order_by = kwargs.pop('order_by', False)
+        self._group_by = kwargs.pop('group_by', False)
         self._distinct = kwargs.pop('distinct', False)
         self._offset = kwargs.pop('offset', None)
         self._limit = kwargs.pop('limit', None)
@@ -52,6 +55,7 @@ class Query(object):
         q.select_mapper = self.select_mapper
         q._order_by = self._order_by
         q._distinct = self._distinct
+        q._entities = list(self._entities)
         q.always_refresh = self.always_refresh
         q.with_options = list(self.with_options)
         q._session = self.session
@@ -62,6 +66,7 @@ class Query(object):
             q.extension.append(ext)
         q._offset = self._offset
         q._limit = self._limit
+        q._group_by = self._group_by
         q._get_clause = self._get_clause
         q._from_obj = list(self._from_obj)
         q._joinpoint = self._joinpoint
@@ -340,6 +345,16 @@ class Query(object):
         t = sql.text(text)
         return self.execute(t, params=params)
 
+    def add_entity(self, entity):
+        q = self._clone()
+        q._entities.append(entity)
+        return q
+        
+    def add_column(self, column):
+        q = self._clone()
+        q._entities.append(column)
+        return q
+        
     def options(self, *args, **kwargs):
         """Return a new Query object, applying the given list of
         MapperOptions.
@@ -497,6 +512,16 @@ class Query(object):
         else:
             q._order_by.extend(util.to_list(criterion))
         return q
+
+    def group_by(self, criterion):
+        """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""
+
+        q = self._clone()
+        if q._group_by is False:    
+            q._group_by = util.to_list(criterion)
+        else:
+            q._group_by.extend(util.to_list(criterion))
+        return q
     
     def join(self, prop):
         """create a join of this ``Query`` object's criterion
@@ -651,6 +676,7 @@ class Query(object):
         context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs)
 
         process = []
+        mappers_or_columns = tuple(self._entities) + mappers_or_columns
         if mappers_or_columns:
             for m in mappers_or_columns:
                 if isinstance(m, type):
@@ -658,7 +684,8 @@ class Query(object):
                 if isinstance(m, mapper.Mapper):
                     appender = []
                     def proc(context, row):
-                        m._instance(context, row, appender)
+                        if not m._instance(context, row, appender):
+                            appender.append(None)
                     process.append((proc, appender))
                 elif isinstance(m, sql.ColumnElement) or isinstance(m, basestring):
                     res = []
@@ -745,14 +772,26 @@ class Query(object):
             whereclause = sql.and_(self._criterion, whereclause)
 
         if whereclause is not None and self.is_polymorphic:
-            # adapt the given WHERECLAUSE to adjust instances of this query's mapped table to be that of our select_table,
+            # adapt the given WHERECLAUSE to adjust instances of this query's mapped 
+            # table to be that of our select_table,
             # which may be the "polymorphic" selectable used by our mapper.
             whereclause.accept_visitor(sql_util.ClauseAdapter(self.table))
-            
+
+            # if extra entities, adapt the criterion to those as well
+            for m in self._entities:
+                if isinstance(m, type):
+                    m = mapper.class_mapper(m)
+                if isinstance(m, mapper.Mapper):
+                    table = m.select_table
+                    whereclause.accept_visitor(sql_util.ClauseAdapter(m.select_table))
+        
+        # get/create query context.  get the ultimate compile arguments
+        # from there
         context = kwargs.pop('query_context', None)
         if context is None:
             context = QueryContext(self, kwargs)
         order_by = context.order_by
+        group_by = context.group_by
         from_obj = context.from_obj
         lockmode = context.lockmode
         distinct = context.distinct
@@ -769,6 +808,8 @@ class Query(object):
         except KeyError:
             raise exceptions.ArgumentError("Unknown lockmode '%s'" % lockmode)
 
+        # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
+        # that we only load the appropriate types
         if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None:
             whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_(*[m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()]))
 
@@ -811,12 +852,23 @@ class Query(object):
                 [statement.append_column(c) for c in util.to_list(order_by)]
 
         context.statement = statement
+        
         # give all the attached properties a chance to modify the query
         # TODO: doing this off the select_mapper.  if its the polymorphic mapper, then
         # it has no relations() on it.  should we compile those too into the query ?  (i.e. eagerloads)
         for value in self.select_mapper.props.values():
             value.setup(context)
 
+        # additional entities/columns, add those to selection criterion
+        for m in self._entities:
+            if isinstance(m, type):
+                m = mapper.class_mapper(m)
+            if isinstance(m, mapper.Mapper):
+                for value in m.props.values():
+                    value.setup(context)
+            elif isinstance(m, sql.ColumnElement):
+                statement.append_column(m)
+                
         return statement
 
     def __log_debug(self, msg):
@@ -833,6 +885,7 @@ class QueryContext(OperationContext):
     def __init__(self, query, kwargs):
         self.query = query
         self.order_by = kwargs.pop('order_by', query._order_by)
+        self.group_by = kwargs.pop('group_by', query._group_by)
         self.from_obj = kwargs.pop('from_obj', query._from_obj)
         self.lockmode = kwargs.pop('lockmode', query.lockmode)
         self.distinct = kwargs.pop('distinct', query._distinct)
@@ -847,7 +900,7 @@ class QueryContext(OperationContext):
         ``QueryContext`` that can be applied to a ``sql.Select``
         statement.
         """
-        return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct}
+        return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct, 'group_by':self.group_by}
 
     def accept_option(self, opt):
         """Accept a ``MapperOption`` which will process (modify) the
index 261fcc1163ba0bd8ad3a292841451459e8facf23..c46825dc29bab799b4991427e6d744662491528b 100644 (file)
@@ -1460,24 +1460,76 @@ class InstancesTest(MapperSuperTest):
     
     def testmultiplemappers(self):
         mapper(User, users, properties={
-            'addresses':relation(Address, lazy=False)
+            'addresses':relation(Address, lazy=True)
         })
         mapper(Address, addresses)
 
+        sess = create_session()
+        
+        (user7, user8, user9) = sess.query(User).select()
+        (address1, address2, address3, address4) = sess.query(Address).select()
+        
         selectquery = users.outerjoin(addresses).select(use_labels=True)
-        q = create_session().query(User)
+        q = sess.query(User)
         l = q.instances(selectquery.execute(), Address)
         # note the result is a cartesian product
-        assert repr(l) == "[(User(user_id=7,user_name=u'jack'), Address(address_id=1,user_id=7,email_address=u'jack@bean.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=2,user_id=8,email_address=u'ed@wood.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=3,user_id=8,email_address=u'ed@bettyboop.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=4,user_id=8,email_address=u'ed@lala.com'))]"
+        assert l == [
+            (user7, address1),
+            (user8, address2),
+            (user8, address3),
+            (user8, address4),
+            (user9, None)
+        ]
+    
+    def testmultipleonquery(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=True)
+        })
+        mapper(Address, addresses)
+        sess = create_session()
+        (user7, user8, user9) = sess.query(User).select()
+        (address1, address2, address3, address4) = sess.query(Address).select()
+        q = sess.query(User)
+        q = q.add_entity(Address).outerjoin('addresses')
+        l = q.list()
+        assert l == [
+            (user7, address1),
+            (user8, address2),
+            (user8, address3),
+            (user8, address4),
+            (user9, None)
+        ]
+
+    def testcolumnonquery(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=True)
+        })
+        mapper(Address, addresses)
         
-        # check identity map still in effect even though dupe results
-        assert l[1][0] is l[2][0]
+        sess = create_session()
+        (user7, user8, user9) = sess.query(User).select()
+        q = sess.query(User)
+        q = q.group_by([c for c in users.c]).outerjoin('addresses').add_column(func.count(addresses.c.address_id).label('count'))
+        l = q.list()
+        assert l == [
+            (user7, 1),
+            (user8, 3),
+            (user9, 0)
+        ]
         
     def testmapperspluscolumn(self):
         mapper(User, users)
         s = select([users, func.count(addresses.c.address_id).label('count')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c])
-        q = create_session().query(User)
+        sess = create_session()
+        (user7, user8, user9) = sess.query(User).select()
+        q = sess.query(User)
         l = q.instances(s.execute(), "count")
-        assert repr(l) == "[(User(user_id=7,user_name=u'jack'), 1), (User(user_id=8,user_name=u'ed'), 3), (User(user_id=9,user_name=u'fred'), 0)]"
+        assert l == [
+            (user7, 1),
+            (user8, 3),
+            (user9, 0)
+        ]
+
+
 if __name__ == "__main__":    
     testbase.main()