]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support casting to ``FLOAT`` in MySQL and MariaDb.
authorFederico Caselli <cfederico87@gmail.com>
Wed, 30 Dec 2020 19:36:27 +0000 (20:36 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 30 Dec 2020 21:26:23 +0000 (22:26 +0100)
Fixes: #5808
Change-Id: I8106ddcf681eec3cb3a67d853586702f6e844b9d

doc/build/changelog/unreleased_13/5808.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_13/5808.rst b/doc/build/changelog/unreleased_13/5808.rst
new file mode 100644 (file)
index 0000000..b6625c0
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, mysql
+    :tickets: 5808
+
+    Casting to ``FLOAT`` is now supported in MySQL >= (8, 0, 17) and
+    MariaDb >= (10, 4, 5).
\ No newline at end of file
index 7a4d3261f95bb9c890c9d88b42d2ede8bb25a41b..a4b5835413a241d8c9c897f9cfe6fc7115e0b6d0 100644 (file)
@@ -1624,6 +1624,11 @@ class MySQLCompiler(compiler.SQLCompiler):
             return self.dialect.type_compiler.process(type_).replace(
                 "NUMERIC", "DECIMAL"
             )
+        elif (
+            isinstance(type_, sqltypes.Float)
+            and self.dialect._support_float_cast
+        ):
+            return self.dialect.type_compiler.process(type_)
         else:
             return None
 
@@ -1631,7 +1636,7 @@ class MySQLCompiler(compiler.SQLCompiler):
         type_ = self.process(cast.typeclause)
         if type_ is None:
             util.warn(
-                "Datatype %s does not support CAST on MySQL; "
+                "Datatype %s does not support CAST on MySQL/MariaDb; "
                 "the CAST will be skipped."
                 % self.dialect.type_compiler.process(cast.typeclause.type)
             )
@@ -2899,6 +2904,17 @@ class MySQLDialect(default.DefaultDialect):
                     "series, to avoid these issues." % (mdb_version,)
                 )
 
+    @property
+    def _support_float_cast(self):
+        if not self.server_version_info:
+            return False
+        elif self.is_mariadb:
+            # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
+            return self.server_version_info >= (10, 4, 5)
+        else:
+            # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature  # noqa
+            return self.server_version_info >= (8, 0, 17)
+
     @property
     def _is_mariadb(self):
         return self.is_mariadb
index 2993f96b8b0c27b52ae94f0eda795a670dff3df5..62292b9daad78b379247b1ba3a4c0deeac93a0cb 100644 (file)
@@ -710,7 +710,9 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_unsupported_cast_literal_bind(self):
         expr = cast(column("foo", Integer) + 5, Float)
 
-        with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
+        with expect_warnings(
+            "Datatype FLOAT does not support CAST on MySQL/MariaDb;"
+        ):
             self.assert_compile(expr, "(foo + 5)", literal_binds=True)
 
     m = mysql
@@ -734,11 +736,35 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_unsupported_casts(self, type_, expected):
 
         t = sql.table("t", sql.column("col"))
-        with expect_warnings("Datatype .* does not support CAST on MySQL;"):
+        with expect_warnings(
+            "Datatype .* does not support CAST on MySQL/MariaDb;"
+        ):
             self.assert_compile(cast(t.c.col, type_), expected)
 
+    @testing.combinations(
+        (m.FLOAT, "CAST(t.col AS FLOAT)"),
+        (Float, "CAST(t.col AS FLOAT)"),
+        (FLOAT, "CAST(t.col AS FLOAT)"),
+        (m.DOUBLE, "CAST(t.col AS DOUBLE)"),
+        (m.FLOAT, "CAST(t.col AS FLOAT)"),
+        argnames="type_,expected",
+    )
+    @testing.combinations(True, False, argnames="maria_db")
+    def test_float_cast(self, type_, expected, maria_db):
+
+        dialect = mysql.dialect()
+        if maria_db:
+            dialect.is_mariadb = maria_db
+            dialect.server_version_info = (10, 4, 5)
+        else:
+            dialect.server_version_info = (8, 0, 17)
+        t = sql.table("t", sql.column("col"))
+        self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect)
+
     def test_cast_grouped_expression_non_castable(self):
-        with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
+        with expect_warnings(
+            "Datatype FLOAT does not support CAST on MySQL/MariaDb;"
+        ):
             self.assert_compile(
                 cast(sql.column("x") + sql.column("y"), Float), "(x + y)"
             )