]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Turn oracle BINARY_DOUBLE, BINARY_FLOAT, DOUBLE_PRECISION into floats
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 23 May 2018 20:22:48 +0000 (16:22 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 25 May 2018 14:29:10 +0000 (10:29 -0400)
The Oracle BINARY_FLOAT and BINARY_DOUBLE datatypes now participate within
cx_Oracle.setinputsizes(), passing along NATIVE_FLOAT, so as to support the
NaN value.  Additionally, :class:`.oracle.BINARY_FLOAT`,
:class:`.oracle.BINARY_DOUBLE` and :class:`.oracle.DOUBLE_PRECISION` now
subclass :class:`.Float`, since these are floating point datatypes, not
decimal.  These datatypes were already defaulting the
:paramref:`.Float.asdecimal` flag to False in line with what
:class:`.Float` already does.

Added reflection capabilities for the :class:`.oracle.BINARY_FLOAT`,
:class:`.oracle.BINARY_DOUBLE` datatypes.

Change-Id: Id99b912e83052654a17d07dc92b4dcb958cb7600
Fixes: #4264
doc/build/changelog/unreleased_12/4264.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/type_api.py
test/dialect/oracle/test_reflection.py
test/dialect/oracle/test_types.py

diff --git a/doc/build/changelog/unreleased_12/4264.rst b/doc/build/changelog/unreleased_12/4264.rst
new file mode 100644 (file)
index 0000000..77878f8
--- /dev/null
@@ -0,0 +1,21 @@
+.. change::
+    :tags: bug, oracle
+    :tickets: 4264
+    :versions: 1.3.0b1
+
+    The Oracle BINARY_FLOAT and BINARY_DOUBLE datatypes now participate within
+    cx_Oracle.setinputsizes(), passing along NATIVE_FLOAT, so as to support the
+    NaN value.  Additionally, :class:`.oracle.BINARY_FLOAT`,
+    :class:`.oracle.BINARY_DOUBLE` and :class:`.oracle.DOUBLE_PRECISION` now
+    subclass :class:`.Float`, since these are floating point datatypes, not
+    decimal.  These datatypes were already defaulting the
+    :paramref:`.Float.asdecimal` flag to False in line with what
+    :class:`.Float` already does.
+
+.. change::
+    :tags: bug, oracle
+    :versions: 1.3.0b1
+
+    Added reflection capabilities for the :class:`.oracle.BINARY_FLOAT`,
+    :class:`.oracle.BINARY_DOUBLE` datatypes.
+
index e55a9cbc6a030d0952e3fd1f5221997a20c4e359..39acbf28d8507630612ac25ffa6f99da28ad2ac7 100644 (file)
@@ -411,38 +411,17 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer):
             return sqltypes.Integer
 
 
-class DOUBLE_PRECISION(sqltypes.Numeric):
+class DOUBLE_PRECISION(sqltypes.Float):
     __visit_name__ = 'DOUBLE_PRECISION'
 
-    def __init__(self, precision=None, scale=None, asdecimal=None):
-        if asdecimal is None:
-            asdecimal = False
-
-        super(DOUBLE_PRECISION, self).__init__(
-            precision=precision, scale=scale, asdecimal=asdecimal)
-
 
-class BINARY_DOUBLE(sqltypes.Numeric):
+class BINARY_DOUBLE(sqltypes.Float):
     __visit_name__ = 'BINARY_DOUBLE'
 
-    def __init__(self, precision=None, scale=None, asdecimal=None):
-        if asdecimal is None:
-            asdecimal = False
-
-        super(BINARY_DOUBLE, self).__init__(
-            precision=precision, scale=scale, asdecimal=asdecimal)
 
-
-class BINARY_FLOAT(sqltypes.Numeric):
+class BINARY_FLOAT(sqltypes.Float):
     __visit_name__ = 'BINARY_FLOAT'
 
-    def __init__(self, precision=None, scale=None, asdecimal=None):
-        if asdecimal is None:
-            asdecimal = False
-
-        super(BINARY_FLOAT, self).__init__(
-            precision=precision, scale=scale, asdecimal=asdecimal)
-
 
 class BFILE(sqltypes.LargeBinary):
     __visit_name__ = 'BFILE'
@@ -536,6 +515,8 @@ ischema_names = {
     'FLOAT': FLOAT,
     'DOUBLE PRECISION': DOUBLE_PRECISION,
     'LONG': LONG,
+    'BINARY_DOUBLE': BINARY_DOUBLE,
+    'BINARY_FLOAT': BINARY_FLOAT
 }
 
 
@@ -585,17 +566,25 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler):
     def visit_BINARY_FLOAT(self, type_, **kw):
         return self._generate_numeric(type_, "BINARY_FLOAT", **kw)
 
+    def visit_FLOAT(self, type_, **kw):
+        # don't support conversion between decimal/binary
+        # precision yet
+        kw['no_precision'] = True
+        return self._generate_numeric(type_, "FLOAT", **kw)
+
     def visit_NUMBER(self, type_, **kw):
         return self._generate_numeric(type_, "NUMBER", **kw)
 
-    def _generate_numeric(self, type_, name, precision=None, scale=None, **kw):
+    def _generate_numeric(
+            self, type_, name, precision=None,
+            scale=None, no_precision=False, **kw):
         if precision is None:
             precision = type_.precision
 
         if scale is None:
             scale = getattr(type_, 'scale', None)
 
-        if precision is None:
+        if no_precision or precision is None:
             return name
         elif scale is None:
             n = "%(name)s(%(precision)s)"
@@ -1418,6 +1407,9 @@ class OracleDialect(default.DefaultDialect):
                     coltype = INTEGER()
                 else:
                     coltype = NUMBER(precision, scale)
+            elif coltype == 'FLOAT':
+                # TODO: support "precision" here as "binary_precision"
+                coltype = FLOAT()
             elif coltype in ('VARCHAR2', 'NVARCHAR2', 'CHAR'):
                 coltype = self.ischema_names.get(coltype)(length)
             elif 'WITH TIME ZONE' in coltype:
index 0bd682d19eeff75b1b82c9a5f7a18a391799ded2..2fbb2074c5680e5481dd7e76ad025119ad363dc9 100644 (file)
@@ -240,7 +240,6 @@ class _OracleInteger(sqltypes.Integer):
         return handler
 
 
-
 class _OracleNumeric(sqltypes.Numeric):
     is_number = False
 
@@ -323,6 +322,19 @@ class _OracleNumeric(sqltypes.Numeric):
         return handler
 
 
+class _OracleBinaryFloat(_OracleNumeric):
+    def get_dbapi_type(self, dbapi):
+        return dbapi.NATIVE_FLOAT
+
+
+class _OracleBINARY_FLOAT(_OracleBinaryFloat, oracle.BINARY_FLOAT):
+    pass
+
+
+class _OracleBINARY_DOUBLE(_OracleBinaryFloat, oracle.BINARY_DOUBLE):
+    pass
+
+
 class _OracleNUMBER(_OracleNumeric):
     is_number = True
 
@@ -597,6 +609,8 @@ class OracleDialect_cx_oracle(OracleDialect):
     colspecs = {
         sqltypes.Numeric: _OracleNumeric,
         sqltypes.Float: _OracleNumeric,
+        oracle.BINARY_FLOAT: _OracleBINARY_FLOAT,
+        oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE,
         sqltypes.Integer: _OracleInteger,
         oracle.NUMBER: _OracleNUMBER,
 
@@ -654,7 +668,7 @@ class OracleDialect_cx_oracle(OracleDialect):
                 cx_Oracle.NCLOB, cx_Oracle.CLOB, cx_Oracle.LOB,
                 cx_Oracle.NCHAR, cx_Oracle.FIXED_NCHAR,
                 cx_Oracle.BLOB, cx_Oracle.FIXED_CHAR, cx_Oracle.TIMESTAMP,
-                _OracleInteger
+                _OracleInteger, _OracleBINARY_FLOAT, _OracleBINARY_DOUBLE
             }
 
         self._is_cx_oracle_6 = self.cx_oracle_ver >= (6, )
index ea806deaeee873dd976b9e440ecfe0ac91d71b56..4d5f338bf23e7f3cabf0640c408051f9bb8c5e29 100644 (file)
@@ -1129,7 +1129,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         key_to_dbapi_type = {}
         for bindparam in self.compiled.bind_names:
             key = self.compiled.bind_names[bindparam]
-            dialect_impl = bindparam.type.dialect_impl(self.dialect)
+            dialect_impl = bindparam.type._unwrapped_dialect_impl(self.dialect)
             dialect_impl_cls = type(dialect_impl)
             dbtype = dialect_impl.get_dbapi_type(self.dialect.dbapi)
             if dbtype is not None and (
index 1d1d6e089a1bf96496786b05763726de2e7350e0..6a323683bfe5dd97c3dbb037550531b91f8d2dba 100644 (file)
@@ -445,6 +445,20 @@ class TypeEngine(Visitable):
         except KeyError:
             return self._dialect_info(dialect)['impl']
 
+    def _unwrapped_dialect_impl(self, dialect):
+        """Return the 'unwrapped' dialect impl for this type.
+
+        For a type that applies wrapping logic (e.g. TypeDecorator), give
+        us the real, actual dialect-level type that is used.
+
+        This is used by TypeDecorator itself as well at least one case where
+        dialects need to check that a particular specific dialect-level
+        type is in use, within the :meth:`.DefaultDialect.set_input_sizes`
+        method.
+
+        """
+        return self.dialect_impl(dialect)
+
     def _cached_literal_processor(self, dialect):
         """Return a dialect-specific literal processor for this type."""
         try:
@@ -922,7 +936,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         # otherwise adapt the impl type, link
         # to a copy of this TypeDecorator and return
         # that.
-        typedesc = self.load_dialect_impl(dialect).dialect_impl(dialect)
+        typedesc = self._unwrapped_dialect_impl(dialect)
         tt = self.copy()
         if not isinstance(tt, self.__class__):
             raise AssertionError('Type object %s does not properly '
@@ -989,6 +1003,20 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         """
         return self.impl
 
+    def _unwrapped_dialect_impl(self, dialect):
+        """Return the 'unwrapped' dialect impl for this type.
+
+        For a type that applies wrapping logic (e.g. TypeDecorator), give
+        us the real, actual dialect-level type that is used.
+
+        This is used by TypeDecorator itself as well at least one case where
+        dialects need to check that a particular specific dialect-level
+        type is in use, within the :meth:`.DefaultDialect.set_input_sizes`
+        method.
+
+        """
+        return self.load_dialect_impl(dialect).dialect_impl(dialect)
+
     def __getattr__(self, key):
         """Proxy all other undefined accessors to the underlying
         implementation."""
index 190fd9f38efabb537011d35ff7c84f3e39349eec..f749e513acf2c204d20a41b2b9d90b606eb61152 100644 (file)
@@ -13,7 +13,8 @@ from sqlalchemy import Integer, Text, LargeBinary, Unicode, UniqueConstraint,\
     ForeignKeyConstraint, Sequence, Float, DateTime, cast, UnicodeText, \
     union, except_, type_coerce, or_, outerjoin, DATE, NCHAR, outparam, \
     PrimaryKeyConstraint, FLOAT, INTEGER
-from sqlalchemy.dialects.oracle.base import NUMBER
+from sqlalchemy.dialects.oracle.base import NUMBER, BINARY_DOUBLE, \
+    BINARY_FLOAT, DOUBLE_PRECISION
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing.schema import Table, Column
@@ -578,3 +579,19 @@ class TypeReflectionTest(fixtures.TestBase):
             (NUMBER, NUMBER(),),
         ]
         self._run_test(specs, ['precision', 'scale'])
+
+    def test_float_types(self):
+        specs = [
+            (DOUBLE_PRECISION(), FLOAT()),
+            # when binary_precision is supported
+            # (DOUBLE_PRECISION(), oracle.FLOAT(binary_precision=126)),
+            (BINARY_DOUBLE(), BINARY_DOUBLE()),
+            (BINARY_FLOAT(), BINARY_FLOAT()),
+            (FLOAT(5), FLOAT(),),
+            # when binary_precision is supported
+            # (FLOAT(5), oracle.FLOAT(binary_precision=5),),
+            (FLOAT(), FLOAT()),
+            # when binary_precision is supported
+            # (FLOAT(5), oracle.FLOAT(binary_precision=126),),
+        ]
+        self._run_test(specs, ['precision'])
index a13a578b608bf26747cd072c25b4107a388cffcc..9e4fa5996b14dbdf119f4cfb65ab5ba7cc62f01d 100644 (file)
@@ -15,6 +15,7 @@ from sqlalchemy import Integer, Text, LargeBinary, Unicode, UniqueConstraint,\
     ForeignKeyConstraint, Sequence, Float, DateTime, cast, UnicodeText, \
     union, except_, type_coerce, or_, outerjoin, DATE, NCHAR, outparam, \
     PrimaryKeyConstraint, FLOAT
+from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.util import u, b
 from sqlalchemy import util
 from sqlalchemy.testing import assert_raises, assert_raises_message
@@ -336,6 +337,75 @@ class TypesTest(fixtures.TestBase):
             [(decimal.Decimal("Infinity"), ), (decimal.Decimal("-Infinity"), )]
         )
 
+    @testing.provide_metadata
+    def test_numeric_nan_float(self):
+        m = self.metadata
+        t1 = Table('t1', m,
+                   Column("intcol", Integer),
+                   Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)))
+        t1.create()
+        t1.insert().execute([
+            dict(
+                intcol=1,
+                numericcol=float("nan")
+            ),
+            dict(
+                intcol=2,
+                numericcol=float("-nan")
+            ),
+        ])
+
+        eq_(
+            [
+                tuple(str(col) for col in row)
+                for row in select([t1.c.numericcol]).
+                order_by(t1.c.intcol).execute()
+            ],
+            [('nan', ), ('nan', )]
+        )
+
+        eq_(
+            [
+                tuple(str(col) for col in row)
+                for row in testing.db.execute(
+                    "select numericcol from t1 order by intcol"
+                )
+            ],
+            [('nan', ), ('nan', )]
+
+        )
+
+    # needs https://github.com/oracle/python-cx_Oracle/issues/184#issuecomment-391399292
+    @testing.provide_metadata
+    def _dont_test_numeric_nan_decimal(self):
+        m = self.metadata
+        t1 = Table('t1', m,
+                   Column("intcol", Integer),
+                   Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=True)))
+        t1.create()
+        t1.insert().execute([
+            dict(
+                intcol=1,
+                numericcol=decimal.Decimal("NaN")
+            ),
+            dict(
+                intcol=2,
+                numericcol=decimal.Decimal("-NaN")
+            ),
+        ])
+
+        eq_(
+            select([t1.c.numericcol]).
+            order_by(t1.c.intcol).execute().fetchall(),
+            [(decimal.Decimal("NaN"), ), (decimal.Decimal("NaN"), )]
+        )
+
+        eq_(
+            testing.db.execute(
+                "select numericcol from t1 order by intcol").fetchall(),
+            [(decimal.Decimal("NaN"), ), (decimal.Decimal("NaN"), )]
+        )
+
     @testing.provide_metadata
     def test_numerics_broken_inspection(self):
         """Numeric scenarios where Oracle type info is 'broken',
@@ -831,9 +901,31 @@ class SetInputSizesTest(fixtures.TestBase):
 
     @testing.provide_metadata
     def _test_setinputsizes(self, datatype, value, sis_value):
+        class TestTypeDec(TypeDecorator):
+            impl = NullType()
+
+            def load_dialect_impl(self, dialect):
+                if dialect.name == 'oracle':
+                    return dialect.type_descriptor(datatype)
+                else:
+                    return self.impl
+
         m = self.metadata
-        t1 = Table('t1', m, Column('foo', datatype))
-        t1.create()
+        # Oracle can have only one column of type LONG so we make three
+        # tables rather than one table w/ three columns
+        t1 = Table(
+            't1', m,
+            Column('foo', datatype),
+        )
+        t2 = Table(
+            't2', m,
+            Column('foo', NullType().with_variant(datatype, "oracle")),
+        )
+        t3 = Table(
+            't3', m,
+            Column('foo', TestTypeDec())
+        )
+        m.create_all()
 
         class CursorWrapper(object):
             # cx_oracle cursor can't be modified so we have to
@@ -853,24 +945,26 @@ class SetInputSizesTest(fixtures.TestBase):
 
         with testing.db.connect() as conn:
             connection_fairy = conn.connection
-            with mock.patch.object(
-                    connection_fairy, "cursor",
-                    lambda: CursorWrapper(connection_fairy)
-            ):
-                conn.execute(
-                    t1.insert(), {"foo": value}
-                )
-
-            if sis_value:
-                eq_(
-                    conn.info['mock'].mock_calls,
-                    [mock.call.setinputsizes(foo=sis_value)]
-                )
-            else:
-                eq_(
-                    conn.info['mock'].mock_calls,
-                    [mock.call.setinputsizes()]
-                )
+            for tab in [t1, t2, t3]:
+                with mock.patch.object(
+                        connection_fairy, "cursor",
+                        lambda: CursorWrapper(connection_fairy)
+                ):
+                    conn.execute(
+                        tab.insert(),
+                        {"foo": value}
+                    )
+
+                if sis_value:
+                    eq_(
+                        conn.info['mock'].mock_calls,
+                        [mock.call.setinputsizes(foo=sis_value)]
+                    )
+                else:
+                    eq_(
+                        conn.info['mock'].mock_calls,
+                        [mock.call.setinputsizes()]
+                    )
 
     def test_smallint_setinputsizes(self):
         self._test_setinputsizes(
@@ -887,6 +981,21 @@ class SetInputSizesTest(fixtures.TestBase):
     def test_float_setinputsizes(self):
         self._test_setinputsizes(Float(15), 25.34534, None)
 
+    def test_binary_double_setinputsizes(self):
+        self._test_setinputsizes(
+            oracle.BINARY_DOUBLE, 25.34534,
+            testing.db.dialect.dbapi.NATIVE_FLOAT)
+
+    def test_binary_float_setinputsizes(self):
+        self._test_setinputsizes(
+            oracle.BINARY_FLOAT, 25.34534,
+            testing.db.dialect.dbapi.NATIVE_FLOAT)
+
+    def test_double_precision_setinputsizes(self):
+        self._test_setinputsizes(
+            oracle.DOUBLE_PRECISION, 25.34534,
+            None)
+
     def test_unicode(self):
         self._test_setinputsizes(
             Unicode(30), u("test"), testing.db.dialect.dbapi.NCHAR)