From aeb8c429bf5f74e35e4c297dfbb82588beb8ade8 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 2 Nov 2006 21:31:56 +0000 Subject: [PATCH] - implemented from_obj argument for query.count, improves count function on selectresults [ticket:325] --- CHANGES | 2 ++ lib/sqlalchemy/orm/query.py | 13 +++++++++++-- test/ext/selectresults.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/CHANGES b/CHANGES index fb0334cfdc..e4abe15650 100644 --- a/CHANGES +++ b/CHANGES @@ -17,6 +17,8 @@ contained a cyclical many-to-one relationship to object B, and object B was just attached to object A, *but* object B itself wasnt changed, the many-to-one synchronize of B's primary key attribute to A's foreign key attribute wouldnt occur. [ticket:360] +- implemented from_obj argument for query.count, improves count function +on selectresults [ticket:325] 0.3.0 - General: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 9c76c6a176..e257a1cbe1 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -255,10 +255,19 @@ class Query(object): def count(self, whereclause=None, params=None, **kwargs): """given a WHERE criterion, create a SELECT COUNT statement, execute and return the resulting count value.""" + + from_obj = kwargs.pop('from_obj', []) + alltables = [] + for l in [sql_util.TableFinder(x) for x in from_obj]: + alltables += l + + if self.table not in alltables: + from_obj.append(self.table) + if self._nestable(**kwargs): - s = self.table.select(whereclause, **kwargs).alias('getcount').count() + s = sql.select([self.table], whereclause, **kwargs).alias('getcount').count() else: - s = self.table.count(whereclause) + s = sql.select([sql.func.count(list(self.table.primary_key)[0])], whereclause, from_obj=from_obj, **kwargs) return self.session.scalar(self.mapper, s, params=params) def select_statement(self, statement, **params): diff --git a/test/ext/selectresults.py b/test/ext/selectresults.py index d82ad96fb6..ebb0e69e2d 100644 --- a/test/ext/selectresults.py +++ b/test/ext/selectresults.py @@ -155,6 +155,17 @@ class RelationsTest(AssertMixin): x = query.outerjoin_to('orders').outerjoin_to('items').select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) print x.compile() self.assert_result(list(x), tables.User, *tables.user_result[1:3]) + def test_outerjointo_count(self): + """test the join_to and outerjoin_to functions on SelectResults""" + mapper(tables.User, tables.users, properties={ + 'orders':relation(mapper(tables.Order, tables.orders, properties={ + 'items':relation(mapper(tables.Item, tables.orderitems)) + })) + }) + session = create_session() + query = SelectResults(session.query(tables.User)) + x = query.outerjoin_to('orders').outerjoin_to('items').select(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count() + assert x==2 def test_from(self): mapper(tables.User, tables.users, properties={ 'orders':relation(mapper(tables.Order, tables.orders, properties={ -- 2.47.2