From: Mike Bayer Date: Tue, 24 Jul 2007 16:36:14 +0000 (+0000) Subject: - Numeric and Float types now have an "asdecimal" flag; defaults to X-Git-Tag: rel_0_4_6~33 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=782e578a41385b9997cb10e9e88e224e83d1dec0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Numeric and Float types now have an "asdecimal" flag; defaults to True for Numeric, False for Float. when True, values are returned as decimal.Decimal objects; when False, values are returned as float(). the defaults of True/False are already the behavior for PG and MySQL's DBAPI modules. [ticket:646] --- diff --git a/CHANGES b/CHANGES index f47e7be66f..60cf22b226 100644 --- a/CHANGES +++ b/CHANGES @@ -154,6 +154,11 @@ - MetaData: - DynamicMetaData has been renamed to ThreadLocalMetaData - BoundMetaData has been removed- regular MetaData is equivalent + - Numeric and Float types now have an "asdecimal" flag; defaults to + True for Numeric, False for Float. when True, values are returned as + decimal.Decimal objects; when False, values are returned as float(). + the defaults of True/False are already the behavior for PG and MySQL's + DBAPI modules. [ticket:646] - new SQL operator implementation which removes all hardcoded operators from expression structures and moves them into compilation; allows greater flexibility of operator compilation; for example, "+" @@ -191,11 +196,12 @@ - better quoting of identifiers when manipulating schemas - standardized the behavior for table reflection where types can't be located; NullType is substituted instead, warning is raised. + - ColumnCollection (i.e. the 'c' attribute on tables) follows dictionary + semantics for "__contains__" [ticket:606] + - engines - Connections gain a .properties collection, with contents scoped to the lifetime of the underlying DBAPI connection - - ColumnCollection (i.e. the 'c' attribute on tables) follows dictionary - semantics for "__contains__" [ticket:606] - extensions - proxyengine is temporarily removed, pending an actually working replacement. diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 6e5616c0bd..f8b6e9bd79 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -12,6 +12,7 @@ import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions import sqlalchemy.util as util from array import array as _array +from decimal import Decimal try: from threading import Lock @@ -135,7 +136,7 @@ class _StringType(object): class MSNumeric(sqltypes.Numeric, _NumericType): """MySQL NUMERIC type""" - def __init__(self, precision = 10, length = 2, **kw): + def __init__(self, precision = 10, length = 2, asdecimal=True, **kw): """Construct a NUMERIC. precision @@ -155,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType): """ _NumericType.__init__(self, **kw) - sqltypes.Numeric.__init__(self, precision, length) - + sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal) + def get_col_spec(self): if self.precision is None: return self._extend("NUMERIC") else: return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}) + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + if not self.asdecimal and isinstance(value, Decimal): + return float(value) + else: + return value + class MSDecimal(MSNumeric): """MySQL DECIMAL type""" - def __init__(self, precision=10, length=2, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DECIMAL. precision @@ -185,7 +195,7 @@ class MSDecimal(MSNumeric): underlying database API, which continue to be numeric. """ - super(MSDecimal, self).__init__(precision, length, **kw) + super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw) def get_col_spec(self): if self.precision is None: @@ -198,7 +208,7 @@ class MSDecimal(MSNumeric): class MSDouble(MSNumeric): """MySQL DOUBLE type""" - def __init__(self, precision=10, length=2, **kw): + def __init__(self, precision=10, length=2, asdecimal=True, **kw): """Construct a DOUBLE. precision @@ -220,7 +230,7 @@ class MSDouble(MSNumeric): if ((precision is None and length is not None) or (precision is not None and length is None)): raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.") - super(MSDouble, self).__init__(precision, length, **kw) + super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw) def get_col_spec(self): if self.precision is not None and self.length is not None: @@ -233,7 +243,7 @@ class MSDouble(MSNumeric): class MSFloat(sqltypes.Float, _NumericType): """MySQL FLOAT type""" - def __init__(self, precision=10, length=None, **kw): + def __init__(self, precision=10, length=None, asdecimal=False, **kw): """Construct a FLOAT. precision @@ -255,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType): if length is not None: self.length=length _NumericType.__init__(self, **kw) - sqltypes.Float.__init__(self, precision) + sqltypes.Float.__init__(self, precision, asdecimal=asdecimal) def get_col_spec(self): if hasattr(self, 'length') and self.length is not None: @@ -265,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType): else: return self._extend("FLOAT") + def convert_bind_param(self, value, dialect): + return value + + class MSInteger(sqltypes.Integer, _NumericType): """MySQL INTEGER type""" diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index d8f467358f..0561012797 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -10,6 +10,7 @@ from sqlalchemy import sql, schema, ansisql, exceptions from sqlalchemy.engine import base, default import sqlalchemy.types as sqltypes from sqlalchemy.databases import information_schema as ischema +from decimal import Decimal try: import mx.DateTime.DateTime as mxDateTime @@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric): else: return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} + def convert_bind_param(self, value, dialect): + return value + + def convert_result_value(self, value, dialect): + if not self.asdecimal and isinstance(value, Decimal): + return float(value) + else: + return value + class PGFloat(sqltypes.Float): def get_col_spec(self): if not self.precision: @@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float): else: return "FLOAT(%(precision)s)" % {'precision': self.precision} + class PGInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 06720fd661..4292e9dcc9 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -13,6 +13,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine', import inspect import datetime as dt +from decimal import Decimal try: import cPickle as pickle except: @@ -246,22 +247,36 @@ class SmallInteger(Integer): Smallinteger = SmallInteger class Numeric(TypeEngine): - def __init__(self, precision = 10, length = 2): + def __init__(self, precision = 10, length = 2, asdecimal=True): self.precision = precision self.length = length + self.asdecimal = asdecimal def adapt(self, impltype): - return impltype(precision=self.precision, length=self.length) + return impltype(precision=self.precision, length=self.length, asdecimal=self.asdecimal) def get_dbapi_type(self, dbapi): return dbapi.NUMBER + def convert_bind_param(self, value, dialect): + if value is not None: + return float(value) + else: + return value + + def convert_result_value(self, value, dialect): + if value is not None and self.asdecimal: + return Decimal(str(value)) + else: + return value + class Float(Numeric): - def __init__(self, precision = 10): + def __init__(self, precision = 10, asdecimal=False, **kwargs): + super(Float, self).__init__(asdecimal=asdecimal, **kwargs) self.precision = precision def adapt(self, impltype): - return impltype(precision=self.precision) + return impltype(precision=self.precision, asdecimal=self.asdecimal) class DateTime(TypeEngine): """Implement a type for ``datetime.datetime()`` objects.""" diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 8dbeda19af..d0ec06caa8 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -355,6 +355,36 @@ class DateTest(AssertMixin): finally: t.drop(checkfirst=True) +class NumericTest(AssertMixin): + def setUpAll(self): + global numeric_table, metadata + metadata = MetaData(testbase.db) + numeric_table = Table('numeric_table', metadata, + Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True), + Column('numericcol', Numeric(asdecimal=False)), + Column('floatcol', Float), + Column('ncasdec', Numeric), + Column('fcasdec', Float(asdecimal=True)) + ) + metadata.create_all() + + def tearDownAll(self): + metadata.drop_all() + + def tearDown(self): + numeric_table.delete().execute() + + def test_decimal(self): + from decimal import Decimal + numeric_table.insert().execute(numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.78) + numeric_table.insert().execute(numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.78")) + print numeric_table.select().execute().fetchall() + assert numeric_table.select().execute().fetchall() == [ + (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")), + (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")), + ] + + class IntervalTest(AssertMixin): def setUpAll(self): global interval_table, metadata