]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added generative versions of aggregates, i.e. sum(), avg(), etc.
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 25 Apr 2007 17:49:26 +0000 (17:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 25 Apr 2007 17:49:26 +0000 (17:49 +0000)
to query. used via query.apply_max(), apply_sum(), etc.
#552

CHANGES
lib/sqlalchemy/orm/query.py
test/orm/generative.py

diff --git a/CHANGES b/CHANGES
index 450ef327992700f929367aa530da11dc48621228..fea91b7f51eed99f9ea77829512a2ac49f3ec6bd 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -80,6 +80,9 @@
       takes optional string "property" to isolate the desired relation.
       also adds static Query.query_from_parent(instance, property)
       version. [ticket:541]
+    - added generative versions of aggregates, i.e. sum(), avg(), etc.
+      to query. used via query.apply_max(), apply_sum(), etc. 
+      #552
     - corresponding to label/bindparam name generataion, eager loaders 
       generate deterministic names for the aliases they create using 
       md5 hashes.
index 9eec1bc0eccd39c6e8c5d8308fb5747d80463372..c43b9a94608262e58d42eaf0a24291f90aa90348 100644 (file)
@@ -43,6 +43,8 @@ class Query(object):
         self._offset = kwargs.pop('offset', None)
         self._limit = kwargs.pop('limit', None)
         self._criterion = None
+        self._col = None
+        self._func = None
         self._joinpoint = self.mapper
         self._from_obj = [self.table]
 
@@ -71,6 +73,8 @@ class Query(object):
         q._from_obj = list(self._from_obj)
         q._joinpoint = self._joinpoint
         q._criterion = self._criterion
+        q._col = self._col
+        q._func = self._func
         return q
     
     def _get_session(self):
@@ -318,7 +322,6 @@ class Query(object):
         """Given a ``WHERE`` criterion, create a ``SELECT`` statement,
         execute and return the resulting instances.
         """
-
         statement = self.compile(whereclause, **kwargs)
         return self._select_statement(statement, params=params)
 
@@ -611,6 +614,41 @@ class Query(object):
             raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key)
         return [keys, p]
 
+    def _generative_col_aggregate(self, col, func):
+        """apply the given aggregate function to the query and return the newly
+        resulting ``Query``.
+        """
+        if self._col is not None or self._func is not None:
+            raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
+        q = self._clone()
+        q._col = col
+        q._func = func
+        return q
+
+    def apply_min(self, col):
+        """apply the SQL ``min()`` function against the given column to the
+        query and return the newly resulting ``Query``.
+        """
+        return self._generative_col_aggregate(col, sql.func.min)
+
+    def apply_max(self, col):
+        """apply the SQL ``max()`` function against the given column to the
+        query and return the newly resulting ``Query``.
+        """
+        return self._generative_col_aggregate(col, sql.func.max)
+
+    def apply_sum(self, col):
+        """apply the SQL ``sum()`` function against the given column to the
+        query and return the newly resulting ``Query``.
+        """
+        return self._generative_col_aggregate(col, sql.func.sum)
+
+    def apply_avg(self, col):
+        """apply the SQL ``avg()`` function against the given column to the
+        query and return the newly resulting ``Query``.
+        """
+        return self._generative_col_aggregate(col, sql.func.avg)
+
     def _col_aggregate(self, col, func):
         """Execute ``func()`` function against the given column.
 
@@ -767,6 +805,12 @@ class Query(object):
         """
 
         return list(self)
+
+    def scalar(self):
+        if self._col is None or self._func is None: 
+            return self[0]
+        else:
+            return self._col_aggregate(self._col, self._func)
     
     def __iter__(self):
         return iter(self.select_whereclause())
index 6cda219645635688cf5e1c11d93c29ed37c0e715..512f04ae9076ae09bbdcd5840cb66f801780227f 100644 (file)
@@ -59,6 +59,7 @@ class GenerativeQueryTest(PersistTest):
         assert self.query.count() == 100
         assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0
         assert self.query.filter(foo.c.bar<30).max(foo.c.bar) == 29
+        assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).scalar() == 29
 
     @testbase.unsupported('mysql')
     def test_aggregate_1(self):
@@ -73,6 +74,9 @@ class GenerativeQueryTest(PersistTest):
     def test_aggregate_2_int(self):
         assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
 
+    def test_aggregate_3(self):
+        assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).scalar() == 14.5
+        
     def test_filter(self):
         assert self.query.count() == 100
         assert self.query.filter(Foo.c.bar < 30).count() == 30