From: Federico Caselli Date: Wed, 30 Dec 2020 19:36:27 +0000 (+0100) Subject: Support casting to ``FLOAT`` in MySQL and MariaDb. X-Git-Tag: rel_1_4_0b2~73^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=95d8f401839bdbe1399fb7d656c11024072f32b0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support casting to ``FLOAT`` in MySQL and MariaDb. Fixes: #5808 Change-Id: I8106ddcf681eec3cb3a67d853586702f6e844b9d --- diff --git a/doc/build/changelog/unreleased_13/5808.rst b/doc/build/changelog/unreleased_13/5808.rst new file mode 100644 index 0000000000..b6625c0507 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5808.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, mysql + :tickets: 5808 + + Casting to ``FLOAT`` is now supported in MySQL >= (8, 0, 17) and + MariaDb >= (10, 4, 5). \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7a4d3261f9..a4b5835413 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1624,6 +1624,11 @@ class MySQLCompiler(compiler.SQLCompiler): return self.dialect.type_compiler.process(type_).replace( "NUMERIC", "DECIMAL" ) + elif ( + isinstance(type_, sqltypes.Float) + and self.dialect._support_float_cast + ): + return self.dialect.type_compiler.process(type_) else: return None @@ -1631,7 +1636,7 @@ class MySQLCompiler(compiler.SQLCompiler): type_ = self.process(cast.typeclause) if type_ is None: util.warn( - "Datatype %s does not support CAST on MySQL; " + "Datatype %s does not support CAST on MySQL/MariaDb; " "the CAST will be skipped." % self.dialect.type_compiler.process(cast.typeclause.type) ) @@ -2899,6 +2904,17 @@ class MySQLDialect(default.DefaultDialect): "series, to avoid these issues." % (mdb_version,) ) + @property + def _support_float_cast(self): + if not self.server_version_info: + return False + elif self.is_mariadb: + # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/ + return self.server_version_info >= (10, 4, 5) + else: + # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa + return self.server_version_info >= (8, 0, 17) + @property def _is_mariadb(self): return self.is_mariadb diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 2993f96b8b..62292b9daa 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -710,7 +710,9 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_unsupported_cast_literal_bind(self): expr = cast(column("foo", Integer) + 5, Float) - with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"): + with expect_warnings( + "Datatype FLOAT does not support CAST on MySQL/MariaDb;" + ): self.assert_compile(expr, "(foo + 5)", literal_binds=True) m = mysql @@ -734,11 +736,35 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): def test_unsupported_casts(self, type_, expected): t = sql.table("t", sql.column("col")) - with expect_warnings("Datatype .* does not support CAST on MySQL;"): + with expect_warnings( + "Datatype .* does not support CAST on MySQL/MariaDb;" + ): self.assert_compile(cast(t.c.col, type_), expected) + @testing.combinations( + (m.FLOAT, "CAST(t.col AS FLOAT)"), + (Float, "CAST(t.col AS FLOAT)"), + (FLOAT, "CAST(t.col AS FLOAT)"), + (m.DOUBLE, "CAST(t.col AS DOUBLE)"), + (m.FLOAT, "CAST(t.col AS FLOAT)"), + argnames="type_,expected", + ) + @testing.combinations(True, False, argnames="maria_db") + def test_float_cast(self, type_, expected, maria_db): + + dialect = mysql.dialect() + if maria_db: + dialect.is_mariadb = maria_db + dialect.server_version_info = (10, 4, 5) + else: + dialect.server_version_info = (8, 0, 17) + t = sql.table("t", sql.column("col")) + self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect) + def test_cast_grouped_expression_non_castable(self): - with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"): + with expect_warnings( + "Datatype FLOAT does not support CAST on MySQL/MariaDb;" + ): self.assert_compile( cast(sql.column("x") + sql.column("y"), Float), "(x + y)" )