]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure correct cast for floats vs. numeric; other fixes
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Apr 2023 14:34:46 +0000 (10:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Apr 2023 19:40:18 +0000 (15:40 -0400)
Fixed regression caused by the fix for :ticket:`9618` where floating point
values would lose precision being inserted in bulk, using either the
psycopg2 or psycopg drivers.

Implemented the :class:`_sqltypes.Double` type for SQL Server, having it
resolve to ``REAL``, or :class:`_mssql.REAL`, at DDL rendering time.

Fixed issue in Oracle dialects where ``Decimal`` returning types such as
:class:`_sqltypes.Numeric` would return floating point values, rather than
``Decimal`` objects, when these columns were used in the
:meth:`_dml.Insert.returning` clause to return INSERTed values.

Fixes: #9701
Change-Id: I8865496a6ccac6d44c19d0ca2a642b63c6172da9

doc/build/changelog/unreleased_20/9701.rst [new file with mode: 0644]
doc/build/dialects/mssql.rst
lib/sqlalchemy/dialects/mssql/__init__.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_insert.py
test/dialect/mssql/test_types.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/9701.rst b/doc/build/changelog/unreleased_20/9701.rst
new file mode 100644 (file)
index 0000000..6c36340
--- /dev/null
@@ -0,0 +1,24 @@
+.. change::
+    :tags: bug, postgresql, regression
+    :tickets: 9701
+
+    Fixed regression caused by the fix for :ticket:`9618` where floating point
+    values would lose precision being inserted in bulk, using either the
+    psycopg2 or psycopg drivers.
+
+
+.. change::
+    :tags: bug, mssql
+
+    Implemented the :class:`_sqltypes.Double` type for SQL Server, having it
+    resolve to ``DOUBLE PRECISION``, or :class:`_mssql.DOUBLE_PRECISION`,
+    at DDL rendering time.
+
+
+.. change::
+    :tags: bug, oracle
+
+    Fixed issue in Oracle dialects where ``Decimal`` returning types such as
+    :class:`_sqltypes.Numeric` would return floating point values, rather than
+    ``Decimal`` objects, when these columns were used in the
+    :meth:`_dml.Insert.returning` clause to return INSERTed values.
index 3aff1008738c8be2cab4ea376026e151047e49dc..92334ef85aebc1e2957388ffe777ee646357aeee 100644 (file)
@@ -29,6 +29,7 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect::
         DATETIME2,
         DATETIMEOFFSET,
         DECIMAL,
+        DOUBLE_PRECISION,
         FLOAT,
         IMAGE,
         INTEGER,
@@ -77,6 +78,8 @@ construction arguments, are as follows:
 .. autoclass:: DATETIMEOFFSET
    :members: __init__
 
+.. autoclass:: DOUBLE_PRECISION
+   :members: __init__
 
 .. autoclass:: IMAGE
    :members: __init__
index 0ef858a971d0f4d4d01d37b059cab0d3361ce872..e6ad5e120a215122f866b0e0ed1da5dd85a71f4b 100644 (file)
@@ -19,6 +19,7 @@ from .base import DATETIME
 from .base import DATETIME2
 from .base import DATETIMEOFFSET
 from .base import DECIMAL
+from .base import DOUBLE_PRECISION
 from .base import FLOAT
 from .base import IMAGE
 from .base import INTEGER
index 4a7e48ab8a0d2bac77e8b98030d7a7372589f415..e66d01a353caea331a8ada3fa18c8d1d6f0152ee 100644 (file)
@@ -1156,7 +1156,7 @@ RESERVED_WORDS = {
 
 
 class REAL(sqltypes.REAL):
-    __visit_name__ = "REAL"
+    """the SQL Server REAL datatype."""
 
     def __init__(self, **kw):
         # REAL is a synonym for FLOAT(24) on SQL server.
@@ -1166,6 +1166,21 @@ class REAL(sqltypes.REAL):
         super().__init__(**kw)
 
 
+class DOUBLE_PRECISION(sqltypes.DOUBLE_PRECISION):
+    """the SQL Server DOUBLE PRECISION datatype.
+
+    .. versionadded:: 2.0.11
+
+    """
+
+    def __init__(self, **kw):
+        # DOUBLE PRECISION is a synonym for FLOAT(53) on SQL server.
+        # it is only accepted as the word "DOUBLE PRECISION" in DDL,
+        # the numeric precision value is not allowed to be present
+        kw.setdefault("precision", 53)
+        super().__init__(**kw)
+
+
 class TINYINT(sqltypes.Integer):
     __visit_name__ = "TINYINT"
 
@@ -1670,6 +1685,7 @@ ischema_names = {
     "varbinary": VARBINARY,
     "bit": BIT,
     "real": REAL,
+    "double precision": DOUBLE_PRECISION,
     "image": IMAGE,
     "xml": XML,
     "timestamp": TIMESTAMP,
@@ -1700,6 +1716,9 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
 
         return " ".join([c for c in (spec, collation) if c is not None])
 
+    def visit_double(self, type_, **kw):
+        return self.visit_DOUBLE_PRECISION(type_, **kw)
+
     def visit_FLOAT(self, type_, **kw):
         precision = getattr(type_, "precision", None)
         if precision is None:
index f6f10c4761f0796b1402a9a71417aca296f7381d..c0e308383b3349b0e9eb86b29a17ab0ad313212b 100644 (file)
@@ -819,6 +819,15 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
                                 outconverter=lambda value: value.read(),
                                 arraysize=len_params,
                             )
+                        elif (
+                            isinstance(type_impl, _OracleNumeric)
+                            and type_impl.asdecimal
+                        ):
+                            out_parameters[name] = self.cursor.var(
+                                decimal.Decimal,
+                                arraysize=len_params,
+                            )
+
                         else:
                             out_parameters[name] = self.cursor.var(
                                 dbtype, arraysize=len_params
index 739cbc5a9d0cf09b850cb32b79e944bd66ba4261..b985180994a52bfbd38dead4af92a5c267784648 100644 (file)
@@ -55,6 +55,10 @@ class _PsycopgNumeric(sqltypes.Numeric):
                 )
 
 
+class _PsycopgFloat(_PsycopgNumeric):
+    __visit_name__ = "float"
+
+
 class _PsycopgHStore(HSTORE):
     def bind_processor(self, dialect):
         if dialect._has_native_hstore:
@@ -104,6 +108,7 @@ class _PGDialect_common_psycopg(PGDialect):
         PGDialect.colspecs,
         {
             sqltypes.Numeric: _PsycopgNumeric,
+            sqltypes.Float: _PsycopgFloat,
             HSTORE: _PsycopgHStore,
             sqltypes.ARRAY: _PsycopgARRAY,
             INT2VECTOR: _PsycopgINT2VECTOR,
index 3332f7ce249aec8135cc4ea907550c590903e205..b59cce3748ae75bab0a0a4cdad25ca2bed346a30 100644 (file)
@@ -1195,6 +1195,10 @@ class SuiteRequirements(Requirements):
 
         return exclusions.closed()
 
+    @property
+    def float_or_double_precision_behaves_generically(self):
+        return exclusions.closed()
+
     @property
     def precision_generic_float_type(self):
         """target backend will return native floating point numbers with at
index ae54f6bcd47d2608a3cb7020a7a87476fb2ee0a5..d49eb3284ffecd0e6c3c8ab88de55e200c5893ee 100644 (file)
@@ -1,13 +1,20 @@
 # mypy: ignore-errors
 
+from decimal import Decimal
+
+from . import testing
 from .. import fixtures
 from ..assertions import eq_
 from ..config import requirements
 from ..schema import Column
 from ..schema import Table
+from ... import Double
+from ... import Float
+from ... import Identity
 from ... import Integer
 from ... import literal
 from ... import literal_column
+from ... import Numeric
 from ... import select
 from ... import String
 
@@ -378,5 +385,109 @@ class ReturningTest(fixtures.TablesTest):
 
         eq_(rall, pks.all())
 
+    @testing.combinations(
+        (Double(), 8.5514716, True),
+        (
+            Double(53),
+            8.5514716,
+            True,
+            testing.requires.float_or_double_precision_behaves_generically,
+        ),
+        (Float(), 8.5514, False),
+        (
+            Float(8),
+            8.5514,
+            True,
+            testing.requires.float_or_double_precision_behaves_generically,
+        ),
+        (
+            Numeric(precision=15, scale=12, asdecimal=False),
+            8.5514716,
+            True,
+            testing.requires.literal_float_coercion,
+        ),
+        (
+            Numeric(precision=15, scale=12, asdecimal=True),
+            Decimal("8.5514716"),
+            False,
+        ),
+        argnames="type_,value,do_rounding",
+    )
+    @testing.variation("sort_by_parameter_order", [True, False])
+    @testing.variation("multiple_rows", [True, False])
+    def test_insert_w_floats(
+        self,
+        connection,
+        metadata,
+        sort_by_parameter_order,
+        type_,
+        value,
+        do_rounding,
+        multiple_rows,
+    ):
+        """test #9701.
+
+        this tests insertmanyvalues as well as decimal / floating point
+        RETURNING types
+
+        """
+
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, Identity(), primary_key=True),
+            Column("value", type_),
+        )
+
+        t.create(connection)
+
+        result = connection.execute(
+            t.insert().returning(
+                t.c.id,
+                t.c.value,
+                sort_by_parameter_order=bool(sort_by_parameter_order),
+            ),
+            [{"value": value} for i in range(10)]
+            if multiple_rows
+            else {"value": value},
+        )
+
+        if multiple_rows:
+            i_range = range(1, 11)
+        else:
+            i_range = range(1, 2)
+
+        # we want to test only that we are getting floating points back
+        # with some degree of the original value maintained, that it is not
+        # being truncated to an integer.  there's too much variation in how
+        # drivers return floats, which should not be relied upon to be
+        # exact, for us to just compare as is (works for PG drivers but not
+        # others) so we use rounding here.  There's precedent for this
+        # in suite/test_types.py::NumericTest as well
+
+        if do_rounding:
+            eq_(
+                {(id_, round(val_, 5)) for id_, val_ in result},
+                {(id_, round(value, 5)) for id_ in i_range},
+            )
+
+            eq_(
+                {
+                    round(val_, 5)
+                    for val_ in connection.scalars(select(t.c.value))
+                },
+                {round(value, 5)},
+            )
+        else:
+            eq_(
+                set(result),
+                {(id_, value) for id_ in i_range},
+            )
+
+            eq_(
+                set(connection.scalars(select(t.c.value))),
+                {value},
+            )
+
 
 __all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
index d94984e0e614f402752040052e6e1151d82d26fe..916d4252fcd1d3a0770987b334c9eb09ade85934 100644 (file)
@@ -300,6 +300,8 @@ class TypeDDLTest(fixtures.TestBase):
             (types.Float, [None], {}, "FLOAT"),
             (types.Float, [12], {}, "FLOAT(12)"),
             (mssql.MSReal, [], {}, "REAL"),
+            (types.Double, [], {}, "DOUBLE PRECISION"),
+            (types.Double, [53], {}, "DOUBLE PRECISION"),
             (types.Integer, [], {}, "INTEGER"),
             (types.BigInteger, [], {}, "BIGINT"),
             (mssql.MSTinyInteger, [], {}, "TINYINT"),
index 9a8500ac3ecf8c5ee6bbc2b49d9a94edf7124ae4..3c72cd07df7f07592727deddef229d1666e2bdf5 100644 (file)
@@ -1395,6 +1395,10 @@ class DefaultRequirements(SuiteRequirements):
             "postgresql+pg8000", "seems to work on pg14 only, not earlier?"
         )
 
+    @property
+    def float_or_double_precision_behaves_generically(self):
+        return skip_if(["oracle", "mysql", "mariadb"])
+
     @property
     def precision_generic_float_type(self):
         """target backend will return native floating point numbers with at