]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- func.count() with no argument emits COUNT(*)
authorJason Kirtland <jek@discorporate.us>
Thu, 24 Jul 2008 21:36:16 +0000 (21:36 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 24 Jul 2008 21:36:16 +0000 (21:36 +0000)
CHANGES
lib/sqlalchemy/sql/functions.py
test/sql/functions.py

diff --git a/CHANGES b/CHANGES
index b10e0a1cd8511ba5a7234b852da6fce6d6dd9dd0..100626f1a2df52d1cf68cf66f5468b4189295270 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -50,7 +50,11 @@ CHANGES
       should reduce the probability of "Attribute x was 
       not replaced during compile" warnings. (this generally
       applies to SQLA hackers, like Elixir devs).
-      
+
+- sql
+    - func.count() with no arguments renders as COUNT(*),
+      equivalent to func.count(text('*')). 
 - ext
     - Class-bound attributes sent as arguments to 
       relation()'s remote_side and foreign_keys parameters 
index 7303bd0c614fe4f8a9276ca8726eb064bdeac1e7..7fce3b95b1797cddd552c642f9e8e333aeba016a 100644 (file)
@@ -1,6 +1,6 @@
 from sqlalchemy import types as sqltypes
 from sqlalchemy.sql.expression import (
-    ClauseList, _FigureVisitName, _Function, _literal_as_binds,
+    ClauseList, _FigureVisitName, _Function, _literal_as_binds, text
     )
 from sqlalchemy.sql import operators
 
@@ -61,6 +61,16 @@ class random(GenericFunction):
         kwargs.setdefault('type_', None)
         GenericFunction.__init__(self, args=args, **kwargs)
 
+class count(GenericFunction):
+    """The ANSI COUNT aggregate function.  With no arguments, emits COUNT *."""
+
+    __return_type__ = sqltypes.Integer
+
+    def __init__(self, expression=None, **kwargs):
+        if expression is None:
+            expression = text('*')
+        GenericFunction.__init__(self, args=(expression,), **kwargs)
+
 class current_date(AnsiFunction):
     __return_type__ = sqltypes.Date
 
index 6754d6d42836ac49eafeebe14145402a446c4a25..681d6a5575e8dc01fbee6e84e2f30ad4982bb5b4 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy.engine import default
 from sqlalchemy import types as sqltypes
 from testlib import *
 from sqlalchemy.sql.functions import GenericFunction
+from testlib.testing import eq_
 
 from sqlalchemy.databases import *
 # every dialect in databases.__all__ is expected to pass these tests.
@@ -68,6 +69,14 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         ]:
             self.assert_compile(func.random(), ret, dialect=dialect)
 
+    def test_generic_count(self):
+        assert isinstance(func.count().type, sqltypes.Integer)
+
+        self.assert_compile(func.count(), 'count(*)')
+        self.assert_compile(func.count(1), 'count(:param_1)')
+        c = column('abc')
+        self.assert_compile(func.count(c), 'count(abc)')
+
     def test_constructor(self):
         try:
             func.current_timestamp('somearg')