]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- pg8000 + postgresql dialects now check for float/numeric return
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Nov 2009 19:20:22 +0000 (19:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Nov 2009 19:20:22 +0000 (19:20 +0000)
types to more intelligently determine float() vs. Decimal(),
[ticket:1567]
- since result processing is a hot issue of late, the DBAPI type
returned from cursor.description is certainly useful in cases like
these to determine an efficient result processor.   There's likely
other result processors that can make use of it.  But, backwards
incompat change to result_processor().  Happy major version number..

26 files changed:
CHANGES
examples/postgis/postgis.py
lib/sqlalchemy/dialects/access/base.py
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/informix/base.py
lib/sqlalchemy/dialects/maxdb/base.py
lib/sqlalchemy/dialects/mssql/adodbapi.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/oursql.py
lib/sqlalchemy/dialects/mysql/zxjdbc.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/oracle/zxjdbc.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/postgresql/pypostgresql.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sybase/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/types.py
test/dialect/test_postgresql.py
test/dialect/test_sqlite.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index 41fcbb8a35bcae6855a068cfcafd922f4e2624b2..9e09763d74ea9fe222722b68b22a7172f1a128a5 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -409,7 +409,17 @@ CHANGES
 
     - cached TypeEngine classes are cached per-dialect class 
       instead of per-dialect.
-
+    
+    - new UserDefinedType should be used as a base class for
+      new types, which preserves the 0.5 behavior of 
+      get_col_spec().
+      
+    - The result_processor() method of all type classes now 
+      accepts a second argument "coltype", which is the DBAPI
+      type argument from cursor.description.  This argument
+      can help some types decide on the most efficient processing
+      of result values.
+      
     - Deprecated Dialect.get_params() removed.
 
     - Dialect.get_rowcount() has been renamed to a descriptor
index 8e687d7f8caa6990db7f664d880235c863fd93e1..d84648a952b2930589b46c73302d2ad6ed648b55 100644 (file)
@@ -101,7 +101,7 @@ class Geometry(TypeEngine):
                 return value
         return process
         
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is not None:
                 return PersistentGisElement(value)
index ed8297137a2ed3921f17baa74b0eddb940c0aeba..ee61190ff4024ea22b78302f29ddcc931b38f37b 100644 (file)
@@ -18,7 +18,7 @@ from sqlalchemy.engine import default, base
 
 
 class AcNumeric(types.Numeric):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         return None
 
     def bind_processor(self, dialect):
@@ -86,7 +86,7 @@ class AcUnicode(types.Unicode):
     def bind_processor(self, dialect):
         return None
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         return None
 
 class AcChar(types.CHAR):
@@ -101,7 +101,7 @@ class AcBoolean(types.Boolean):
     def get_col_spec(self):
         return "YESNO"
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index 921c70e980144c903db9a43f727d102c87a02cf3..21fec6b51c80d5d01384b1697f7ad78813f2f403 100644 (file)
@@ -124,7 +124,7 @@ RESERVED_WORDS = set([
 
 
 class _FBBoolean(sqltypes.Boolean):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index 5760488ae2ba5b94dc4cac7bc6d5899371ccd5e4..6565a812fed101097a02d4ce15d66fcb869a1587 100644 (file)
@@ -41,7 +41,7 @@ class InfoTime(sqltypes.Time):
             return value
         return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if isinstance(value, datetime.datetime):
                 return value.time()
@@ -51,7 +51,7 @@ class InfoTime(sqltypes.Time):
 
 
 class InfoBoolean(sqltypes.Boolean):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index d1c0191ed60a7933d76b860ba57900991e667223..d5f00dbdd131e26c91ce6ba2a44d4f67654d63fe 100644 (file)
@@ -83,7 +83,7 @@ class _StringType(sqltypes.String):
                     return value
             return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             while True:
                 if value is None:
@@ -169,7 +169,7 @@ class MaxTimestamp(sqltypes.DateTime):
                     dialect.datetimeformat,))
         return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
@@ -209,7 +209,7 @@ class MaxDate(sqltypes.Date):
                     dialect.datetimeformat,))
         return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
@@ -243,7 +243,7 @@ class MaxTime(sqltypes.Time):
                     dialect.datetimeformat,))
         return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
@@ -270,7 +270,7 @@ class MaxBlob(sqltypes.Binary):
                 return str(value)
         return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index 10b8b33b30796716efec3751ef088643c07c0d26..6ca1879d612674e6e368403ca8262468e1fd4c3d 100644 (file)
@@ -3,7 +3,7 @@ from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
 import sys
 
 class MSDateTime_adodbapi(MSDateTime):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             # adodbapi will return datetimes with empty time values as datetime.date() objects.
             # Promote them back to full datetime.datetime()
index 129125ca73a6155d7d4535aeb5caf8711627db97..6c89377992d5ecf8dbb09aa7642c400bed17b88f 100644 (file)
@@ -277,7 +277,7 @@ RESERVED_WORDS = set(
 
 
 class _MSNumeric(sqltypes.Numeric):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if self.asdecimal:
             def process(value):
                 if value is not None:
@@ -350,7 +350,7 @@ class _MSDate(sqltypes.Date):
         return process
 
     _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if isinstance(value, datetime.datetime):
                 return value.date()
@@ -377,7 +377,7 @@ class TIME(sqltypes.TIME):
         return process
 
     _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if isinstance(value, datetime.datetime):
                 return value.time()
@@ -599,7 +599,7 @@ class BIT(sqltypes.TypeEngine):
     __visit_name__ = 'BIT'
     
 class _MSBoolean(sqltypes.Boolean):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index 36a0425898d081eeefb9ce209512db971ee2d0fc..a9acc2a013aba1b5970821fac669a926c1fd76ec 100644 (file)
@@ -546,7 +546,7 @@ class BIT(sqltypes.TypeEngine):
         """
         self.length = length
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         """Convert a MySQL's 64 bit, variable length binary string to a long."""
         def process(value):
             if value is not None:
@@ -562,7 +562,7 @@ class _MSTime(sqltypes.Time):
 
     __visit_name__ = 'TIME'
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         time = datetime.time
         def process(value):
             # convert from a timedelta value
@@ -1042,7 +1042,7 @@ class SET(_StringType):
         length = max([len(v) for v in strip_values] + [0])
         super(SET, self).__init__(length=length, **kw)
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             # The good news:
             #   No ',' quoting issues- commas aren't allowed in SET values
@@ -1085,7 +1085,7 @@ class _MSBoolean(sqltypes.Boolean):
 
     __visit_name__ = 'BOOLEAN'
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index 49fa044a3e038d4ee9260011c26614136f92f1a0..846de6580cb7bddcf5acf4e963b43de500d5b4fa 100644 (file)
@@ -48,7 +48,7 @@ class MySQL_mysqldbCompiler(MySQLCompiler):
 
 
 class _DecimalType(_NumericType):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if self.asdecimal:
             return
         def process(value):
index 37537483d78233ed2cfcc5f864b2c25c5b82a9a3..5558c1a1928d6db950289f645d943cb6678b4661 100644 (file)
@@ -36,7 +36,7 @@ class _PlainQuery(unicode):
 
 
 class _oursqlNumeric(NUMERIC):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if self.asdecimal:
             return
         def process(value):
@@ -48,7 +48,7 @@ class _oursqlNumeric(NUMERIC):
 
 
 class _oursqlBIT(BIT):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         """oursql already converts mysql bits, so."""
 
         return None
index bf1267aed1d7a121defdb25666f19149d00feead..dcb46789ab207370cfce30b9d5f38ad9eae46988 100644 (file)
@@ -23,7 +23,7 @@ from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
 from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext
 
 class _ZxJDBCBit(BIT):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         """Converts boolean or byte arrays from MySQL Connector/J to longs."""
         def process(value):
             if value is None:
index 22ba2ce934e821372f4911053672214d95944cdc..d13e37d60a108be2efafa3bdb997ef7f23e8f1d0 100644 (file)
@@ -159,7 +159,7 @@ class _OracleBoolean(sqltypes.Boolean):
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
     
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index eb5f2cb43bf56d239a20ffd1b426e10fbab254f7..2db37a4fc68a5f655dcf5b3ca78509313f7f4673 100644 (file)
@@ -84,7 +84,7 @@ class _OracleDate(sqltypes.Date):
     def bind_processor(self, dialect):
         return None
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if not isinstance(value, datetime):
                 return value
@@ -93,7 +93,7 @@ class _OracleDate(sqltypes.Date):
         return process
 
 class _OracleDateTime(sqltypes.DateTime):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None or isinstance(value, datetime):
                 return value
@@ -110,7 +110,7 @@ class _OracleDateTime(sqltypes.DateTime):
 
 # only if cx_oracle contains TIMESTAMP
 class _OracleTimestamp(sqltypes.TIMESTAMP):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None or isinstance(value, datetime):
                 return value
@@ -121,13 +121,13 @@ class _OracleTimestamp(sqltypes.TIMESTAMP):
         return process
 
 class _LOBMixin(object):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if not dialect.auto_convert_lobs:
             # return the cx_oracle.LOB directly.
             # don't even call super.result_processor here.
             return None
             
-        super_process = super(_LOBMixin, self).result_processor(dialect)
+        super_process = super(_LOBMixin, self).result_processor(dialect, coltype)
         lob = dialect.dbapi.LOB
         if super_process:
             def process(value):
@@ -148,11 +148,11 @@ class _OracleChar(sqltypes.CHAR):
         return dbapi.FIXED_CHAR
 
 class _OracleNVarChar(sqltypes.NVARCHAR):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if dialect._cx_oracle_native_nvarchar:
             return None
         else:
-            return sqltypes.NVARCHAR.result_processor(self, dialect)
+            return sqltypes.NVARCHAR.result_processor(self, dialect, coltype)
         
 class _OracleText(_LOBMixin, sqltypes.Text):
     def get_dbapi_type(self, dbapi):
@@ -163,7 +163,7 @@ class _OracleUnicodeText(_LOBMixin, sqltypes.UnicodeText):
         return dbapi.NCLOB
 
 class _OracleInteger(sqltypes.Integer):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def to_int(val):
             if val is not None:
                 val = int(val)
@@ -259,11 +259,14 @@ class Oracle_cx_oracleExecutionContext(OracleExecutionContext):
                 for bind, name in self.compiled.bind_names.iteritems():
                     if name in self.out_parameters:
                         type = bind.type
-                        result_processor = type.dialect_impl(self.dialect).\
-                                                    result_processor(self.dialect)
+                        impl_type = type.dialect_impl(self.dialect)
+                        dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
+                        result_processor = impl_type.\
+                                                    result_processor(self.dialect, 
+                                                    dbapi_type)
                         if result_processor is not None:
                             out_parameters[name] = \
-                                    result_processor(self.out_parameters[name].getvalue())
+                                    result_processor(self.out_parameters[name].getvalue(), dbapi_type)
                         else:
                             out_parameters[name] = self.out_parameters[name].getvalue()
             else:
index 6edef301c6184ca6d8ec90b6f0d94e70b3a79f06..42c43d369adc371fd44935ca6eb6087b9317e53f 100644 (file)
@@ -20,7 +20,7 @@ SQLException = zxJDBC = None
 
 class _ZxJDBCDate(sqltypes.Date):
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
@@ -31,7 +31,7 @@ class _ZxJDBCDate(sqltypes.Date):
 
 class _ZxJDBCNumeric(sqltypes.Numeric):
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if self.asdecimal:
             def process(value):
                 if isinstance(value, decimal.Decimal):
index 97108b3cbb9b04a00f809ce5908c2008969ca0af..9c6de362358d1e023325d96cfec407cdbc1cac94 100644 (file)
@@ -174,8 +174,8 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
             return [convert_item(item) for item in value]
         return process
 
-    def result_processor(self, dialect):
-        item_proc = self.item_type.result_processor(dialect)
+    def result_processor(self, dialect, coltype):
+        item_proc = self.item_type.result_processor(dialect, coltype)
         def process(value):
             if value is None:
                 return value
index 17fe86be66bf0de056346e96d5929ddd853e2ed4..e90bebb6b290954c1d163e7b02f87bcef22baddd 100644 (file)
@@ -21,25 +21,47 @@ Passing data from/to the Interval type is not supported as of yet.
 """
 from sqlalchemy.engine import default
 import decimal
-from sqlalchemy import util
+from sqlalchemy import util, exc
 from sqlalchemy import types as sqltypes
-from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext
+from sqlalchemy.dialects.postgresql.base import PGDialect, \
+                PGCompiler, PGIdentifierPreparer, PGExecutionContext
 
 class _PGNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
-        return None
-
-    def result_processor(self, dialect):
+        def process(value):
+            if value is not None:
+                return float(value)
+            else:
+                return value
+        return process
+    
+    def result_processor(self, dialect, coltype):
         if self.asdecimal:
-            return None
+            if coltype in (700, 701):
+                def process(value):
+                    if value is not None:
+                        return decimal.Decimal(str(value))
+                    else:
+                        return value
+                return process
+            elif coltype == 1700:
+                # pg8000 returns Decimal natively for 1700
+                return None
+            else:
+                raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
         else:
-            def process(value):
-                if isinstance(value, decimal.Decimal):
-                    return float(value)
-                else:
-                    return value
-            return process
-
+            if coltype in (700, 701):
+                # pg8000 returns float natively for 701
+                return None
+            elif coltype == 1700:
+                def process(value):
+                    if value is not None:
+                        return float(value)
+                    else:
+                        return value
+                return process
+            else:
+                raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
 
 class PostgreSQL_pg8000ExecutionContext(PGExecutionContext):
     pass
@@ -79,7 +101,6 @@ class PostgreSQL_pg8000(PGDialect):
         PGDialect.colspecs,
         {
             sqltypes.Numeric : _PGNumeric,
-            sqltypes.Float: sqltypes.Float,  # prevents _PGNumeric from being used
         }
     )
     
index aa4e07bb3bd2f0aa64540a5bd81861c72a71ec26..a46fdbddbfd7e5d6e0894c741f856fc157ec08fb 100644 (file)
@@ -51,17 +51,33 @@ class _PGNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
         return None
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if self.asdecimal:
-            return None
+            if coltype in (700, 701):
+                def process(value):
+                    if value is not None:
+                        return decimal.Decimal(str(value))
+                    else:
+                        return value
+                return process
+            elif coltype == 1700:
+                # pg8000 returns Decimal natively for 1700
+                return None
+            else:
+                raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
         else:
-            def process(value):
-                if isinstance(value, decimal.Decimal):
-                    return float(value)
-                else:
-                    return value
-            return process
-
+            if coltype in (700, 701):
+                # pg8000 returns float natively for 701
+                return None
+            elif coltype == 1700:
+                def process(value):
+                    if value is not None:
+                        return float(value)
+                    else:
+                        return value
+                return process
+            else:
+                raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype)
 
 class _PGEnum(ENUM):
     def __init__(self, *arg, **kw):
@@ -139,7 +155,6 @@ class PostgreSQL_psycopg2(PGDialect):
         PGDialect.colspecs,
         {
             sqltypes.Numeric : _PGNumeric,
-            sqltypes.Float: sqltypes.Float,  # prevents _PGNumeric from being used
             ENUM : _PGEnum, # needs force_unicode
             sqltypes.Enum : _PGEnum, # needs force_unicode
             ARRAY : _PGArray, # needs force_unicode
index 517d41aaf80062e53d9f1eade87cc9dfd71b5108..2c33b3eb5b1e9d45c212156749a71b003021ec52 100644 (file)
@@ -17,7 +17,7 @@ class PGNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
         return None
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if self.asdecimal:
             return None
         else:
index 47c797c2118f61c484eed0fbdc37831f4ff5cbfb..33feaeaaeb5b47b9cd81971558744cd0520646d9 100644 (file)
@@ -96,7 +96,7 @@ class _SLDateTime(_DateTimeMixin, sqltypes.DateTime):
                         )
 
     _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?")
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         return self._result_processor(datetime.datetime, self._reg)
 
 class _SLDate(_DateTimeMixin, sqltypes.Date):
@@ -107,7 +107,7 @@ class _SLDate(_DateTimeMixin, sqltypes.Date):
                 )
 
     _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         return self._result_processor(datetime.date, self._reg)
 
 class _SLTime(_DateTimeMixin, sqltypes.Time):
@@ -126,7 +126,7 @@ class _SLTime(_DateTimeMixin, sqltypes.Time):
                     )
 
     _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         return self._result_processor(datetime.time, self._reg)
 
 
@@ -138,7 +138,7 @@ class _SLBoolean(sqltypes.Boolean):
             return value and 1 or 0
         return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index 6fc42c312ecbcba45d895210ce81e6e45fadc506..cfdbd321ad70be39ac2e8e23250d2ff24d42c8f4 100644 (file)
@@ -113,7 +113,7 @@ class SybaseUniqueIdentifier(sqltypes.TypeEngine):
     __visit_name__ = "UNIQUEIDENTIFIER"
     
 class SybaseBoolean(sqltypes.Boolean):
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         def process(value):
             if value is None:
                 return None
index b65843f48d771131a01e8ebd12e2948161699de2..3ea52cd725176cee4b57fbf4d3e972c1aed16747 100644 (file)
@@ -1757,7 +1757,6 @@ class ResultProxy(object):
         typemap = self.dialect.dbapi_type_map
 
         for i, (colname, coltype) in enumerate(m[0:2] for m in metadata):
-
             if self.dialect.description_encoding:
                 colname = colname.decode(self.dialect.description_encoding)
 
@@ -1779,7 +1778,7 @@ class ResultProxy(object):
                 name, obj, type_ = (colname, None, typemap.get(coltype, types.NULLTYPE))
 
             processor = type_.dialect_impl(self.dialect).\
-                            result_processor(self.dialect)
+                            result_processor(self.dialect, coltype)
             
             if processor:
                 def make_colfunc(processor, index):
index a215f31ff0395269f3c9ff433ecab9fb3587197f..3fa18b2c26d1d1a18a85416471b119dd2f17ceba 100644 (file)
@@ -47,12 +47,22 @@ class AbstractType(Visitable):
         return value
 
     def bind_processor(self, dialect):
-        """Defines a bind parameter processing function."""
+        """Defines a bind parameter processing function.
+        
+        :param dialect: Dialect instance in use.
+
+        """
 
         return None
 
-    def result_processor(self, dialect):
-        """Defines a result-column processing function."""
+    def result_processor(self, dialect, coltype):
+        """Defines a result-column processing function.
+        
+        :param dialect: Dialect instance in use.
+
+        :param coltype: DBAPI coltype argument received in cursor.description.
+        
+        """
 
         return None
 
@@ -126,7 +136,7 @@ class TypeEngine(AbstractType):
         """
         return None
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         """Return a conversion function for processing result row values.
 
         Returns a callable which will receive a result row column
@@ -162,7 +172,7 @@ class UserDefinedType(TypeEngine):
                   return value
               return process
 
-          def result_processor(self, dialect):
+          def result_processor(self, dialect, coltype):
               def process(value):
                   return value
               return process
@@ -300,10 +310,10 @@ class TypeDecorator(AbstractType):
         else:
             return self.impl.bind_processor(dialect)
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if self.__class__.process_result_value.func_code is not TypeDecorator.process_result_value.func_code:
             process_value = self.process_result_value
-            impl_processor = self.impl.result_processor(dialect)
+            impl_processor = self.impl.result_processor(dialect, coltype)
             if impl_processor:
                 def process(value):
                     return process_value(impl_processor(value), dialect)
@@ -312,7 +322,7 @@ class TypeDecorator(AbstractType):
                     return process_value(value, dialect)
             return process
         else:
-            return self.impl.result_processor(dialect)
+            return self.impl.result_processor(dialect, coltype)
 
     def copy(self):
         instance = self.__class__.__new__(self.__class__)
@@ -511,7 +521,7 @@ class String(Concatenable, TypeEngine):
         else:
             return None
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if (not dialect.returns_unicode_strings or self.convert_unicode == 'force') \
             and (self.convert_unicode or dialect.convert_unicode):
             def process(value):
@@ -666,7 +676,10 @@ class Numeric(TypeEngine):
         self.asdecimal = asdecimal
 
     def adapt(self, impltype):
-        return impltype(precision=self.precision, scale=self.scale, asdecimal=self.asdecimal)
+        return impltype(
+                precision=self.precision, 
+                scale=self.scale, 
+                asdecimal=self.asdecimal)
 
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
@@ -679,7 +692,7 @@ class Numeric(TypeEngine):
                 return value
         return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if self.asdecimal:
             def process(value):
                 if value is not None:
@@ -790,7 +803,7 @@ class Binary(TypeEngine):
                 return None
         return process
 
-    def result_processor(self, dialect):
+    def result_processor(self, dialect, coltype):
         if util.jython:
             def process(value):
                 if value is not None:
@@ -1041,8 +1054,8 @@ class PickleType(MutableType, TypeDecorator):
                 return value
         return process
 
-    def result_processor(self, dialect):
-        impl_processor = self.impl.result_processor(dialect)
+    def result_processor(self, dialect, coltype):
+        impl_processor = self.impl.result_processor(dialect, coltype)
         loads = self.pickler.loads
         if impl_processor:
             def process(value):
@@ -1111,8 +1124,8 @@ class Interval(TypeDecorator):
                 return value
         return process
 
-    def result_processor(self, dialect):
-        impl_processor = self.impl.result_processor(dialect)
+    def result_processor(self, dialect, coltype):
+        impl_processor = self.impl.result_processor(dialect, coltype)
         epoch = self.epoch
         if impl_processor:
             def process(value):
index aa2b99275aeb33ac3cef914bebdabc429d370d15..152ca40dabc1dfcb253e89ce49328167a1e11747 100644 (file)
@@ -2,6 +2,7 @@
 from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
 from sqlalchemy.test import  engines
 import datetime
+import decimal
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from sqlalchemy import exc, schema, types
@@ -10,6 +11,7 @@ from sqlalchemy.engine.strategies import MockEngineStrategy
 from sqlalchemy.test import *
 from sqlalchemy.sql import table, column
 from sqlalchemy.test.testing import eq_
+from test.engine._base import TablesTest
 
 class SequenceTest(TestBase, AssertsCompiledSQL):
     def test_basic(self):
@@ -105,6 +107,65 @@ class CompileTest(TestBase, AssertsCompiledSQL):
                 "SELECT EXTRACT(%s FROM t.col1::timestamp) AS anon_1 "
                 "FROM t" % field)
 
+class FloatCoercionTest(TablesTest, AssertsExecutionResults):
+    __only_on__ = 'postgresql'
+    __dialect__ = postgresql.dialect()
+
+    @classmethod
+    def define_tables(cls, metadata):
+        data_table = Table('data_table', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', Integer)
+        )
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def insert_data(cls):
+        data_table.insert().execute(
+            {'data':3},
+            {'data':5},
+            {'data':7},
+            {'data':2},
+            {'data':15},
+            {'data':12},
+            {'data':6},
+            {'data':478},
+            {'data':52},
+            {'data':9},
+        )
+    
+    def _round(self, x):
+        if isinstance(x, float):
+            return round(x, 9)
+        elif isinstance(x, decimal.Decimal):
+            # really ?
+            x = x.shift(decimal.Decimal(9)).to_integral() / pow(10, 9)
+        return x
+    @testing.resolve_artifact_names
+    def test_float_coercion(self):
+        for type_, result in [
+            (Numeric, decimal.Decimal('140.381230939')),
+            (Float, 140.381230939),
+            (Float(asdecimal=True), decimal.Decimal('140.381230939')),
+            (Numeric(asdecimal=False), 140.381230939),
+        ]:
+            ret = testing.db.execute(
+                select([
+                    func.stddev_pop(data_table.c.data, type_=type_)
+                ])
+            ).scalar()
+            
+            eq_(self._round(ret), result)
+
+            ret = testing.db.execute(
+                select([
+                    cast(func.stddev_pop(data_table.c.data), type_)
+                ])
+            ).scalar()
+            eq_(self._round(ret), result)
+    
+    
+        
 class EnumTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     __only_on__ = 'postgresql'
     __dialect__ = postgresql.dialect()
index 040397f4c3d98f436e2846f1d06b9fb221eea13a..6c6ad65e0b286cbcefa40dc8efb4e036d71d7d36 100644 (file)
@@ -43,7 +43,7 @@ class TestTypes(TestBase, AssertsExecutionResults):
         bp = sldt.bind_processor(None)
         eq_(bp(dt), '2008-06-27 12:00:00.000125')
         
-        rp = sldt.result_processor(None)
+        rp = sldt.result_processor(None, None)
         eq_(rp(bp(dt)), dt)
         
         sldt.__legacy_microseconds__ = True
index c0b86c1e4307aa3dfd7ff8b377b8aafb23f7854c..a3cb03022a8da411b11f7f028553b277b4a51e6f 100644 (file)
@@ -100,7 +100,7 @@ class UserDefinedTest(TestBase):
                 def process(value):
                     return "BIND_IN"+ value
                 return process
-            def result_processor(self, dialect):
+            def result_processor(self, dialect, coltype):
                 def process(value):
                     return value + "BIND_OUT"
                 return process
@@ -114,8 +114,8 @@ class UserDefinedTest(TestBase):
                 def process(value):
                     return "BIND_IN"+ impl_processor(value)
                 return process
-            def result_processor(self, dialect):
-                impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value)
+            def result_processor(self, dialect, coltype):
+                impl_processor = super(MyDecoratedType, self).result_processor(dialect, coltype) or (lambda value:value)
                 def process(value):
                     return impl_processor(value) + "BIND_OUT"
                 return process
@@ -163,8 +163,8 @@ class UserDefinedTest(TestBase):
                     return "BIND_IN"+ impl_processor(value)
                 return process
 
-            def result_processor(self, dialect):
-                impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value)
+            def result_processor(self, dialect, coltype):
+                impl_processor = super(MyUnicodeType, self).result_processor(dialect, coltype) or (lambda value:value)
                 def process(value):
                     return impl_processor(value) + "BIND_OUT"
                 return process
@@ -528,7 +528,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults):
                 def process(value):
                     return value * 10
                 return process
-            def result_processor(self, dialect):
+            def result_processor(self, dialect, coltype):
                 def process(value):
                     return value / 10
                 return process