]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added func.min(), func.max(), func.sum() as "generic functions",
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Sep 2008 03:51:47 +0000 (03:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Sep 2008 03:51:47 +0000 (03:51 +0000)
which basically allows for their return type to be determined
automatically.  Helps with dates on SQLite, decimal types,
others. [ticket:1160]

- added decimal.Decimal as an "auto-detect" type; bind parameters
and generic functions will set their type to Numeric when a
Decimal is used.

CHANGES
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/types.py
test/sql/functions.py

diff --git a/CHANGES b/CHANGES
index 31dd9e71164b0358d2717543344b02e002e8a915..fc7bc65da31c4ea74d13014f0d54c398c8c54d81 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -149,6 +149,15 @@ CHANGES
       [ticket:1068].  This feature is on hold pending further
       development.
 
+    - Added func.min(), func.max(), func.sum() as "generic functions",
+      which basically allows for their return type to be determined
+      automatically.  Helps with dates on SQLite, decimal types, 
+      others. [ticket:1160]
+    
+    - added decimal.Decimal as an "auto-detect" type; bind parameters
+      and generic functions will set their type to Numeric when a 
+      Decimal is used.
+      
 - mysql
     - The 'length' argument to MSInteger, MSBigInteger, MSTinyInteger,
       MSSmallInteger and MSYear has been renamed to 'display_width'.
index 7fce3b95b1797cddd552c642f9e8e333aeba016a..c7a0f142d0b9e4996597e40796e97e4821a3a5d8 100644 (file)
@@ -36,12 +36,25 @@ class AnsiFunction(GenericFunction):
     def __init__(self, **kwargs):
         GenericFunction.__init__(self, **kwargs)
 
-
-class coalesce(GenericFunction):
+class ReturnTypeFromArgs(GenericFunction):
+    """Define a function whose return type is the same as its arguments."""
+    
     def __init__(self, *args, **kwargs):
         kwargs.setdefault('type_', _type_from_args(args))
         GenericFunction.__init__(self, args=args, **kwargs)
 
+class coalesce(ReturnTypeFromArgs):
+    pass
+
+class max(ReturnTypeFromArgs):
+    pass
+
+class min(ReturnTypeFromArgs):
+    pass
+
+class sum(ReturnTypeFromArgs):
+    pass
+
 class now(GenericFunction):
     __return_type__ = sqltypes.DateTime
 
index 3690ed3ca243d2fb2a05a3adbfe259e2e409996d..4958e4812a21f4bc0502be9e43c09ce3432f94ef 100644 (file)
@@ -625,6 +625,7 @@ type_map = {
     unicode : NCHAR,
     int : Integer,
     float : Numeric,
+    _python_Decimal : Numeric,
     dt.date : Date,
     dt.datetime : DateTime,
     dt.time : Time,
index 27e87ecebc22708b8cabb62c02933719b949e85d..ac9b7e3292a831943acbbe3010752b02024e3678 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import types as sqltypes
 from testlib import *
 from sqlalchemy.sql.functions import GenericFunction
 from testlib.testing import eq_
+from decimal import Decimal as _python_Decimal
 
 from sqlalchemy.databases import *
 
@@ -90,13 +91,21 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         except TypeError:
             assert True
 
-    def test_typing(self):
-        assert isinstance(func.coalesce(datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)).type, sqltypes.Date)
-
-        assert isinstance(func.coalesce(None, datetime.date(2005, 10, 15)).type, sqltypes.Date)
-
+    def test_return_type_detection(self):
+        
+        for fn in [func.coalesce, func.max, func.min, func.sum]:
+            for args, type_ in [
+                            ((datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)), sqltypes.Date),
+                            ((3, 5), sqltypes.Integer),
+                            ((_python_Decimal(3), _python_Decimal(5)), sqltypes.Numeric),
+                            (("foo", "bar"), sqltypes.String),
+                            ((datetime.datetime(2007, 10, 5, 8, 3, 34), datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime)
+                        ]:
+                assert isinstance(fn(*args).type, type_), "%s / %s" % (fn(), type_)
+        
         assert isinstance(func.concat("foo", "bar").type, sqltypes.String)
 
+
     def test_assorted(self):
         table1 = table('mytable',
             column('myid', Integer),