]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure float are not implemented as numeric
authorFederico Caselli <cfederico87@gmail.com>
Wed, 26 Apr 2023 19:40:38 +0000 (21:40 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 May 2023 14:14:58 +0000 (10:14 -0400)
Fixed the base class for dialect-specific float/double types; Oracle
:class:`_oracle.BINARY_DOUBLE` now subclasses :class:`_sqltypes.Double`,
and internal types for :class:`_sqltypes.Float` for asyncpg and pg8000 now
correctly subclass :class:`_sqltypes.Float`.

Added suite tests to ensure that floating point types, such as
class:`_types.Float` and :class:`_types.Double` are not resolved as
class:`_types.Numeric` in the dialect, since it may not compatible in
all cases, such as when casting a value.

Change-Id: I20b814e8e029d57921d9728a55f2570f74c35c87

doc/build/changelog/unreleased_20/suite_float_tests.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/types.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_types.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/suite_float_tests.rst b/doc/build/changelog/unreleased_20/suite_float_tests.rst
new file mode 100644 (file)
index 0000000..06e7fcd
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, sql
+
+    Fixed the base class for dialect-specific float/double types; Oracle
+    :class:`_oracle.BINARY_DOUBLE` now subclasses :class:`_sqltypes.Double`,
+    and internal types for :class:`_sqltypes.Float` for asyncpg and pg8000 now
+    correctly subclass :class:`_sqltypes.Float`.
index 1c252ab897b85bc6de5b495535d7c9d5090c4712..62028c767386df9458285c2973b006673c86fbb0 100644 (file)
@@ -97,7 +97,7 @@ class FLOAT(sqltypes.FLOAT):
         self.binary_precision = binary_precision
 
 
-class BINARY_DOUBLE(sqltypes.Float):
+class BINARY_DOUBLE(sqltypes.Double):
     __visit_name__ = "BINARY_DOUBLE"
 
 
index c879205e4a18ae70670588a7f40ed699d059adb5..a25502b9078cdbb8930058b9a381e7f28aef1fb4 100644 (file)
@@ -322,7 +322,7 @@ class AsyncpgNumeric(sqltypes.Numeric):
                 )
 
 
-class AsyncpgFloat(AsyncpgNumeric):
+class AsyncpgFloat(AsyncpgNumeric, sqltypes.Float):
     __visit_name__ = "float"
     render_bind_cast = True
 
index 3f01b00e85dfb6e49127d517da14acaae51738f1..a32d375c7bf3df376bac099ef73bbfea1ae803dd 100644 (file)
@@ -148,7 +148,7 @@ class _PGNumeric(sqltypes.Numeric):
                 )
 
 
-class _PGFloat(_PGNumeric):
+class _PGFloat(_PGNumeric, sqltypes.Float):
     __visit_name__ = "float"
     render_bind_cast = True
 
index b59cce3748ae75bab0a0a4cdad25ca2bed346a30..c1d6a14aaa84d9f1ecf120c8ac805cf3b97b063e 100644 (file)
@@ -1240,6 +1240,12 @@ class SuiteRequirements(Requirements):
 
         return exclusions.open()
 
+    @property
+    def float_is_numeric(self):
+        """target backend uses Numeric for Float/Dual"""
+
+        return exclusions.open()
+
     @property
     def text_type(self):
         """Target database must support an unbounded Text() "
index 72f1e8c10f68866849d8b1056c19352b3131e2ba..92781cc1b393e32501c4623762710b0e6cd4f46a 100644 (file)
@@ -13,6 +13,7 @@ from .. import fixtures
 from .. import mock
 from ..assertions import eq_
 from ..assertions import is_
+from ..assertions import ne_
 from ..config import requirements
 from ..schema import Column
 from ..schema import Table
@@ -47,6 +48,7 @@ from ... import UUID
 from ... import Uuid
 from ...orm import declarative_base
 from ...orm import Session
+from ...sql import sqltypes
 from ...sql.sqltypes import LargeBinary
 from ...sql.sqltypes import PickleType
 
@@ -1090,6 +1092,15 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
             Numeric(precision=5, scale=3), numbers, numbers, check_scale=True
         )
 
+    @testing.combinations(sqltypes.Float, sqltypes.Double, argnames="cls_")
+    @testing.requires.float_is_numeric
+    def test_float_is_not_numeric(self, connection, cls_):
+        target_type = cls_().dialect_impl(connection.dialect)
+        numeric_type = sqltypes.Numeric().dialect_impl(connection.dialect)
+
+        ne_(target_type.__visit_name__, numeric_type.__visit_name__)
+        ne_(target_type.__class__, numeric_type.__class__)
+
 
 class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     __backend__ = True
index 3c72cd07df7f07592727deddef229d1666e2bdf5..ae72002382f70f594184254f2238cfc13327033c 100644 (file)
@@ -1446,6 +1446,10 @@ class DefaultRequirements(SuiteRequirements):
     def fetch_null_from_numeric(self):
         return skip_if(("mssql+pyodbc", None, None, "crashes due to bug #351"))
 
+    @property
+    def float_is_numeric(self):
+        return exclusions.fails_if(["oracle"])
+
     @property
     def duplicate_key_raises_integrity_error(self):
         return exclusions.open()