From: Mike Bayer Date: Sun, 15 Nov 2009 19:20:22 +0000 (+0000) Subject: - pg8000 + postgresql dialects now check for float/numeric return X-Git-Tag: rel_0_6beta1~168 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5f6ed1a3f8bb0b2a724c7f07b98936433a3ef053;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - pg8000 + postgresql dialects now check for float/numeric return types to more intelligently determine float() vs. Decimal(), [ticket:1567] - since result processing is a hot issue of late, the DBAPI type returned from cursor.description is certainly useful in cases like these to determine an efficient result processor. There's likely other result processors that can make use of it. But, backwards incompat change to result_processor(). Happy major version number.. --- diff --git a/CHANGES b/CHANGES index 41fcbb8a35..9e09763d74 100644 --- a/CHANGES +++ b/CHANGES @@ -409,7 +409,17 @@ CHANGES - cached TypeEngine classes are cached per-dialect class instead of per-dialect. - + + - new UserDefinedType should be used as a base class for + new types, which preserves the 0.5 behavior of + get_col_spec(). + + - The result_processor() method of all type classes now + accepts a second argument "coltype", which is the DBAPI + type argument from cursor.description. This argument + can help some types decide on the most efficient processing + of result values. + - Deprecated Dialect.get_params() removed. - Dialect.get_rowcount() has been renamed to a descriptor diff --git a/examples/postgis/postgis.py b/examples/postgis/postgis.py index 8e687d7f8c..d84648a952 100644 --- a/examples/postgis/postgis.py +++ b/examples/postgis/postgis.py @@ -101,7 +101,7 @@ class Geometry(TypeEngine): return value return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is not None: return PersistentGisElement(value) diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py index ed8297137a..ee61190ff4 100644 --- a/lib/sqlalchemy/dialects/access/base.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -18,7 +18,7 @@ from sqlalchemy.engine import default, base class AcNumeric(types.Numeric): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): return None def bind_processor(self, dialect): @@ -86,7 +86,7 @@ class AcUnicode(types.Unicode): def bind_processor(self, dialect): return None - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): return None class AcChar(types.CHAR): @@ -101,7 +101,7 @@ class AcBoolean(types.Boolean): def get_col_spec(self): return "YESNO" - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 921c70e980..21fec6b51c 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -124,7 +124,7 @@ RESERVED_WORDS = set([ class _FBBoolean(sqltypes.Boolean): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/dialects/informix/base.py b/lib/sqlalchemy/dialects/informix/base.py index 5760488ae2..6565a812fe 100644 --- a/lib/sqlalchemy/dialects/informix/base.py +++ b/lib/sqlalchemy/dialects/informix/base.py @@ -41,7 +41,7 @@ class InfoTime(sqltypes.Time): return value return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if isinstance(value, datetime.datetime): return value.time() @@ -51,7 +51,7 @@ class InfoTime(sqltypes.Time): class InfoBoolean(sqltypes.Boolean): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py index d1c0191ed6..d5f00dbdd1 100644 --- a/lib/sqlalchemy/dialects/maxdb/base.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -83,7 +83,7 @@ class _StringType(sqltypes.String): return value return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): while True: if value is None: @@ -169,7 +169,7 @@ class MaxTimestamp(sqltypes.DateTime): dialect.datetimeformat,)) return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None @@ -209,7 +209,7 @@ class MaxDate(sqltypes.Date): dialect.datetimeformat,)) return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None @@ -243,7 +243,7 @@ class MaxTime(sqltypes.Time): dialect.datetimeformat,)) return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None @@ -270,7 +270,7 @@ class MaxBlob(sqltypes.Binary): return str(value) return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py index 10b8b33b30..6ca1879d61 100644 --- a/lib/sqlalchemy/dialects/mssql/adodbapi.py +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -3,7 +3,7 @@ from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect import sys class MSDateTime_adodbapi(MSDateTime): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): # adodbapi will return datetimes with empty time values as datetime.date() objects. # Promote them back to full datetime.datetime() diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 129125ca73..6c89377992 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -277,7 +277,7 @@ RESERVED_WORDS = set( class _MSNumeric(sqltypes.Numeric): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if self.asdecimal: def process(value): if value is not None: @@ -350,7 +350,7 @@ class _MSDate(sqltypes.Date): return process _reg = re.compile(r"(\d+)-(\d+)-(\d+)") - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if isinstance(value, datetime.datetime): return value.date() @@ -377,7 +377,7 @@ class TIME(sqltypes.TIME): return process _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if isinstance(value, datetime.datetime): return value.time() @@ -599,7 +599,7 @@ class BIT(sqltypes.TypeEngine): __visit_name__ = 'BIT' class _MSBoolean(sqltypes.Boolean): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 36a0425898..a9acc2a013 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -546,7 +546,7 @@ class BIT(sqltypes.TypeEngine): """ self.length = length - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): """Convert a MySQL's 64 bit, variable length binary string to a long.""" def process(value): if value is not None: @@ -562,7 +562,7 @@ class _MSTime(sqltypes.Time): __visit_name__ = 'TIME' - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): time = datetime.time def process(value): # convert from a timedelta value @@ -1042,7 +1042,7 @@ class SET(_StringType): length = max([len(v) for v in strip_values] + [0]) super(SET, self).__init__(length=length, **kw) - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): # The good news: # No ',' quoting issues- commas aren't allowed in SET values @@ -1085,7 +1085,7 @@ class _MSBoolean(sqltypes.Boolean): __visit_name__ = 'BOOLEAN' - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 49fa044a3e..846de6580c 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -48,7 +48,7 @@ class MySQL_mysqldbCompiler(MySQLCompiler): class _DecimalType(_NumericType): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if self.asdecimal: return def process(value): diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index 37537483d7..5558c1a192 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -36,7 +36,7 @@ class _PlainQuery(unicode): class _oursqlNumeric(NUMERIC): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if self.asdecimal: return def process(value): @@ -48,7 +48,7 @@ class _oursqlNumeric(NUMERIC): class _oursqlBIT(BIT): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): """oursql already converts mysql bits, so.""" return None diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py index bf1267aed1..dcb46789ab 100644 --- a/lib/sqlalchemy/dialects/mysql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -23,7 +23,7 @@ from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector from sqlalchemy.dialects.mysql.base import BIT, MySQLDialect, MySQLExecutionContext class _ZxJDBCBit(BIT): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): """Converts boolean or byte arrays from MySQL Connector/J to longs.""" def process(value): if value is None: diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 22ba2ce934..d13e37d60a 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -159,7 +159,7 @@ class _OracleBoolean(sqltypes.Boolean): def get_dbapi_type(self, dbapi): return dbapi.NUMBER - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index eb5f2cb43b..2db37a4fc6 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -84,7 +84,7 @@ class _OracleDate(sqltypes.Date): def bind_processor(self, dialect): return None - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if not isinstance(value, datetime): return value @@ -93,7 +93,7 @@ class _OracleDate(sqltypes.Date): return process class _OracleDateTime(sqltypes.DateTime): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None or isinstance(value, datetime): return value @@ -110,7 +110,7 @@ class _OracleDateTime(sqltypes.DateTime): # only if cx_oracle contains TIMESTAMP class _OracleTimestamp(sqltypes.TIMESTAMP): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None or isinstance(value, datetime): return value @@ -121,13 +121,13 @@ class _OracleTimestamp(sqltypes.TIMESTAMP): return process class _LOBMixin(object): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if not dialect.auto_convert_lobs: # return the cx_oracle.LOB directly. # don't even call super.result_processor here. return None - super_process = super(_LOBMixin, self).result_processor(dialect) + super_process = super(_LOBMixin, self).result_processor(dialect, coltype) lob = dialect.dbapi.LOB if super_process: def process(value): @@ -148,11 +148,11 @@ class _OracleChar(sqltypes.CHAR): return dbapi.FIXED_CHAR class _OracleNVarChar(sqltypes.NVARCHAR): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if dialect._cx_oracle_native_nvarchar: return None else: - return sqltypes.NVARCHAR.result_processor(self, dialect) + return sqltypes.NVARCHAR.result_processor(self, dialect, coltype) class _OracleText(_LOBMixin, sqltypes.Text): def get_dbapi_type(self, dbapi): @@ -163,7 +163,7 @@ class _OracleUnicodeText(_LOBMixin, sqltypes.UnicodeText): return dbapi.NCLOB class _OracleInteger(sqltypes.Integer): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def to_int(val): if val is not None: val = int(val) @@ -259,11 +259,14 @@ class Oracle_cx_oracleExecutionContext(OracleExecutionContext): for bind, name in self.compiled.bind_names.iteritems(): if name in self.out_parameters: type = bind.type - result_processor = type.dialect_impl(self.dialect).\ - result_processor(self.dialect) + impl_type = type.dialect_impl(self.dialect) + dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi) + result_processor = impl_type.\ + result_processor(self.dialect, + dbapi_type) if result_processor is not None: out_parameters[name] = \ - result_processor(self.out_parameters[name].getvalue()) + result_processor(self.out_parameters[name].getvalue(), dbapi_type) else: out_parameters[name] = self.out_parameters[name].getvalue() else: diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py index 6edef301c6..42c43d369a 100644 --- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -20,7 +20,7 @@ SQLException = zxJDBC = None class _ZxJDBCDate(sqltypes.Date): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None @@ -31,7 +31,7 @@ class _ZxJDBCDate(sqltypes.Date): class _ZxJDBCNumeric(sqltypes.Numeric): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if self.asdecimal: def process(value): if isinstance(value, decimal.Decimal): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 97108b3cbb..9c6de36235 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -174,8 +174,8 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): return [convert_item(item) for item in value] return process - def result_processor(self, dialect): - item_proc = self.item_type.result_processor(dialect) + def result_processor(self, dialect, coltype): + item_proc = self.item_type.result_processor(dialect, coltype) def process(value): if value is None: return value diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 17fe86be66..e90bebb6b2 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -21,25 +21,47 @@ Passing data from/to the Interval type is not supported as of yet. """ from sqlalchemy.engine import default import decimal -from sqlalchemy import util +from sqlalchemy import util, exc from sqlalchemy import types as sqltypes -from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext +from sqlalchemy.dialects.postgresql.base import PGDialect, \ + PGCompiler, PGIdentifierPreparer, PGExecutionContext class _PGNumeric(sqltypes.Numeric): def bind_processor(self, dialect): - return None - - def result_processor(self, dialect): + def process(value): + if value is not None: + return float(value) + else: + return value + return process + + def result_processor(self, dialect, coltype): if self.asdecimal: - return None + if coltype in (700, 701): + def process(value): + if value is not None: + return decimal.Decimal(str(value)) + else: + return value + return process + elif coltype == 1700: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) else: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - + if coltype in (700, 701): + # pg8000 returns float natively for 701 + return None + elif coltype == 1700: + def process(value): + if value is not None: + return float(value) + else: + return value + return process + else: + raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) class PostgreSQL_pg8000ExecutionContext(PGExecutionContext): pass @@ -79,7 +101,6 @@ class PostgreSQL_pg8000(PGDialect): PGDialect.colspecs, { sqltypes.Numeric : _PGNumeric, - sqltypes.Float: sqltypes.Float, # prevents _PGNumeric from being used } ) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index aa4e07bb3b..a46fdbddbf 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -51,17 +51,33 @@ class _PGNumeric(sqltypes.Numeric): def bind_processor(self, dialect): return None - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if self.asdecimal: - return None + if coltype in (700, 701): + def process(value): + if value is not None: + return decimal.Decimal(str(value)) + else: + return value + return process + elif coltype == 1700: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) else: - def process(value): - if isinstance(value, decimal.Decimal): - return float(value) - else: - return value - return process - + if coltype in (700, 701): + # pg8000 returns float natively for 701 + return None + elif coltype == 1700: + def process(value): + if value is not None: + return float(value) + else: + return value + return process + else: + raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) class _PGEnum(ENUM): def __init__(self, *arg, **kw): @@ -139,7 +155,6 @@ class PostgreSQL_psycopg2(PGDialect): PGDialect.colspecs, { sqltypes.Numeric : _PGNumeric, - sqltypes.Float: sqltypes.Float, # prevents _PGNumeric from being used ENUM : _PGEnum, # needs force_unicode sqltypes.Enum : _PGEnum, # needs force_unicode ARRAY : _PGArray, # needs force_unicode diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py index 517d41aaf8..2c33b3eb5b 100644 --- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -17,7 +17,7 @@ class PGNumeric(sqltypes.Numeric): def bind_processor(self, dialect): return None - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if self.asdecimal: return None else: diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 47c797c211..33feaeaaeb 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -96,7 +96,7 @@ class _SLDateTime(_DateTimeMixin, sqltypes.DateTime): ) _reg = re.compile(r"(\d+)-(\d+)-(\d+)(?: (\d+):(\d+):(\d+)(?:\.(\d+))?)?") - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): return self._result_processor(datetime.datetime, self._reg) class _SLDate(_DateTimeMixin, sqltypes.Date): @@ -107,7 +107,7 @@ class _SLDate(_DateTimeMixin, sqltypes.Date): ) _reg = re.compile(r"(\d+)-(\d+)-(\d+)") - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): return self._result_processor(datetime.date, self._reg) class _SLTime(_DateTimeMixin, sqltypes.Time): @@ -126,7 +126,7 @@ class _SLTime(_DateTimeMixin, sqltypes.Time): ) _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): return self._result_processor(datetime.time, self._reg) @@ -138,7 +138,7 @@ class _SLBoolean(sqltypes.Boolean): return value and 1 or 0 return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 6fc42c312e..cfdbd321ad 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -113,7 +113,7 @@ class SybaseUniqueIdentifier(sqltypes.TypeEngine): __visit_name__ = "UNIQUEIDENTIFIER" class SybaseBoolean(sqltypes.Boolean): - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): if value is None: return None diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index b65843f48d..3ea52cd725 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1757,7 +1757,6 @@ class ResultProxy(object): typemap = self.dialect.dbapi_type_map for i, (colname, coltype) in enumerate(m[0:2] for m in metadata): - if self.dialect.description_encoding: colname = colname.decode(self.dialect.description_encoding) @@ -1779,7 +1778,7 @@ class ResultProxy(object): name, obj, type_ = (colname, None, typemap.get(coltype, types.NULLTYPE)) processor = type_.dialect_impl(self.dialect).\ - result_processor(self.dialect) + result_processor(self.dialect, coltype) if processor: def make_colfunc(processor, index): diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index a215f31ff0..3fa18b2c26 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -47,12 +47,22 @@ class AbstractType(Visitable): return value def bind_processor(self, dialect): - """Defines a bind parameter processing function.""" + """Defines a bind parameter processing function. + + :param dialect: Dialect instance in use. + + """ return None - def result_processor(self, dialect): - """Defines a result-column processing function.""" + def result_processor(self, dialect, coltype): + """Defines a result-column processing function. + + :param dialect: Dialect instance in use. + + :param coltype: DBAPI coltype argument received in cursor.description. + + """ return None @@ -126,7 +136,7 @@ class TypeEngine(AbstractType): """ return None - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): """Return a conversion function for processing result row values. Returns a callable which will receive a result row column @@ -162,7 +172,7 @@ class UserDefinedType(TypeEngine): return value return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): return value return process @@ -300,10 +310,10 @@ class TypeDecorator(AbstractType): else: return self.impl.bind_processor(dialect) - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if self.__class__.process_result_value.func_code is not TypeDecorator.process_result_value.func_code: process_value = self.process_result_value - impl_processor = self.impl.result_processor(dialect) + impl_processor = self.impl.result_processor(dialect, coltype) if impl_processor: def process(value): return process_value(impl_processor(value), dialect) @@ -312,7 +322,7 @@ class TypeDecorator(AbstractType): return process_value(value, dialect) return process else: - return self.impl.result_processor(dialect) + return self.impl.result_processor(dialect, coltype) def copy(self): instance = self.__class__.__new__(self.__class__) @@ -511,7 +521,7 @@ class String(Concatenable, TypeEngine): else: return None - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if (not dialect.returns_unicode_strings or self.convert_unicode == 'force') \ and (self.convert_unicode or dialect.convert_unicode): def process(value): @@ -666,7 +676,10 @@ class Numeric(TypeEngine): self.asdecimal = asdecimal def adapt(self, impltype): - return impltype(precision=self.precision, scale=self.scale, asdecimal=self.asdecimal) + return impltype( + precision=self.precision, + scale=self.scale, + asdecimal=self.asdecimal) def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -679,7 +692,7 @@ class Numeric(TypeEngine): return value return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if self.asdecimal: def process(value): if value is not None: @@ -790,7 +803,7 @@ class Binary(TypeEngine): return None return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): if util.jython: def process(value): if value is not None: @@ -1041,8 +1054,8 @@ class PickleType(MutableType, TypeDecorator): return value return process - def result_processor(self, dialect): - impl_processor = self.impl.result_processor(dialect) + def result_processor(self, dialect, coltype): + impl_processor = self.impl.result_processor(dialect, coltype) loads = self.pickler.loads if impl_processor: def process(value): @@ -1111,8 +1124,8 @@ class Interval(TypeDecorator): return value return process - def result_processor(self, dialect): - impl_processor = self.impl.result_processor(dialect) + def result_processor(self, dialect, coltype): + impl_processor = self.impl.result_processor(dialect, coltype) epoch = self.epoch if impl_processor: def process(value): diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index aa2b99275a..152ca40dab 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -2,6 +2,7 @@ from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message from sqlalchemy.test import engines import datetime +import decimal from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy import exc, schema, types @@ -10,6 +11,7 @@ from sqlalchemy.engine.strategies import MockEngineStrategy from sqlalchemy.test import * from sqlalchemy.sql import table, column from sqlalchemy.test.testing import eq_ +from test.engine._base import TablesTest class SequenceTest(TestBase, AssertsCompiledSQL): def test_basic(self): @@ -105,6 +107,65 @@ class CompileTest(TestBase, AssertsCompiledSQL): "SELECT EXTRACT(%s FROM t.col1::timestamp) AS anon_1 " "FROM t" % field) +class FloatCoercionTest(TablesTest, AssertsExecutionResults): + __only_on__ = 'postgresql' + __dialect__ = postgresql.dialect() + + @classmethod + def define_tables(cls, metadata): + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer) + ) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + data_table.insert().execute( + {'data':3}, + {'data':5}, + {'data':7}, + {'data':2}, + {'data':15}, + {'data':12}, + {'data':6}, + {'data':478}, + {'data':52}, + {'data':9}, + ) + + def _round(self, x): + if isinstance(x, float): + return round(x, 9) + elif isinstance(x, decimal.Decimal): + # really ? + x = x.shift(decimal.Decimal(9)).to_integral() / pow(10, 9) + return x + @testing.resolve_artifact_names + def test_float_coercion(self): + for type_, result in [ + (Numeric, decimal.Decimal('140.381230939')), + (Float, 140.381230939), + (Float(asdecimal=True), decimal.Decimal('140.381230939')), + (Numeric(asdecimal=False), 140.381230939), + ]: + ret = testing.db.execute( + select([ + func.stddev_pop(data_table.c.data, type_=type_) + ]) + ).scalar() + + eq_(self._round(ret), result) + + ret = testing.db.execute( + select([ + cast(func.stddev_pop(data_table.c.data), type_) + ]) + ).scalar() + eq_(self._round(ret), result) + + + class EnumTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): __only_on__ = 'postgresql' __dialect__ = postgresql.dialect() diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 040397f4c3..6c6ad65e0b 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -43,7 +43,7 @@ class TestTypes(TestBase, AssertsExecutionResults): bp = sldt.bind_processor(None) eq_(bp(dt), '2008-06-27 12:00:00.000125') - rp = sldt.result_processor(None) + rp = sldt.result_processor(None, None) eq_(rp(bp(dt)), dt) sldt.__legacy_microseconds__ = True diff --git a/test/sql/test_types.py b/test/sql/test_types.py index c0b86c1e43..a3cb03022a 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -100,7 +100,7 @@ class UserDefinedTest(TestBase): def process(value): return "BIND_IN"+ value return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): return value + "BIND_OUT" return process @@ -114,8 +114,8 @@ class UserDefinedTest(TestBase): def process(value): return "BIND_IN"+ impl_processor(value) return process - def result_processor(self, dialect): - impl_processor = super(MyDecoratedType, self).result_processor(dialect) or (lambda value:value) + def result_processor(self, dialect, coltype): + impl_processor = super(MyDecoratedType, self).result_processor(dialect, coltype) or (lambda value:value) def process(value): return impl_processor(value) + "BIND_OUT" return process @@ -163,8 +163,8 @@ class UserDefinedTest(TestBase): return "BIND_IN"+ impl_processor(value) return process - def result_processor(self, dialect): - impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value) + def result_processor(self, dialect, coltype): + impl_processor = super(MyUnicodeType, self).result_processor(dialect, coltype) or (lambda value:value) def process(value): return impl_processor(value) + "BIND_OUT" return process @@ -528,7 +528,7 @@ class ExpressionTest(TestBase, AssertsExecutionResults): def process(value): return value * 10 return process - def result_processor(self, dialect): + def result_processor(self, dialect, coltype): def process(value): return value / 10 return process