]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Query.count() has been enhanced to do the "right
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Nov 2008 16:06:05 +0000 (16:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Nov 2008 16:06:05 +0000 (16:06 +0000)
thing" in a wider variety of cases. It can now
count multiple-entity queries, as well as
column-based queries. Note that this means if you
say query(A, B).count() without any joining
criterion, it's going to count the cartesian
product of A*B. Any query which is against
column-based entities will automatically issue
"SELECT count(1) FROM (SELECT...)" so that the
real rowcount is returned, meaning a query such as
query(func.count(A.name)).count() will return a value of
one, since that query would return one row.

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 41f509c5f45711852b67b0c17345c2db9ef7a3bc..c59eb826ae76ee59b79a7aa5e3b76974f4fe5cf3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,15 +6,30 @@ CHANGES
 =======
 0.5.0rc4
 ========
+- features
+- orm
+    - Query.count() has been enhanced to do the "right
+      thing" in a wider variety of cases. It can now
+      count multiple-entity queries, as well as
+      column-based queries. Note that this means if you
+      say query(A, B).count() without any joining
+      criterion, it's going to count the cartesian
+      product of A*B. Any query which is against
+      column-based entities will automatically issue
+      "SELECT count(1) FROM (SELECT...)" so that the
+      real rowcount is returned, meaning a query such as
+      query(func.count(A.name)).count() will return a value of
+      one, since that query would return one row.
+      
 - bugfixes and behavioral changes
 - general:
     - global "propigate"->"propagate" change.
 
 - orm
-    - Query.count() and Query.get() return a more informative
+    - Query.get() returns a more informative
       error message when executed against multiple entities.
       [ticket:1220]
-
+      
 - access
     - Added support for Currency type.
 
index 39e3db43c6336050ced196e3e6862d795146283f..81250706bea20e927774440790eb070c46ac1ef4 100644 (file)
@@ -1246,28 +1246,54 @@ class Query(object):
                 kwargs.get('distinct', False))
 
     def count(self):
-        """Apply this query's criterion to a SELECT COUNT statement."""
-
+        """Apply this query's criterion to a SELECT COUNT statement.
+        
+        If column expressions or LIMIT/OFFSET/DISTINCT are present,
+        the query "SELECT count(1) FROM (SELECT ...)" is issued, 
+        so that the result matches the total number of rows
+        this query would return.  For mapped entities,
+        the primary key columns of each is written to the 
+        columns clause of the nested SELECT statement.
+        
+        For a Query which is only against mapped entities,
+        a simpler "SELECT count(1) FROM table1, table2, ... 
+        WHERE criterion" is issued.  
+        
+        """
+        should_nest = [self._should_nest_selectable]
+        def ent_cols(ent):
+            if isinstance(ent, _MapperEntity):
+                return ent.mapper.primary_key
+            else:
+                should_nest[0] = True
+                return [ent.column]
+                
         return self._col_aggregate(sql.literal_column('1'), sql.func.count, 
-            nested_cols=list(self._only_mapper_zero(
-                "Can't issue count() for multiple types of objects or columns. "
-                " Construct the Query against a single element as the thing to be counted, "
-                "or for an actual row count use Query(func.count(somecolumn)) or "
-                "query.values(func.count(somecolumn)) instead.").primary_key))
+            nested_cols=chain(*[ent_cols(ent) for ent in self._entities]),
+            should_nest = should_nest[0]
+        )
 
-    def _col_aggregate(self, col, func, nested_cols=None):
+    def _col_aggregate(self, col, func, nested_cols=None, should_nest=False):
         context = QueryContext(self)
 
+        for entity in self._entities:
+            entity.setup_context(self, context)
+
+        if context.from_clause:
+            from_obj = [context.from_clause]
+        else:
+            from_obj = context.froms
+
         self._adjust_for_single_inheritance(context)
 
         whereclause  = context.whereclause
 
-        from_obj = self.__mapper_zero_from_obj()
-
-        if self._should_nest_selectable:
+        if should_nest:
             if not nested_cols:
                 nested_cols = [col]
-            s = sql.select(nested_cols, whereclause, from_obj=from_obj, **self._select_args)
+            else:
+                nested_cols = list(nested_cols)
+            s = sql.select(nested_cols, whereclause, from_obj=from_obj, use_labels=True, **self._select_args)
             s = s.alias()
             s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s)
         else:
index c90707342cf5605cf001f7570b4744f8676f4a46..3e2f327c35cb0248569728b8728abaacd2773942 100644 (file)
@@ -240,8 +240,6 @@ class InvalidGenerationsTest(QueryTest):
         s = create_session()
         
         q = s.query(User, Address)
-        self.assertRaises(sa_exc.InvalidRequestError, q.count)
-
         self.assertRaises(sa_exc.InvalidRequestError, q.get, 5)
         
     def test_from_statement(self):
@@ -779,10 +777,53 @@ class AggregateTest(QueryTest):
 
 class CountTest(QueryTest):
     def test_basic(self):
-        assert 4 == create_session().query(User).count()
+        s = create_session()
+        
+        eq_(s.query(User).count(), 4)
+
+        eq_(s.query(User).filter(users.c.name.endswith('ed')).count(), 2)
+
+    def test_multiple_entity(self):
+        s = create_session()
+        q = s.query(User, Address)
+        eq_(q.count(), 20)  # cartesian product
+        
+        q = s.query(User, Address).join(User.addresses)
+        eq_(q.count(), 5)
+    
+    def test_nested(self):
+        s = create_session()
+        q = s.query(User, Address).limit(2)
+        eq_(q.count(), 2)
 
-        assert 2 == create_session().query(User).filter(users.c.name.endswith('ed')).count()
+        q = s.query(User, Address).limit(100)
+        eq_(q.count(), 20)
 
+        q = s.query(User, Address).join(User.addresses).limit(100)
+        eq_(q.count(), 5)
+    
+    def test_cols(self):
+        """test that column-based queries always nest."""
+        
+        s = create_session()
+        
+        q = s.query(func.count(distinct(User.name)))
+        eq_(q.count(), 1)
+
+        q = s.query(func.count(distinct(User.name))).distinct()
+        eq_(q.count(), 1)
+
+        q = s.query(User.name)
+        eq_(q.count(), 4)
+
+        q = s.query(User.name, Address)
+        eq_(q.count(), 20)
+
+        q = s.query(Address.user_id)
+        eq_(q.count(), 5)
+        eq_(q.distinct().count(), 3)
+        
+        
 class DistinctTest(QueryTest):
     def test_basic(self):
         assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).distinct().all()