From 2f547222da46b38df07dde08e60bc5efbb0afd79 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 25 Apr 2007 17:49:26 +0000 Subject: [PATCH] - added generative versions of aggregates, i.e. sum(), avg(), etc. to query. used via query.apply_max(), apply_sum(), etc. #552 --- CHANGES | 3 +++ lib/sqlalchemy/orm/query.py | 46 ++++++++++++++++++++++++++++++++++++- test/orm/generative.py | 4 ++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/CHANGES b/CHANGES index 450ef32799..fea91b7f51 100644 --- a/CHANGES +++ b/CHANGES @@ -80,6 +80,9 @@ takes optional string "property" to isolate the desired relation. also adds static Query.query_from_parent(instance, property) version. [ticket:541] + - added generative versions of aggregates, i.e. sum(), avg(), etc. + to query. used via query.apply_max(), apply_sum(), etc. + #552 - corresponding to label/bindparam name generataion, eager loaders generate deterministic names for the aliases they create using md5 hashes. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 9eec1bc0ec..c43b9a9460 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -43,6 +43,8 @@ class Query(object): self._offset = kwargs.pop('offset', None) self._limit = kwargs.pop('limit', None) self._criterion = None + self._col = None + self._func = None self._joinpoint = self.mapper self._from_obj = [self.table] @@ -71,6 +73,8 @@ class Query(object): q._from_obj = list(self._from_obj) q._joinpoint = self._joinpoint q._criterion = self._criterion + q._col = self._col + q._func = self._func return q def _get_session(self): @@ -318,7 +322,6 @@ class Query(object): """Given a ``WHERE`` criterion, create a ``SELECT`` statement, execute and return the resulting instances. """ - statement = self.compile(whereclause, **kwargs) return self._select_statement(statement, params=params) @@ -611,6 +614,41 @@ class Query(object): raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key) return [keys, p] + def _generative_col_aggregate(self, col, func): + """apply the given aggregate function to the query and return the newly + resulting ``Query``. + """ + if self._col is not None or self._func is not None: + raise exceptions.InvalidRequestError("Query already contains an aggregate column or function") + q = self._clone() + q._col = col + q._func = func + return q + + def apply_min(self, col): + """apply the SQL ``min()`` function against the given column to the + query and return the newly resulting ``Query``. + """ + return self._generative_col_aggregate(col, sql.func.min) + + def apply_max(self, col): + """apply the SQL ``max()`` function against the given column to the + query and return the newly resulting ``Query``. + """ + return self._generative_col_aggregate(col, sql.func.max) + + def apply_sum(self, col): + """apply the SQL ``sum()`` function against the given column to the + query and return the newly resulting ``Query``. + """ + return self._generative_col_aggregate(col, sql.func.sum) + + def apply_avg(self, col): + """apply the SQL ``avg()`` function against the given column to the + query and return the newly resulting ``Query``. + """ + return self._generative_col_aggregate(col, sql.func.avg) + def _col_aggregate(self, col, func): """Execute ``func()`` function against the given column. @@ -767,6 +805,12 @@ class Query(object): """ return list(self) + + def scalar(self): + if self._col is None or self._func is None: + return self[0] + else: + return self._col_aggregate(self._col, self._func) def __iter__(self): return iter(self.select_whereclause()) diff --git a/test/orm/generative.py b/test/orm/generative.py index 6cda219645..512f04ae90 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -59,6 +59,7 @@ class GenerativeQueryTest(PersistTest): assert self.query.count() == 100 assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0 assert self.query.filter(foo.c.bar<30).max(foo.c.bar) == 29 + assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).scalar() == 29 @testbase.unsupported('mysql') def test_aggregate_1(self): @@ -73,6 +74,9 @@ class GenerativeQueryTest(PersistTest): def test_aggregate_2_int(self): assert int(self.res.filter(foo.c.bar<30).avg(foo.c.bar)) == 14 + def test_aggregate_3(self): + assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).scalar() == 14.5 + def test_filter(self): assert self.query.count() == 100 assert self.query.filter(Foo.c.bar < 30).count() == 30 -- 2.47.2