]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- moved most Decimal bind/result handling into types.py, out of sqlite, mysql dialects.
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Mar 2010 15:48:24 +0000 (11:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Mar 2010 15:48:24 +0000 (11:48 -0400)
- added an explicit test for [ticket:1216]
- some questions remain about MSSQL - would like to simplify/remove bind handling for numerics

lib/sqlalchemy/connectors/mxodbc.py
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/mysqlconnector.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/oursql.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/types.py
test/sql/test_types.py

index 68b88019c248ab221acc8b3d314d96387fcafbaf..de9b18f1a464053058b328d071884b3355d83f45 100644 (file)
@@ -11,7 +11,8 @@ class MxODBCConnector(Connector):
     supports_sane_multi_rowcount = False
     supports_unicode_statements = False
     supports_unicode_binds = False
-    supports_native_decimal = False
+    
+    supports_native_decimal = True
     
     @classmethod
     def dbapi(cls):
index de13a2dce11c2970928ec85111829efa8af5f59d..a2da132dafb8915881dddb8a6ed66b97aa1bfaf4 100644 (file)
@@ -69,7 +69,7 @@ the SQLAlchemy ``returning()`` method, such as::
 
 """
 
-import datetime, decimal, re
+import datetime, re
 
 from sqlalchemy import schema as sa_schema
 from sqlalchemy import exc, types as sqltypes, sql, util
index bd5af4e8e4acfcb15e08aa1dc081b2a9bb6bd345..070650d97f4bf2541813e06eb34402265cef66af 100644 (file)
@@ -277,27 +277,13 @@ RESERVED_WORDS = set(
 
 
 class _MSNumeric(sqltypes.Numeric):
-    def result_processor(self, dialect, coltype):
-        if self.asdecimal:
-            # TODO: factor this down into the sqltypes.Numeric class,
-            # use dialect flags
-            if getattr(self, 'scale', None) is None:
-                # we're a "float".  return a default decimal factory
-                return processors.to_decimal_processor_factory(decimal.Decimal)
-            elif dialect.supports_native_decimal:
-                # we're a "numeric", DBAPI will give us Decimal directly
-                return None
-            else:
-                # we're a "numeric", DBAPI returns floats, convert.
-                return processors.to_decimal_processor_factory(decimal.Decimal, self.scale)
-        else:
-            #XXX: if the DBAPI returns a float (this is likely, given the
-            # processor when asdecimal is True), this should be a None
-            # processor instead.
-            return processors.to_float
-            
+    
     def bind_processor(self, dialect):
         def process(value):
+            # TODO: this seems exceedingly complex. 
+            # need to know exactly what tests cover this, so far
+            # test_types.NumericTest.test_enotation_decimal
+            
             if isinstance(value, decimal.Decimal):
                 if value.adjusted() < 0:
                     result = "%s0.%s%s" % (
index 33c18f6d62946d0617c0320159366500dc90f9da..981e1e204b535cd410d75f5b3bcc0e5c42b76579 100644 (file)
@@ -8,7 +8,7 @@ import re
 
 from sqlalchemy.dialects.mysql.base import (MySQLDialect,
     MySQLExecutionContext, MySQLCompiler, MySQLIdentifierPreparer,
-    BIT, NUMERIC, _NumericType)
+    BIT)
 
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
@@ -28,15 +28,6 @@ class MySQLCompiler_mysqlconnector(MySQLCompiler):
     def post_process_text(self, text):
         return text.replace('%', '%%')
 
-class _DecimalType(_NumericType):
-    def result_processor(self, dialect, coltype):
-        if self.asdecimal:
-            return None
-        return processors.to_float
-
-class _myconnpyNumeric(_DecimalType, NUMERIC):
-    pass
-
 class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
 
     def _escape_identifier(self, value):
@@ -56,6 +47,8 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
 
+    supports_native_decimal = True
+
     default_paramstyle = 'format'
     execution_ctx_cls = MySQLExecutionContext_mysqlconnector
     statement_compiler = MySQLCompiler_mysqlconnector
@@ -65,7 +58,6 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
     colspecs = util.update_copy(
         MySQLDialect.colspecs,
         {
-            sqltypes.Numeric: _myconnpyNumeric,
             BIT: _myconnpyBIT,
         }
     )
index 038e58a4c30d6c926b386cd63dc18856be1cf553..9d34939a1ff86193990436fc6f5b9ba13f5fd2fe 100644 (file)
@@ -20,11 +20,10 @@ strings, also pass ``use_unicode=0`` in the connection arguments::
   create_engine('mysql:///mydb?charset=utf8&use_unicode=0')
 """
 
-import decimal
 import re
 
-from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutionContext,
-                                            MySQLCompiler, MySQLIdentifierPreparer, NUMERIC, _NumericType)
+from sqlalchemy.dialects.mysql.base import (MySQLDialect, MySQLExecutionContext,
+                                            MySQLCompiler, MySQLIdentifierPreparer)
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
@@ -48,20 +47,6 @@ class MySQLCompiler_mysqldb(MySQLCompiler):
         return text.replace('%', '%%')
 
 
-class _DecimalType(_NumericType):
-    def result_processor(self, dialect, coltype):
-        if self.asdecimal:
-            return None
-        return processors.to_float
-
-
-class _MySQLdbNumeric(_DecimalType, NUMERIC):
-    pass
-
-
-class _MySQLdbDecimal(_DecimalType, DECIMAL):
-    pass
-
 class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer):
     
     def _escape_identifier(self, value):
@@ -74,6 +59,8 @@ class MySQLDialect_mysqldb(MySQLDialect):
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
 
+    supports_native_decimal = True
+
     default_paramstyle = 'format'
     execution_ctx_cls = MySQLExecutionContext_mysqldb
     statement_compiler = MySQLCompiler_mysqldb
@@ -82,8 +69,6 @@ class MySQLDialect_mysqldb(MySQLDialect):
     colspecs = util.update_copy(
         MySQLDialect.colspecs,
         {
-            sqltypes.Numeric: _MySQLdbNumeric,
-            DECIMAL: _MySQLdbDecimal
         }
     )
     
index 605b39760f475bd8577390844ca3bb4b8eebee16..f26bc4da2f82265d4785c9b574945d02e13c48b1 100644 (file)
@@ -21,23 +21,16 @@ defaults to, there is a separate parameter::
   create_engine('mysql+oursql:///mydb?charset=latin1')
 """
 
-import decimal
 import re
 
 from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionContext,
-                                            MySQLCompiler, MySQLIdentifierPreparer, NUMERIC, _NumericType)
+                                            MySQLCompiler, MySQLIdentifierPreparer)
 from sqlalchemy.engine import base as engine_base, default
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import exc, log, schema, sql, types as sqltypes, util
 from sqlalchemy import processors
 
 
-class _oursqlNumeric(NUMERIC):
-    def result_processor(self, dialect, coltype):
-        if self.asdecimal:
-            return None
-        return processors.to_float
-
 
 class _oursqlBIT(BIT):
     def result_processor(self, dialect, coltype):
@@ -60,7 +53,9 @@ class MySQLDialect_oursql(MySQLDialect):
     supports_unicode_binds = True
     supports_unicode_statements = True
 # end Py2K
-
+    
+    supports_native_decimal = True
+    
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
     execution_ctx_cls = MySQLExecutionContext_oursql
@@ -69,7 +64,6 @@ class MySQLDialect_oursql(MySQLDialect):
         MySQLDialect.colspecs,
         {
             sqltypes.Time: sqltypes.Time,
-            sqltypes.Numeric: _oursqlNumeric,
             BIT: _oursqlBIT,
         }
     )
index 98df8d0cb4b6d2236f16c3375649fa3cd1deda2a..d7637e71b45c6f2fc2ff3291a4cc5af4dcf4a360 100644 (file)
@@ -61,19 +61,6 @@ from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\
                             TIMESTAMP, VARCHAR
                             
 
-class _NumericMixin(object):
-    def bind_processor(self, dialect):
-        if self.asdecimal:
-            return processors.to_str
-        else:
-            return processors.to_float
-
-class _SLNumeric(_NumericMixin, sqltypes.Numeric):
-    pass
-
-class _SLFloat(_NumericMixin, sqltypes.Float):
-    pass
-
 class _DateTimeMixin(object):
     _reg = None
     _storage_format = None
@@ -163,8 +150,6 @@ class TIME(_DateTimeMixin, sqltypes.Time):
 colspecs = {
     sqltypes.Date: DATE,
     sqltypes.DateTime: DATETIME,
-    sqltypes.Float: _SLFloat,
-    sqltypes.Numeric: _SLNumeric,
     sqltypes.Time: TIME,
 }
 
index 53f32fb2e9f448d81fa9296adfacebb74ef7e3ad..3feac8f4fc9bc77e181ef310e5230b7bd1b28cdc 100644 (file)
@@ -928,20 +928,24 @@ class Numeric(_DateAffinity, TypeEngine):
         return dbapi.NUMBER
 
     def bind_processor(self, dialect):
-        return processors.to_float
+        if dialect.supports_native_decimal:
+            return None
+        else:
+            return processors.to_float
 
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
-            #XXX: use decimal from http://www.bytereef.org/libmpdec.html
-#            try:
-#                from fastdec import mpd as Decimal
-#            except ImportError:
-            if self.scale is not None:
-                return processors.to_decimal_processor_factory(_python_Decimal, self.scale)
+            if dialect.supports_native_decimal:
+                # we're a "numeric", DBAPI will give us Decimal directly
+                return None
             else:
-                return processors.to_decimal_processor_factory(_python_Decimal)
+                # we're a "numeric", DBAPI returns floats, convert.
+                return processors.to_decimal_processor_factory(_python_Decimal, self.scale)
         else:
-            return None
+            if dialect.supports_native_decimal:
+                return processors.to_float
+            else:
+                return None
 
     @util.memoized_property
     def _expression_adaptations(self):
@@ -980,10 +984,6 @@ class Float(Numeric):
 
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
-            #XXX: use decimal from http://www.bytereef.org/libmpdec.html
-#            try:
-#                from fastdec import mpd as Decimal
-#            except ImportError:
             return processors.to_decimal_processor_factory(_python_Decimal)
         else:
             return None
index 6404783a5a97f3cec38bcf6602867b2a066fdcb7..507ca5a091b6447b7080fa20d59f0de1a9f84aa0 100644 (file)
@@ -11,6 +11,8 @@ from sqlalchemy.databases import *
 from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.test import *
 from sqlalchemy.test.util import picklers
+from decimal import Decimal
+from sqlalchemy.test.util import round_decimal
 
 
 class AdaptTest(TestBase):
@@ -1084,14 +1086,6 @@ class StringTest(TestBase, AssertsExecutionResults):
         foo.create()
         foo.drop()
 
-def _missing_decimal():
-    """Python implementation supports decimals"""
-    try:
-        import decimal
-        return False
-    except ImportError:
-        return True
-
 class NumericTest(TestBase, AssertsExecutionResults):
     @classmethod
     def setup_class(cls):
@@ -1114,7 +1108,6 @@ class NumericTest(TestBase, AssertsExecutionResults):
     def teardown(self):
         numeric_table.delete().execute()
 
-    @testing.fails_if(_missing_decimal)
     def test_decimal(self):
         from decimal import Decimal
         numeric_table.insert().execute(
@@ -1134,10 +1127,7 @@ class NumericTest(TestBase, AssertsExecutionResults):
             (2, 3.5, 5.6, Decimal("12.4"), Decimal("15.75")),
         ])
 
-    @testing.fails_if(_missing_decimal)
     def test_precision_decimal(self):
-        from decimal import Decimal
-        from sqlalchemy.test.util import round_decimal
             
         t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12)))
         t.create(testing.db)
@@ -1154,13 +1144,50 @@ class NumericTest(TestBase, AssertsExecutionResults):
 
             ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()])
             
-            numbers = set(round_decimal(n, 11) for n in numbers)
-            ret = set(round_decimal(n, 11) for n in ret)
+            if testing.against('sqlite'):
+                numbers = set(round_decimal(n, 11) for n in numbers)
+                ret = set(round_decimal(n, 11) for n in ret)
+            else:
+                numbers = set(n for n in numbers)
+                ret = set(n for n in ret)
             
             eq_(numbers, ret)
         finally:
             t.drop(testing.db)
+
+    @testing.fails_on('sybase', "Driver doesn't appear to handle E notation, won't accept strings")
+    def test_enotation_decimal(self):
+        """test exceedingly small decimals.
+        
+        Decimal reports values with E notation when the exponent 
+        is greater than 6.
+        
+        """
+
+        t = Table('t', MetaData(), Column('x', Numeric(precision=18, scale=12)))
+        t.create(testing.db)
+        try:
+            numbers = set([
+                decimal.Decimal('1E-2'),
+                decimal.Decimal('1E-3'),
+                decimal.Decimal('1E-4'),
+                decimal.Decimal('1E-5'),
+                decimal.Decimal('1E-6'),
+                decimal.Decimal('1E-7'),
+                decimal.Decimal('1E-8'),
+            ])
+
+            testing.db.execute(t.insert(), [{'x':x} for x in numbers])
+
+            ret = set([row[0] for row in testing.db.execute(t.select()).fetchall()])
             
+            numbers = set(n for n in numbers)
+            ret = set(n for n in ret)
+            
+            eq_(numbers, ret)
+        finally:
+            t.drop(testing.db)
+        
 
     def test_decimal_fallback(self):
         from decimal import Decimal