]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The cx_oracle "decimal detection" logic, which takes place
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Nov 2010 17:19:31 +0000 (12:19 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Nov 2010 17:19:31 +0000 (12:19 -0500)
for for result set columns with ambiguous numeric characteristics,
now uses the decimal point character determined by the locale/
NLS_LANG setting, using an on-first-connect detection of
this character.  cx_oracle 5.0.3 or greater is also required
when using a non-period-decimal-point NLS_LANG setting.
[ticket:1953].

CHANGES
lib/sqlalchemy/dialects/oracle/cx_oracle.py
test/dialect/test_oracle.py

diff --git a/CHANGES b/CHANGES
index d179b011512016956417768b7364699dcbecf959..d29175e6b3f5054b72156b4fa11ab9cbc3ea02f6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -60,6 +60,15 @@ CHANGES
     than that of the parent table doesn't render at all,
     as cross-schema references do not appear to be supported.
 
+- oracle
+  - The cx_oracle "decimal detection" logic, which takes place
+    for for result set columns with ambiguous numeric characteristics,
+    now uses the decimal point character determined by the locale/
+    NLS_LANG setting, using an on-first-connect detection of 
+    this character.  cx_oracle 5.0.3 or greater is also required
+    when using a non-period-decimal-point NLS_LANG setting.
+    [ticket:1953].
+    
 - declarative
   - An error is raised if __table_args__ is not in tuple
     or dict format, and is not None.  [ticket:1972]
index eb25e614e6915ef37a6d20ad8a1065dbf6d044d2..87a84e514d2cce834ff59711b92b7a767f5620fc 100644 (file)
@@ -66,6 +66,52 @@ Two Phase Transaction Support
 Two Phase transactions are implemented using XA transactions.  Success has been reported 
 with this feature but it should be regarded as experimental.
 
+Precision Numerics
+------------------
+
+The SQLAlchemy dialect goes thorugh a lot of steps to ensure
+that decimal numbers are sent and received with full accuracy.
+An "outputtypehandler" callable is associated with each
+cx_oracle connection object which detects numeric types and
+receives them as string values, instead of receiving a Python
+``float`` directly, which is then passed to the Python
+``Decimal`` constructor.  The :class:`.Numeric` and
+:class:`.Float` types under the cx_oracle dialect are aware of
+this behavior, and will coerce the ``Decimal`` to ``float`` if
+the ``asdecimal`` flag is ``False`` (default on :class:`.Float`,
+optional on :class:`.Numeric`).
+
+The handler attempts to use the "precision" and "scale"
+attributes of the result set column to best determine if
+subsequent incoming values should be received as ``Decimal`` as
+opposed to int (in which case no processing is added). There are
+several scenarios where OCI_ does not provide unambiguous data
+as to the numeric type, including some situations where
+individual rows may return a combination of floating point and
+integer values. Certain values for "precision" and "scale" have
+been observed to determine this scenario.  When it occurs, the
+outputtypehandler receives as string and then passes off to a
+processing function which detects, for each returned value, if a
+decimal point is present, and if so converts to ``Decimal``,
+otherwise to int.  The intention is that simple int-based
+statements like "SELECT my_seq.nextval() FROM DUAL" continue to
+return ints and not ``Decimal`` objects, and that any kind of
+floating point value is received as a string so that there is no
+floating point loss of precision.
+
+The "decimal point is present" logic itself is also sensitive to
+locale.  Under OCI_, this is controlled by the NLS_LANG
+environment variable. Upon first connection, the dialect runs a
+test to determine the current "decimal" character, which can be
+a comma "," for european locales. From that point forward the
+outputtypehandler uses that character to represent a decimal
+point (this behavior is new in version 0.6.6). Note that
+cx_oracle 5.0.3 or greater is required when dealing with
+numerics with locale settings that don't use a period "." as the
+decimal character.
+
+.. _OCI: http://www.oracle.com/technetwork/database/features/oci/index.html
+
 """
 
 from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, \
@@ -76,6 +122,7 @@ from sqlalchemy import types as sqltypes, util, exc, processors
 from datetime import datetime
 import random
 from decimal import Decimal
+import re
 
 class _OracleNumeric(sqltypes.Numeric):
     def bind_processor(self, dialect):
@@ -473,37 +520,80 @@ class OracleDialect_cx_oracle(OracleDialect):
                 self.dbapi.BLOB: oracle.BLOB(),
                 self.dbapi.BINARY: oracle.RAW(),
             }
+    @classmethod
+    def dbapi(cls):
+        import cx_Oracle
+        return cx_Oracle
 
     def initialize(self, connection):
         super(OracleDialect_cx_oracle, self).initialize(connection)
         if self._is_oracle_8:
             self.supports_unicode_binds = False
+        self._detect_decimal_char(connection)
+    
+    def _detect_decimal_char(self, connection):
+        """detect if the decimal separator character is not '.', as 
+        is the case with european locale settings for NLS_LANG.
+        
+        cx_oracle itself uses similar logic when it formats Python
+        Decimal objects to strings on the bind side (as of 5.0.3), 
+        as Oracle sends/receives string numerics only in the 
+        current locale.
+        
+        """
+        if self.cx_oracle_ver < (5,):
+            # no output type handlers before version 5
+            return
+        
+        cx_Oracle = self.dbapi
+        conn = connection.connection
+        
+        # override the output_type_handler that's 
+        # on the cx_oracle connection with a plain 
+        # one on the cursor
+        
+        def output_type_handler(cursor, name, defaultType, 
+                                size, precision, scale):
+            return cursor.var(
+                        cx_Oracle.STRING, 
+                        255, arraysize=cursor.arraysize)
+
+        cursor = conn.cursor()
+        cursor.outputtypehandler = output_type_handler
+        cursor.execute("SELECT 0.1 FROM DUAL")
+        val = cursor.fetchone()[0]
+        cursor.close()
+        char = re.match(r"([\.,])", val).group(1)
+        if char != '.':
+            _detect_decimal = self._detect_decimal
+            self._detect_decimal = \
+                lambda value: _detect_decimal(value.replace(char, '.'))
+            self._to_decimal = \
+                lambda value: Decimal(value.replace(char, '.'))
+        
+    def _detect_decimal(self, value):
+        if "." in value:
+            return Decimal(value)
+        else:
+            return int(value)
+    
+    _to_decimal = Decimal
     
-    @classmethod
-    def dbapi(cls):
-        import cx_Oracle
-        return cx_Oracle
-
     def on_connect(self):
         if self.cx_oracle_ver < (5,):
             # no output type handlers before version 5
             return
         
-        def maybe_decimal(value):
-            if "." in value:
-                return Decimal(value)
-            else:
-                return int(value)
-                
         cx_Oracle = self.dbapi
-        def output_type_handler(cursor, name, defaultType, size, precision, scale):
+        def output_type_handler(cursor, name, defaultType, 
+                                    size, precision, scale):
             # convert all NUMBER with precision + positive scale to Decimal
             # this almost allows "native decimal" mode.
             if defaultType == cx_Oracle.NUMBER and precision and scale > 0:
                 return cursor.var(
                             cx_Oracle.STRING, 
                             255, 
-                            outconverter=Decimal, 
+                            outconverter=self._to_decimal, 
                             arraysize=cursor.arraysize)
             # if NUMBER with zero precision and 0 or neg scale, this appears
             # to indicate "ambiguous".  Use a slower converter that will 
@@ -515,7 +605,7 @@ class OracleDialect_cx_oracle(OracleDialect):
                 return cursor.var(
                             cx_Oracle.STRING, 
                             255, 
-                            outconverter=maybe_decimal, 
+                            outconverter=self._detect_decimal, 
                             arraysize=cursor.arraysize)
             # allow all strings to come back natively as Unicode
             elif defaultType in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR):
@@ -578,7 +668,10 @@ class OracleDialect_cx_oracle(OracleDialect):
         return ([], opts)
 
     def _get_server_version_info(self, connection):
-        return tuple(int(x) for x in connection.connection.version.split('.'))
+        return tuple(
+                        int(x) 
+                        for x in connection.connection.version.split('.')
+                    )
 
     def is_disconnect(self, e):
         if isinstance(e, self.dbapi.InterfaceError):
index 04b9d32749903a60c0459d0bdee397bc9ee5c807..eba750533b201b0979aaea8337d03ede63a1a357 100644 (file)
@@ -367,7 +367,7 @@ class CompatFlagsTest(TestBase, AssertsCompiledSQL):
         def server_version_info(self):
             return (8, 2, 5)
             
-        dialect = oracle.dialect()
+        dialect = oracle.dialect(dbapi=testing.db.dialect.dbapi)
         dialect._get_server_version_info = server_version_info
 
         # before connect, assume modern DB
@@ -384,7 +384,8 @@ class CompatFlagsTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(Unicode(50),"VARCHAR(50)",dialect=dialect)
         self.assert_compile(UnicodeText(),"CLOB",dialect=dialect)
 
-        dialect = oracle.dialect(implicit_returning=True)
+        dialect = oracle.dialect(implicit_returning=True, 
+                                    dbapi=testing.db.dialect.dbapi)
         dialect._get_server_version_info = server_version_info
         dialect.initialize(testing.db.connect())
         assert dialect.implicit_returning
@@ -392,7 +393,7 @@ class CompatFlagsTest(TestBase, AssertsCompiledSQL):
 
     def test_default_flags(self):
         """test with no initialization or server version info"""
-        dialect = oracle.dialect()
+        dialect = oracle.dialect(dbapi=testing.db.dialect.dbapi)
         assert dialect._supports_char_length
         assert dialect._supports_nchar
         assert dialect.use_ansi
@@ -403,7 +404,7 @@ class CompatFlagsTest(TestBase, AssertsCompiledSQL):
     def test_ora10_flags(self):
         def server_version_info(self):
             return (10, 2, 5)
-        dialect = oracle.dialect()
+        dialect = oracle.dialect(dbapi=testing.db.dialect.dbapi)
         dialect._get_server_version_info = server_version_info
         dialect.initialize(testing.db.connect())
         assert dialect._supports_char_length
@@ -1043,7 +1044,40 @@ class TypesTest(TestBase, AssertsCompiledSQL):
         finally:
             t.drop(engine)
             
-            
+class EuroNumericTest(TestBase):
+    """test the numeric output_type_handler when using non-US locale for NLS_LANG."""
+    
+    __only_on__ = 'oracle+cx-oracle'
+    
+    def setup(self):
+        self.old_nls_lang = os.environ.get('NLS_LANG', False)
+        os.environ['NLS_LANG'] = "GERMAN"
+        self.engine = testing_engine()
+        
+    def teardown(self):
+        if self.old_nls_lang is not False:
+            os.environ['NLS_LANG'] = self.old_nls_lang
+        else:
+            del os.environ['NLS_LANG']
+        self.engine.dispose()
+        
+    @testing.provide_metadata
+    def test_output_type_handler(self):
+        for stmt, exp, kw in [
+            ("SELECT 0.1 FROM DUAL", Decimal("0.1"), {}),
+            ("SELECT 15 FROM DUAL", 15, {}),
+            ("SELECT CAST(15 AS NUMERIC(3, 1)) FROM DUAL", Decimal("15"), {}),
+            ("SELECT CAST(0.1 AS NUMERIC(5, 2)) FROM DUAL", Decimal("0.1"), {}),
+            ("SELECT :num FROM DUAL", Decimal("2.5"), {'num':Decimal("2.5")})
+        ]:
+            test_exp = self.engine.scalar(stmt, **kw)
+            eq_(
+                test_exp,
+                exp
+            )
+            assert type(test_exp) is type(exp)
+        
+    
 class DontReflectIOTTest(TestBase):
     """test that index overflow tables aren't included in
     table_names."""