]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Numeric and Float types now have an "asdecimal" flag; defaults to
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jul 2007 16:36:14 +0000 (16:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jul 2007 16:36:14 +0000 (16:36 +0000)
True for Numeric, False for Float.  when True, values are returned as
decimal.Decimal objects; when False, values are returned as float().
the defaults of True/False are already the behavior for PG and MySQL's
DBAPI modules. [ticket:646]

CHANGES
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/types.py
test/sql/testtypes.py

diff --git a/CHANGES b/CHANGES
index f47e7be66f7a7866d439784c6a461bafac749f48..60cf22b226060694773202d643956962d7796bb2 100644 (file)
--- a/CHANGES
+++ b/CHANGES
   - MetaData:
     - DynamicMetaData has been renamed to ThreadLocalMetaData
     - BoundMetaData has been removed- regular MetaData is equivalent
+  - Numeric and Float types now have an "asdecimal" flag; defaults to 
+    True for Numeric, False for Float.  when True, values are returned as
+    decimal.Decimal objects; when False, values are returned as float().
+    the defaults of True/False are already the behavior for PG and MySQL's
+    DBAPI modules. [ticket:646]
   - new SQL operator implementation which removes all hardcoded operators
     from expression structures and moves them into compilation; 
     allows greater flexibility of operator compilation; for example, "+" 
   - better quoting of identifiers when manipulating schemas
   - standardized the behavior for table reflection where types can't be located;
     NullType is substituted instead, warning is raised.
+  - ColumnCollection (i.e. the 'c' attribute on tables) follows dictionary
+    semantics for "__contains__" [ticket:606]
+    
 - engines
   - Connections gain a .properties collection, with contents scoped to the
     lifetime of the underlying DBAPI connection
-  - ColumnCollection (i.e. the 'c' attribute on tables) follows dictionary
-    semantics for "__contains__" [ticket:606]
 - extensions
   - proxyengine is temporarily removed, pending an actually working
     replacement.
index 6e5616c0bd699b8654ab0a7b428a3d5030fd620f..f8b6e9bd79b92415c7d0d4fcc8ece718efc4b29c 100644 (file)
@@ -12,6 +12,7 @@ import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 import sqlalchemy.util as util
 from array import array as _array
+from decimal import Decimal
 
 try:
     from threading import Lock
@@ -135,7 +136,7 @@ class _StringType(object):
 class MSNumeric(sqltypes.Numeric, _NumericType):
     """MySQL NUMERIC type"""
     
-    def __init__(self, precision = 10, length = 2, **kw):
+    def __init__(self, precision = 10, length = 2, asdecimal=True, **kw):
         """Construct a NUMERIC.
 
         precision
@@ -155,18 +156,27 @@ class MSNumeric(sqltypes.Numeric, _NumericType):
         """
 
         _NumericType.__init__(self, **kw)
-        sqltypes.Numeric.__init__(self, precision, length)
-
+        sqltypes.Numeric.__init__(self, precision, length, asdecimal=asdecimal)
+        
     def get_col_spec(self):
         if self.precision is None:
             return self._extend("NUMERIC")
         else:
             return self._extend("NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
 
+    def convert_bind_param(self, value, dialect):
+        return value
+
+    def convert_result_value(self, value, dialect):
+        if not self.asdecimal and isinstance(value, Decimal):
+            return float(value)
+        else:
+            return value
+
 class MSDecimal(MSNumeric):
     """MySQL DECIMAL type"""
 
-    def __init__(self, precision=10, length=2, **kw):
+    def __init__(self, precision=10, length=2, asdecimal=True, **kw):
         """Construct a DECIMAL.
 
         precision
@@ -185,7 +195,7 @@ class MSDecimal(MSNumeric):
           underlying database API, which continue to be numeric.
         """
 
-        super(MSDecimal, self).__init__(precision, length, **kw)
+        super(MSDecimal, self).__init__(precision, length, asdecimal=asdecimal, **kw)
     
     def get_col_spec(self):
         if self.precision is None:
@@ -198,7 +208,7 @@ class MSDecimal(MSNumeric):
 class MSDouble(MSNumeric):
     """MySQL DOUBLE type"""
 
-    def __init__(self, precision=10, length=2, **kw):
+    def __init__(self, precision=10, length=2, asdecimal=True, **kw):
         """Construct a DOUBLE.
 
         precision
@@ -220,7 +230,7 @@ class MSDouble(MSNumeric):
         if ((precision is None and length is not None) or
             (precision is not None and length is None)):
             raise exceptions.ArgumentError("You must specify both precision and length or omit both altogether.")
-        super(MSDouble, self).__init__(precision, length, **kw)
+        super(MSDouble, self).__init__(precision, length, asdecimal=asdecimal, **kw)
 
     def get_col_spec(self):
         if self.precision is not None and self.length is not None:
@@ -233,7 +243,7 @@ class MSDouble(MSNumeric):
 class MSFloat(sqltypes.Float, _NumericType):
     """MySQL FLOAT type"""
 
-    def __init__(self, precision=10, length=None, **kw):
+    def __init__(self, precision=10, length=None, asdecimal=False, **kw):
         """Construct a FLOAT.
           
         precision
@@ -255,7 +265,7 @@ class MSFloat(sqltypes.Float, _NumericType):
         if length is not None:
             self.length=length
         _NumericType.__init__(self, **kw)
-        sqltypes.Float.__init__(self, precision)
+        sqltypes.Float.__init__(self, precision, asdecimal=asdecimal)
 
     def get_col_spec(self):
         if hasattr(self, 'length') and self.length is not None:
@@ -265,6 +275,10 @@ class MSFloat(sqltypes.Float, _NumericType):
         else:
             return self._extend("FLOAT")
 
+    def convert_bind_param(self, value, dialect):
+        return value
+
+
 class MSInteger(sqltypes.Integer, _NumericType):
     """MySQL INTEGER type"""
 
index d8f467358ffa687ec7f588696c921bc1267d8c83..056101279724ea5dc0208586f0c3c2094b9fe477 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import sql, schema, ansisql, exceptions
 from sqlalchemy.engine import base, default
 import sqlalchemy.types as sqltypes
 from sqlalchemy.databases import information_schema as ischema
+from decimal import Decimal
 
 try:
     import mx.DateTime.DateTime as mxDateTime
@@ -28,6 +29,15 @@ class PGNumeric(sqltypes.Numeric):
         else:
             return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
 
+    def convert_bind_param(self, value, dialect):
+        return value
+
+    def convert_result_value(self, value, dialect):
+        if not self.asdecimal and isinstance(value, Decimal):
+            return float(value)
+        else:
+            return value
+        
 class PGFloat(sqltypes.Float):
     def get_col_spec(self):
         if not self.precision:
@@ -35,6 +45,7 @@ class PGFloat(sqltypes.Float):
         else:
             return "FLOAT(%(precision)s)" % {'precision': self.precision}
 
+
 class PGInteger(sqltypes.Integer):
     def get_col_spec(self):
         return "INTEGER"
index 06720fd661f8200f0b906b2911da03e78b7895c5..4292e9dcc9ba8713d52e041909f0146931c3f440 100644 (file)
@@ -13,6 +13,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine',
 
 import inspect
 import datetime as dt
+from decimal import Decimal
 try:
     import cPickle as pickle
 except:
@@ -246,22 +247,36 @@ class SmallInteger(Integer):
 Smallinteger = SmallInteger
 
 class Numeric(TypeEngine):
-    def __init__(self, precision = 10, length = 2):
+    def __init__(self, precision = 10, length = 2, asdecimal=True):
         self.precision = precision
         self.length = length
+        self.asdecimal = asdecimal
 
     def adapt(self, impltype):
-        return impltype(precision=self.precision, length=self.length)
+        return impltype(precision=self.precision, length=self.length, asdecimal=self.asdecimal)
 
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
 
+    def convert_bind_param(self, value, dialect):
+        if value is not None:
+            return float(value)
+        else:
+            return value
+            
+    def convert_result_value(self, value, dialect):
+        if value is not None and self.asdecimal:
+            return Decimal(str(value))
+        else:
+            return value
+
 class Float(Numeric):
-    def __init__(self, precision = 10):
+    def __init__(self, precision = 10, asdecimal=False, **kwargs):
+        super(Float, self).__init__(asdecimal=asdecimal, **kwargs)
         self.precision = precision
 
     def adapt(self, impltype):
-        return impltype(precision=self.precision)
+        return impltype(precision=self.precision, asdecimal=self.asdecimal)
 
 class DateTime(TypeEngine):
     """Implement a type for ``datetime.datetime()`` objects."""
index 8dbeda19af4c6c2c2dc35bb39b9c754165fddaf3..d0ec06caa8998090e6723bd41b3c94410a083789 100644 (file)
@@ -355,6 +355,36 @@ class DateTest(AssertMixin):
         finally:
             t.drop(checkfirst=True)
 
+class NumericTest(AssertMixin):
+    def setUpAll(self):
+        global numeric_table, metadata
+        metadata = MetaData(testbase.db)
+        numeric_table = Table('numeric_table', metadata,
+            Column('id', Integer, Sequence('numeric_id_seq', optional=True), primary_key=True),
+            Column('numericcol', Numeric(asdecimal=False)),
+            Column('floatcol', Float),
+            Column('ncasdec', Numeric),
+            Column('fcasdec', Float(asdecimal=True))
+        )
+        metadata.create_all()
+        
+    def tearDownAll(self):
+        metadata.drop_all()
+        
+    def tearDown(self):
+        numeric_table.delete().execute()
+        
+    def test_decimal(self):
+        from decimal import Decimal
+        numeric_table.insert().execute(numericcol=3.5, floatcol=5.6, ncasdec=12.4, fcasdec=15.78)
+        numeric_table.insert().execute(numericcol=Decimal("3.5"), floatcol=Decimal("5.6"), ncasdec=Decimal("12.4"), fcasdec=Decimal("15.78"))
+        print numeric_table.select().execute().fetchall()
+        assert numeric_table.select().execute().fetchall() == [
+            (1, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
+            (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.78")),
+        ]
+        
+            
 class IntervalTest(AssertMixin):
     def setUpAll(self):
         global interval_table, metadata