]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added count func to mapper
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Jan 2006 07:24:33 +0000 (07:24 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Jan 2006 07:24:33 +0000 (07:24 +0000)
lib/sqlalchemy/mapping/mapper.py
test/mapper.py

index 746ddb8b92665f4a33f171c0218ef65f1238b559..a1a3ee5fb3160f31a8f8e29ad2da06be4cdfd838 100644 (file)
@@ -338,8 +338,10 @@ class Mapper(object):
             return None
             
     def select_by(self, *args, **params):
-        """returns an array of object instances based on the given key/value criterion. 
+        """returns an array of object instances based on the given clauses and key/value criterion. 
         
+        *args is a list of zero or more ClauseElements which will be connected by AND operators.
+        **params is a set of zero or more key/value parameters which are converted into ClauseElements.
         the keys are mapped to property or column names mapped by this mapper's Table, and the values
         are coerced into a WHERE clause separated by AND operators.  If the local property/column
         names dont contain the key, a search will be performed against this mapper's immediate
@@ -348,6 +350,14 @@ class Mapper(object):
         
         e.g.   result = usermapper.select_by(user_name = 'fred')
         """
+        return self.select_whereclause(self._by_clause(*args, **params))
+
+    def count_by(self, *args, **params):
+        """returns the count of instances based on the given clauses and key/value criterion.
+        The criterion is constructed in the same way as the select_by() method."""
+        return self.count(self._by_clause(*args, **params))
+        
+    def _by_clause(self, *args, **params):
         clause = None
         for arg in args:
             if clause is None:
@@ -364,7 +374,7 @@ class Mapper(object):
                 clause = c
             else:                
                 clause &= c
-        return self.select_whereclause(clause)
+        return clause
 
     def _get_criterion(self, key, value):
         """used by select_by to match a key/value pair against
@@ -427,6 +437,13 @@ class Mapper(object):
         else:
             return self.select_statement(statement)
 
+    def count(self, whereclause = None, params=None, **kwargs):
+        s = self.table.count(whereclause)
+        if params is not None:
+            return s.scalar(**params)
+        else:
+            return s.scalar()
+
     def select_statement(self, statement, **params):
         statement.use_labels = True
         return self.instances(statement.execute(**params))
index 412166b86229e271f4bde3ddce19ff3a776bb2b4..0f5902f10317967b5e59c2598f11999a79159a83 100644 (file)
@@ -130,7 +130,13 @@ class MapperTest(MapperSuperTest):
             print "User", u.user_id, u.user_name, u.concat, u.count
         #l[1].user_name='asdf'
         #objectstore.commit()
-        
+    
+    def testcount(self):
+        m = mapper(User, users)
+        self.assert_(m.count()==3)
+        self.assert_(m.count(users.c.user_id.in_(8,9))==2)
+        self.assert_(m.count_by(user_name='fred')==1)
+            
     def testmultitable(self):
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
         m = mapper(User, usersaddresses, primarytable = users, primary_key=[users.c.user_id])