]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the "scale" argument of the Numeric() type is honored when
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 26 Feb 2010 17:50:34 +0000 (17:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 26 Feb 2010 17:50:34 +0000 (17:50 +0000)
coercing a returned floating point value into a string
on its way to Decimal - this allows accuracy to function
on SQLite, MySQL.  [ticket:1717]

CHANGES
lib/sqlalchemy/processors.py
lib/sqlalchemy/test/util.py
lib/sqlalchemy/types.py
test/dialect/test_postgresql.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index 5e0d3662dc8cf3206c296a29168b47c18c4b1899..05d53ffa5f2714110089cfcac730ef9b25284299 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -171,6 +171,11 @@ CHANGES
     not new).  An error is now raised if a Column() has no type
     and no foreign keys.  [ticket:1705]
     
+  - the "scale" argument of the Numeric() type is honored when 
+    coercing a returned floating point value into a string 
+    on its way to Decimal - this allows accuracy to function
+    on SQLite, MySQL.  [ticket:1717]
+    
 - engines
   - Added an optional C extension to speed up the sql layer by
     reimplementing RowProxy and the most common result processors.
index 4cf6831bd557b2e1eea464bdc577747eba1979a2..04fa5054abee27605e522da47b01bb37428c18a4 100644 (file)
@@ -38,9 +38,10 @@ try:
             return UnicodeResultProcessor(encoding, errors).process
         else:
             return UnicodeResultProcessor(encoding).process
-
-    def to_decimal_processor_factory(target_class):
-        return DecimalResultProcessor(target_class).process
+    
+    # TODO: add scale argument
+    #def to_decimal_processor_factory(target_class):
+    #    return DecimalResultProcessor(target_class).process
 
 except ImportError:
     def to_unicode_processor_factory(encoding, errors=None):
@@ -57,13 +58,14 @@ except ImportError:
                 return decoder(value, errors)[0]
         return process
 
-    def to_decimal_processor_factory(target_class):
-        def process(value):
-            if value is None:
-                return None
-            else:
-                return target_class(str(value))
-        return process
+    # TODO: add scale argument
+    #def to_decimal_processor_factory(target_class):
+    #    def process(value):
+    #        if value is None:
+    #            return None
+    #        else:
+    #            return target_class(str(value))
+    #    return process
 
     def to_float(value):
         if value is None:
@@ -92,3 +94,13 @@ except ImportError:
     str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time)
     str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date)
 
+
+def to_decimal_processor_factory(target_class, scale=10):
+    fstring = "%%.%df" % scale
+    
+    def process(value):
+        if value is None:
+            return None
+        else:
+            return target_class(fstring % value)
+    return process
index 5be00f9068e8932d7c8de171551b92926289c77e..8a3a0e7452a25c18ac3d294c00acb44bd9728fb4 100644 (file)
@@ -39,4 +39,15 @@ def picklers():
     for pickle in picklers:
         for protocol in -1, 0, 1, 2:
             yield pickle.loads, lambda d:pickle.dumps(d, protocol)
+    
+    
+def round_decimal(value, prec):
+    if isinstance(value, float):
+        return round(value, prec)
+    
+    import decimal
+
+    # can also use shift() here but that is 2.6 only
+    return (value * decimal.Decimal("1" + "0" * prec)).to_integral(decimal.ROUND_FLOOR) / \
+                        pow(10, prec)
     
\ No newline at end of file
index d5f1d9f1455438ef2930c4ab3c9f05e298c74ab9..d7b8f9289106ce865e56d6f364fee9a23f08decd 100644 (file)
@@ -838,7 +838,7 @@ class Numeric(_DateAffinity, TypeEngine):
 #            try:
 #                from fastdec import mpd as Decimal
 #            except ImportError:
-            return processors.to_decimal_processor_factory(_python_Decimal)
+            return processors.to_decimal_processor_factory(_python_Decimal, self.scale)
         else:
             return None
 
@@ -877,6 +877,16 @@ class Float(Numeric):
     def adapt(self, impltype):
         return impltype(precision=self.precision, asdecimal=self.asdecimal)
 
+    def result_processor(self, dialect, coltype):
+        if self.asdecimal:
+            #XXX: use decimal from http://www.bytereef.org/libmpdec.html
+#            try:
+#                from fastdec import mpd as Decimal
+#            except ImportError:
+            return processors.to_decimal_processor_factory(_python_Decimal)
+        else:
+            return None
+
 
 class DateTime(_DateAffinity, TypeEngine):
     """A type for ``datetime.datetime()`` objects.
index 1a21ec11f17f7b7fa42a138ee70826c530ccbacc..b002e7f19fbbf9abf103641f25500d1f5432bc19 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import exc, schema, types
 from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.engine.strategies import MockEngineStrategy
 from sqlalchemy.test import *
+from sqlalchemy.test.util import round_decimal
 from sqlalchemy.sql import table, column
 from sqlalchemy.test.testing import eq_
 from test.engine._base import TablesTest
@@ -203,15 +204,6 @@ class FloatCoercionTest(TablesTest, AssertsExecutionResults):
             {'data':9},
         )
     
-    def _round(self, x):
-        if isinstance(x, float):
-            return round(x, 9)
-        elif isinstance(x, decimal.Decimal):
-            # really ?
-            # (can also use shift() here but that is 2.6 only)
-            x = (x * decimal.Decimal("1000000000")).to_integral() / pow(10, 9)
-        return x
-
     @testing.resolve_artifact_names
     def test_float_coercion(self):
         for type_, result in [
@@ -226,14 +218,14 @@ class FloatCoercionTest(TablesTest, AssertsExecutionResults):
                 ])
             ).scalar()
 
-            eq_(self._round(ret), result)
+            eq_(round_decimal(ret, 9), result)
 
             ret = testing.db.execute(
                 select([
                     cast(func.stddev_pop(data_table.c.data), type_)
                 ])
             ).scalar()
-            eq_(self._round(ret), result)
+            eq_(round_decimal(ret, 9), result)
     
     
         
index 4b2370afcbf45c6157c578e85dadeedfdefdf208..53f4d8d919e85f729e0df796825ae300ffea4ea6 100644 (file)
@@ -1024,6 +1024,34 @@ class NumericTest(TestBase, AssertsExecutionResults):
             (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")),
         ])
 
+    @testing.fails_if(_missing_decimal)
+    def test_precision_decimal(self):
+        from decimal import Decimal
+        from sqlalchemy.test.util import round_decimal
+            
+        t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12)))
+        t.create(testing.db)
+        try:
+            numbers = set(
+            [
+                decimal.Decimal("54.234246451650"),
+                decimal.Decimal("876734.594069654000"),
+                decimal.Decimal("0.004354"), 
+                decimal.Decimal("900.0"), 
+            ])
+
+            testing.db.execute(t.insert(), [{'x':x} for x in numbers])
+
+            ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()])
+            
+            numbers = set(round_decimal(n, 11) for n in numbers)
+            ret = set(round_decimal(n, 11) for n in ret)
+            
+            eq_(numbers, ret)
+        finally:
+            t.drop(testing.db)
+            
+
     def test_decimal_fallback(self):
         from decimal import Decimal