From 7252ccd7c988d2fe2f218401a0a81738e19fa239 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 27 Aug 2006 21:22:28 +0000 Subject: [PATCH] - fix to using query.count() with distinct, **kwargs with SelectResults count() [ticket:287] --- CHANGES | 2 ++ lib/sqlalchemy/ext/selectresults.py | 2 +- lib/sqlalchemy/orm/query.py | 13 ++++++--- test/orm/selectresults.py | 42 +++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 4 deletions(-) diff --git a/CHANGES b/CHANGES index 085a5979ed..9ae530ace4 100644 --- a/CHANGES +++ b/CHANGES @@ -49,6 +49,8 @@ so far will convert this to "TIME[STAMP] (WITH|WITHOUT) TIME ZONE", so that control over timezone presence is more controllable (psycopg2 returns datetimes with tzinfo's if available, which can create confusion against datetimes that dont). +- fix to using query.count() with distinct, **kwargs with SelectResults +count() [ticket:287] 0.2.7 - quoting facilities set up so that database-specific quoting can be diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py index 79d56ec675..a35cdfa7ee 100644 --- a/lib/sqlalchemy/ext/selectresults.py +++ b/lib/sqlalchemy/ext/selectresults.py @@ -28,7 +28,7 @@ class SelectResults(object): def count(self): """executes the SQL count() function against the SelectResults criterion.""" - return self._query.count(self._clause) + return self._query.count(self._clause, **self._ops) def _col_aggregate(self, col, func): """executes func() function against the given column diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 1e9d40c753..29cc56761d 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -232,7 +232,10 @@ class Query(object): return self._select_statement(statement, params=params) def count(self, whereclause=None, params=None, **kwargs): - s = self.table.count(whereclause) + if self._nestable(**kwargs): + s = self.table.select(whereclause, **kwargs).alias('getcount').count() + else: + s = self.table.count(whereclause) return self.session.scalar(self.mapper, s, params=params) def select_statement(self, statement, **params): @@ -302,14 +305,18 @@ class Query(object): return self.instances(statement, params=params, **kwargs) def _should_nest(self, **kwargs): - """returns True if the given statement options indicate that we should "nest" the + """return True if the given statement options indicate that we should "nest" the generated query as a subquery inside of a larger eager-loading query. this is used with keywords like distinct, limit and offset and the mapper defines eager loads.""" return ( self.mapper.has_eager() - and (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False)) + and self._nestable(**kwargs) ) + def _nestable(self, **kwargs): + """return true if the given statement options imply it should be nested.""" + return (kwargs.has_key('limit') or kwargs.has_key('offset') or kwargs.get('distinct', False)) + def compile(self, whereclause = None, **kwargs): order_by = kwargs.pop('order_by', False) from_obj = kwargs.pop('from_obj', []) diff --git a/test/orm/selectresults.py b/test/orm/selectresults.py index 3f5bcff923..c4b1d6a56e 100644 --- a/test/orm/selectresults.py +++ b/test/orm/selectresults.py @@ -79,6 +79,48 @@ class SelectResultsTest(PersistTest): def test_offset(self): assert len(list(self.res.limit(10))) == 10 +class Obj1(object): + pass +class Obj2(object): + pass + +class SelectResultsTest2(PersistTest): + def setUpAll(self): + self.install_threadlocal() + global metadata, table1, table2 + metadata = BoundMetaData(testbase.db) + table1 = Table('Table1', metadata, + Column('id', Integer, primary_key=True), + ) + table2 = Table('Table2', metadata, + Column('t1id', Integer, ForeignKey("Table1.id"), primary_key=True), + Column('num', Integer, primary_key=True), + ) + assign_mapper(Obj1, table1, extension=SelectResultsExt()) + assign_mapper(Obj2, table2, extension=SelectResultsExt()) + metadata.create_all() + table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4}) + table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\ +{'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3}) + + def setUp(self): + self.query = Obj1.mapper.query() + #self.orig = self.query.select_whereclause() + #self.res = self.query.select() + + def tearDownAll(self): + metadata.drop_all() + self.uninstall_threadlocal() + + def test_distinctcount(self): + res = self.query.select() + assert res.count() == 4 + res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)) + assert res.count() == 3 + res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1), distinct=True) + self.assertEqual(res.count(), 1) + + if __name__ == "__main__": testbase.main() -- 2.47.2