]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- get firebird on board
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2010 16:30:22 +0000 (12:30 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2010 16:30:22 +0000 (12:30 -0400)
- a lot of these drivers suck at decimals, not sure what to do

lib/sqlalchemy/dialects/firebird/kinterbasdb.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
test/sql/test_types.py

index 66d001e0c85ec5c710bc663e9c44534e17721900..9984d32a2895df6f72b5272249bd8308e999fd5a 100644 (file)
@@ -28,13 +28,32 @@ __ http://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurr
 """
 
 from sqlalchemy.dialects.firebird.base import FBDialect, FBCompiler
-
-
+from sqlalchemy import util, types as sqltypes
+
+class _FBNumeric_kinterbasdb(sqltypes.Numeric):
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is not None:
+                return str(value)
+            else:
+                return value
+        return process
+        
 class FBDialect_kinterbasdb(FBDialect):
     driver = 'kinterbasdb'
     supports_sane_rowcount = False
     supports_sane_multi_rowcount = False
-
+    
+    supports_native_decimal = True
+    
+    colspecs = util.update_copy(
+        FBDialect.colspecs,
+        {
+            sqltypes.Numeric:_FBNumeric_kinterbasdb
+        }
+        
+    )
+    
     def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
         super(FBDialect_kinterbasdb, self).__init__(**kwargs)
 
index eee2bb1babe36b450fb4a42bf34121679c6ae664..c6e9cea5dcd1f2af04fdd33258a19c4ec7b61753 100644 (file)
@@ -82,6 +82,12 @@ from sqlalchemy import types as sqltypes, util, exc
 from datetime import datetime
 import random
 
+class _OracleNumeric(sqltypes.Numeric):
+    # cx_oracle accepts Decimal objects, but returns
+    # floats
+    def bind_processor(self, dialect):
+        return None
+        
 class _OracleDate(sqltypes.Date):
     def bind_processor(self, dialect):
         return None
@@ -188,25 +194,6 @@ class _OracleInterval(oracle.INTERVAL):
 class _OracleRaw(oracle.RAW):
     pass
 
-colspecs = {
-    sqltypes.Date : _OracleDate, # generic type, assume datetime.date is desired
-    oracle.DATE: oracle.DATE,  # non generic type - passthru
-    sqltypes.LargeBinary : _OracleBinary,
-    sqltypes.Boolean : oracle._OracleBoolean,
-    sqltypes.Interval : _OracleInterval,
-    oracle.INTERVAL : _OracleInterval,
-    sqltypes.Text : _OracleText,
-    sqltypes.String : _OracleString,
-    sqltypes.UnicodeText : _OracleUnicodeText,
-    sqltypes.CHAR : _OracleChar,
-    sqltypes.Integer : _OracleInteger,  # this is only needed for OUT parameters.
-                                        # it would be nice if we could not use it otherwise.
-    oracle.NUMBER : oracle.NUMBER, # don't let this get converted
-    oracle.RAW: _OracleRaw,
-    sqltypes.Unicode: _OracleNVarChar,
-    sqltypes.NVARCHAR : _OracleNVarChar,
-}
-
 class OracleCompiler_cx_oracle(OracleCompiler):
     def bindparam_string(self, name):
         if self.preparer._bindparam_requires_quotes(name):
@@ -346,7 +333,27 @@ class OracleDialect_cx_oracle(OracleDialect):
     execution_ctx_cls = OracleExecutionContext_cx_oracle
     statement_compiler = OracleCompiler_cx_oracle
     driver = "cx_oracle"
-    colspecs = colspecs
+    
+    colspecs = colspecs = {
+        sqltypes.Numeric: _OracleNumeric,
+        sqltypes.Date : _OracleDate, # generic type, assume datetime.date is desired
+        oracle.DATE: oracle.DATE,  # non generic type - passthru
+        sqltypes.LargeBinary : _OracleBinary,
+        sqltypes.Boolean : oracle._OracleBoolean,
+        sqltypes.Interval : _OracleInterval,
+        oracle.INTERVAL : _OracleInterval,
+        sqltypes.Text : _OracleText,
+        sqltypes.String : _OracleString,
+        sqltypes.UnicodeText : _OracleUnicodeText,
+        sqltypes.CHAR : _OracleChar,
+        sqltypes.Integer : _OracleInteger,  # this is only needed for OUT parameters.
+                                            # it would be nice if we could not use it otherwise.
+        oracle.NUMBER : oracle.NUMBER, # don't let this get converted
+        oracle.RAW: _OracleRaw,
+        sqltypes.Unicode: _OracleNVarChar,
+        sqltypes.NVARCHAR : _OracleNVarChar,
+    }
+
     
     execute_sequence_format = list
     
index ba7b2aaeadfda440ba4f9bcbd0e4d61ccc0ec45d..6a58d180cc241c348fb5595d443d22c8a4353dca 100644 (file)
@@ -1110,29 +1110,29 @@ class NumericTest(TestBase):
     def test_numeric_as_decimal(self):
         self._do_test(
             Numeric(precision=8, scale=4),
-            [15.7563, Decimal("15.7563")],
-            [Decimal("15.7563")], 
+            [15.7563, Decimal("15.7563"), None],
+            [Decimal("15.7563"), None], 
         )
 
     def test_numeric_as_float(self):
         if testing.against("oracle+cx_oracle"):
-            filter_ = lambda n:round(n, 5)
+            filter_ = lambda n:n is not None and round(n, 5) or None
         else:
             filter_ = None
 
         self._do_test(
             Numeric(precision=8, scale=4, asdecimal=False),
-            [15.7563, Decimal("15.7563")],
-            [15.7563],
+            [15.7563, Decimal("15.7563"), None],
+            [15.7563, None],
             filter_ = filter_
         )
 
     def test_float_as_decimal(self):
         self._do_test(
             Float(precision=8, asdecimal=True),
-            [15.7563, Decimal("15.7563")],
-            [Decimal("15.7563")], 
-            filter_ = lambda n:round(n, 5)
+            [15.7563, Decimal("15.7563"), None],
+            [Decimal("15.7563"), None], 
+            filter_ = lambda n:n is not None and round(n, 5) or None
         )
 
     def test_float_as_float(self):
@@ -1140,26 +1140,20 @@ class NumericTest(TestBase):
             Float(precision=8),
             [15.7563, Decimal("15.7563")],
             [15.7563],
-            filter_ = lambda n:round(n, 5)
+            filter_ = lambda n:n is not None and round(n, 5) or None
         )
         
     def test_precision_decimal(self):
         numbers = set([
             decimal.Decimal("54.234246451650"),
-            decimal.Decimal("87673.594069654000"),
             decimal.Decimal("0.004354"), 
             decimal.Decimal("900.0"), 
         ])
-        if testing.against('sqlite', 'sybase+pysybase', 'oracle+cx_oracle'):
-            filter_ = lambda n:round_decimal(n, 11)
-        else:
-            filter_ = None
             
         self._do_test(
             Numeric(precision=18, scale=12),
             numbers,
             numbers,
-            filter_=filter_
         )
 
     def test_enotation_decimal(self):
@@ -1192,6 +1186,7 @@ class NumericTest(TestBase):
     
     @testing.fails_on("sybase+pyodbc", 
                         "Don't know how do get these values through FreeTDS + Sybase")
+    @testing.fails_on("firebird", "Precision must be from 1 to 18")
     def test_enotation_decimal_large(self):
         """test exceedingly large decimals.
 
@@ -1212,13 +1207,15 @@ class NumericTest(TestBase):
     @testing.fails_on('sqlite', 'TODO')
     @testing.fails_on('oracle', 'TODO')
     @testing.fails_on('postgresql+pg8000', 'TODO')
+    @testing.fails_on("firebird", "Precision must be from 1 to 18")
     def test_many_significant_digits(self):
         numbers = set([
             decimal.Decimal("31943874831932418390.01"),
             decimal.Decimal("319438950232418390.273596"),
+            decimal.Decimal("87673.594069654243"),
         ])
         self._do_test(
-            Numeric(precision=26, scale=6),
+            Numeric(precision=38, scale=12),
             numbers,
             numbers
         )