]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support casting to ``FLOAT`` in MySQL and MariaDb.
authorFederico Caselli <cfederico87@gmail.com>
Wed, 30 Dec 2020 19:36:27 +0000 (20:36 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 30 Dec 2020 19:42:24 +0000 (20:42 +0100)
Fixes: #5808
Change-Id: I8106ddcf681eec3cb3a67d853586702f6e844b9d
(cherry picked from commit 0da7225ac16b966c1cc5f1b2afde4eb6856183aa)

doc/build/changelog/unreleased_13/5808.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_13/5808.rst b/doc/build/changelog/unreleased_13/5808.rst
new file mode 100644 (file)
index 0000000..b6625c0
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, mysql
+    :tickets: 5808
+
+    Casting to ``FLOAT`` is now supported in MySQL >= (8, 0, 17) and
+    MariaDb >= (10, 4, 5).
\ No newline at end of file
index 5962f7faba095a2a008a0eb0dc22941f1af33efa..65fede74a4e2ac846ffed713f2d814fbf272172a 100644 (file)
@@ -1505,6 +1505,11 @@ class MySQLCompiler(compiler.SQLCompiler):
             return self.dialect.type_compiler.process(type_).replace(
                 "NUMERIC", "DECIMAL"
             )
+        elif (
+            isinstance(type_, sqltypes.Float)
+            and self.dialect._support_float_cast
+        ):
+            return self.dialect.type_compiler.process(type_)
         else:
             return None
 
@@ -1520,7 +1525,7 @@ class MySQLCompiler(compiler.SQLCompiler):
         type_ = self.process(cast.typeclause)
         if type_ is None:
             util.warn(
-                "Datatype %s does not support CAST on MySQL; "
+                "Datatype %s does not support CAST on MySQL/MariaDb; "
                 "the CAST will be skipped."
                 % self.dialect.type_compiler.process(cast.typeclause.type)
             )
@@ -2659,6 +2664,17 @@ class MySQLDialect(default.DefaultDialect):
                     "series, to avoid these issues." % (mdb_version,)
                 )
 
+    @property
+    def _support_float_cast(self):
+        if not self.server_version_info:
+            return False
+        elif self._is_mariadb:
+            # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
+            return self.server_version_info >= (10, 4, 5)
+        else:
+            # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature  # noqa
+            return self.server_version_info >= (8, 0, 17)
+
     @property
     def _is_mariadb(self):
         return (
index 6a0202b22948325b7d283bcf3778f9d41af0fc52..85732692e23d02ebaa08e412dccfa1fe297ff422 100644 (file)
@@ -718,7 +718,9 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_unsupported_cast_literal_bind(self):
         expr = cast(column("foo", Integer) + 5, Float)
 
-        with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
+        with expect_warnings(
+            "Datatype FLOAT does not support CAST on MySQL/MariaDb;"
+        ):
             self.assert_compile(expr, "(foo + 5)", literal_binds=True)
 
         dialect = mysql.MySQLDialect()
@@ -754,9 +756,30 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_unsupported_casts(self, type_, expected):
 
         t = sql.table("t", sql.column("col"))
-        with expect_warnings("Datatype .* does not support CAST on MySQL;"):
+        with expect_warnings(
+            "Datatype .* does not support CAST on MySQL/MariaDb;"
+        ):
             self.assert_compile(cast(t.c.col, type_), expected)
 
+    @testing.combinations(
+        (m.FLOAT, "CAST(t.col AS FLOAT)"),
+        (Float, "CAST(t.col AS FLOAT)"),
+        (FLOAT, "CAST(t.col AS FLOAT)"),
+        (m.DOUBLE, "CAST(t.col AS DOUBLE)"),
+        (m.FLOAT, "CAST(t.col AS FLOAT)"),
+        argnames="type_,expected",
+    )
+    @testing.combinations(True, False, argnames="maria_db")
+    def test_float_cast(self, type_, expected, maria_db):
+
+        dialect = mysql.dialect()
+        if maria_db:
+            dialect.server_version_info = (10, 4, 5, "MariaDB")
+        else:
+            dialect.server_version_info = (8, 0, 17)
+        t = sql.table("t", sql.column("col"))
+        self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect)
+
     def test_no_cast_pre_4(self):
         self.assert_compile(
             cast(Column("foo", Integer), String), "CAST(foo AS CHAR)"
@@ -769,7 +792,9 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
             )
 
     def test_cast_grouped_expression_non_castable(self):
-        with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
+        with expect_warnings(
+            "Datatype FLOAT does not support CAST on MySQL/MariaDb;"
+        ):
             self.assert_compile(
                 cast(sql.column("x") + sql.column("y"), Float), "(x + y)"
             )