return self.dialect.type_compiler.process(type_).replace(
"NUMERIC", "DECIMAL"
)
+ elif (
+ isinstance(type_, sqltypes.Float)
+ and self.dialect._support_float_cast
+ ):
+ return self.dialect.type_compiler.process(type_)
else:
return None
type_ = self.process(cast.typeclause)
if type_ is None:
util.warn(
- "Datatype %s does not support CAST on MySQL; "
+ "Datatype %s does not support CAST on MySQL/MariaDb; "
"the CAST will be skipped."
% self.dialect.type_compiler.process(cast.typeclause.type)
)
"series, to avoid these issues." % (mdb_version,)
)
+ @property
+ def _support_float_cast(self):
+ if not self.server_version_info:
+ return False
+ elif self.is_mariadb:
+ # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
+ return self.server_version_info >= (10, 4, 5)
+ else:
+ # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa
+ return self.server_version_info >= (8, 0, 17)
+
@property
def _is_mariadb(self):
return self.is_mariadb
def test_unsupported_cast_literal_bind(self):
expr = cast(column("foo", Integer) + 5, Float)
- with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
+ with expect_warnings(
+ "Datatype FLOAT does not support CAST on MySQL/MariaDb;"
+ ):
self.assert_compile(expr, "(foo + 5)", literal_binds=True)
m = mysql
def test_unsupported_casts(self, type_, expected):
t = sql.table("t", sql.column("col"))
- with expect_warnings("Datatype .* does not support CAST on MySQL;"):
+ with expect_warnings(
+ "Datatype .* does not support CAST on MySQL/MariaDb;"
+ ):
self.assert_compile(cast(t.c.col, type_), expected)
+ @testing.combinations(
+ (m.FLOAT, "CAST(t.col AS FLOAT)"),
+ (Float, "CAST(t.col AS FLOAT)"),
+ (FLOAT, "CAST(t.col AS FLOAT)"),
+ (m.DOUBLE, "CAST(t.col AS DOUBLE)"),
+ (m.FLOAT, "CAST(t.col AS FLOAT)"),
+ argnames="type_,expected",
+ )
+ @testing.combinations(True, False, argnames="maria_db")
+ def test_float_cast(self, type_, expected, maria_db):
+
+ dialect = mysql.dialect()
+ if maria_db:
+ dialect.is_mariadb = maria_db
+ dialect.server_version_info = (10, 4, 5)
+ else:
+ dialect.server_version_info = (8, 0, 17)
+ t = sql.table("t", sql.column("col"))
+ self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect)
+
def test_cast_grouped_expression_non_castable(self):
- with expect_warnings("Datatype FLOAT does not support CAST on MySQL;"):
+ with expect_warnings(
+ "Datatype FLOAT does not support CAST on MySQL/MariaDb;"
+ ):
self.assert_compile(
cast(sql.column("x") + sql.column("y"), Float), "(x + y)"
)