From 12e15ca0c74da1f69a868b1db971d7cad7128c64 Mon Sep 17 00:00:00 2001 From: Gord Thompson Date: Sat, 24 Jun 2023 16:05:42 -0600 Subject: [PATCH] Fix SQL syntax for CAST with explicit collation Fixes: #9932 Change-Id: I557e00cfc0725e2f247103dea484a7e818592f7f --- doc/build/changelog/unreleased_20/9932.rst | 7 ++++ lib/sqlalchemy/sql/compiler.py | 7 ++-- test/sql/test_compiler.py | 42 ++++++++++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9932.rst diff --git a/doc/build/changelog/unreleased_20/9932.rst b/doc/build/changelog/unreleased_20/9932.rst new file mode 100644 index 0000000000..71f395c651 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9932.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, mssql, sql + :tickets: 9932 + + Fixed issue where performing :class:`.Cast` to a string type with an + explicit collation would render the COLLATE clause inside the CAST + function, which resulted in a syntax error. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 79092ec661..185d71a70e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2768,9 +2768,12 @@ class SQLCompiler(Compiled): return type_coerce.typed_expression._compiler_dispatch(self, **kw) def visit_cast(self, cast, **kwargs): - return "CAST(%s AS %s)" % ( + type_clause = cast.typeclause._compiler_dispatch(self, **kwargs) + match = re.match("(.*)( COLLATE .*)", type_clause) + return "CAST(%s AS %s)%s" % ( cast.clause._compiler_dispatch(self, **kwargs), - cast.typeclause._compiler_dispatch(self, **kwargs), + match.group(1) if match else type_clause, + match.group(2) if match else "", ) def _format_frame_clause(self, range_, **kw): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 63ccb961f1..655897afbb 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2969,6 +2969,48 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): dialect=sqlite.dialect(), ) + @testing.combinations( + ( + "default", + None, + "SELECT CAST(t1.txt AS VARCHAR(10)) AS txt FROM t1", + None, + ), + ( + "explicit_mssql", + "Latin1_General_CI_AS", + "SELECT CAST(t1.txt AS VARCHAR(10)) COLLATE Latin1_General_CI_AS AS txt FROM t1", # noqa + mssql.dialect(), + ), + ( + "explicit_mysql", + "utf8mb4_unicode_ci", + "SELECT CAST(t1.txt AS CHAR(10)) AS txt FROM t1", + mysql.dialect(), + ), + ( + "explicit_postgresql", + "en_US", + 'SELECT CAST(t1.txt AS VARCHAR(10)) COLLATE "en_US" AS txt FROM t1', # noqa + postgresql.dialect(), + ), + ( + "explicit_sqlite", + "NOCASE", + 'SELECT CAST(t1.txt AS VARCHAR(10)) COLLATE "NOCASE" AS txt FROM t1', # noqa + sqlite.dialect(), + ), + id_="iaaa", + ) + def test_cast_with_collate(self, collation_name, expected_sql, dialect): + t1 = Table( + "t1", + MetaData(), + Column("txt", String(10, collation=collation_name)), + ) + stmt = select(func.cast(t1.c.txt, t1.c.txt.type)) + self.assert_compile(stmt, expected_sql, dialect=dialect) + def test_over(self): self.assert_compile(func.row_number().over(), "row_number() OVER ()") self.assert_compile( -- 2.39.5