From: Mike Bayer Date: Sun, 8 Jan 2006 07:24:33 +0000 (+0000) Subject: added count func to mapper X-Git-Tag: rel_0_1_0~151 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4c50fd22ed79425d52f5a05667757c4f83ad5304;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added count func to mapper --- diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 746ddb8b92..a1a3ee5fb3 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -338,8 +338,10 @@ class Mapper(object): return None def select_by(self, *args, **params): - """returns an array of object instances based on the given key/value criterion. + """returns an array of object instances based on the given clauses and key/value criterion. + *args is a list of zero or more ClauseElements which will be connected by AND operators. + **params is a set of zero or more key/value parameters which are converted into ClauseElements. the keys are mapped to property or column names mapped by this mapper's Table, and the values are coerced into a WHERE clause separated by AND operators. If the local property/column names dont contain the key, a search will be performed against this mapper's immediate @@ -348,6 +350,14 @@ class Mapper(object): e.g. result = usermapper.select_by(user_name = 'fred') """ + return self.select_whereclause(self._by_clause(*args, **params)) + + def count_by(self, *args, **params): + """returns the count of instances based on the given clauses and key/value criterion. + The criterion is constructed in the same way as the select_by() method.""" + return self.count(self._by_clause(*args, **params)) + + def _by_clause(self, *args, **params): clause = None for arg in args: if clause is None: @@ -364,7 +374,7 @@ class Mapper(object): clause = c else: clause &= c - return self.select_whereclause(clause) + return clause def _get_criterion(self, key, value): """used by select_by to match a key/value pair against @@ -427,6 +437,13 @@ class Mapper(object): else: return self.select_statement(statement) + def count(self, whereclause = None, params=None, **kwargs): + s = self.table.count(whereclause) + if params is not None: + return s.scalar(**params) + else: + return s.scalar() + def select_statement(self, statement, **params): statement.use_labels = True return self.instances(statement.execute(**params)) diff --git a/test/mapper.py b/test/mapper.py index 412166b862..0f5902f103 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -130,7 +130,13 @@ class MapperTest(MapperSuperTest): print "User", u.user_id, u.user_name, u.concat, u.count #l[1].user_name='asdf' #objectstore.commit() - + + def testcount(self): + m = mapper(User, users) + self.assert_(m.count()==3) + self.assert_(m.count(users.c.user_id.in_(8,9))==2) + self.assert_(m.count_by(user_name='fred')==1) + def testmultitable(self): usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) m = mapper(User, usersaddresses, primarytable = users, primary_key=[users.c.user_id])