]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix SQL syntax for CAST with explicit collation
authorGord Thompson <gord@gordthompson.com>
Sat, 24 Jun 2023 22:05:42 +0000 (16:05 -0600)
committerGord Thompson <gord@gordthompson.com>
Tue, 27 Jun 2023 13:49:12 +0000 (07:49 -0600)
Fixes: #9932
Change-Id: I557e00cfc0725e2f247103dea484a7e818592f7f

doc/build/changelog/unreleased_20/9932.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_20/9932.rst b/doc/build/changelog/unreleased_20/9932.rst
new file mode 100644 (file)
index 0000000..71f395c
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, mssql, sql
+    :tickets: 9932
+
+    Fixed issue where performing :class:`.Cast` to a string type with an
+    explicit collation would render the COLLATE clause inside the CAST
+    function, which resulted in a syntax error.
index 79092ec6619507fbffde4543c709477d006eec29..185d71a70ec35e0f60e04470584b3ae0e3e2c8f8 100644 (file)
@@ -2768,9 +2768,12 @@ class SQLCompiler(Compiled):
         return type_coerce.typed_expression._compiler_dispatch(self, **kw)
 
     def visit_cast(self, cast, **kwargs):
-        return "CAST(%s AS %s)" % (
+        type_clause = cast.typeclause._compiler_dispatch(self, **kwargs)
+        match = re.match("(.*)( COLLATE .*)", type_clause)
+        return "CAST(%s AS %s)%s" % (
             cast.clause._compiler_dispatch(self, **kwargs),
-            cast.typeclause._compiler_dispatch(self, **kwargs),
+            match.group(1) if match else type_clause,
+            match.group(2) if match else "",
         )
 
     def _format_frame_clause(self, range_, **kw):
index 63ccb961f12b4ddf1bcacceaa0fe434adbfee276..655897afbb0df0cdb9a1093dc18d38ce6cd274ba 100644 (file)
@@ -2969,6 +2969,48 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=sqlite.dialect(),
         )
 
+    @testing.combinations(
+        (
+            "default",
+            None,
+            "SELECT CAST(t1.txt AS VARCHAR(10)) AS txt FROM t1",
+            None,
+        ),
+        (
+            "explicit_mssql",
+            "Latin1_General_CI_AS",
+            "SELECT CAST(t1.txt AS VARCHAR(10)) COLLATE Latin1_General_CI_AS AS txt FROM t1",  # noqa
+            mssql.dialect(),
+        ),
+        (
+            "explicit_mysql",
+            "utf8mb4_unicode_ci",
+            "SELECT CAST(t1.txt AS CHAR(10)) AS txt FROM t1",
+            mysql.dialect(),
+        ),
+        (
+            "explicit_postgresql",
+            "en_US",
+            'SELECT CAST(t1.txt AS VARCHAR(10)) COLLATE "en_US" AS txt FROM t1',  # noqa
+            postgresql.dialect(),
+        ),
+        (
+            "explicit_sqlite",
+            "NOCASE",
+            'SELECT CAST(t1.txt AS VARCHAR(10)) COLLATE "NOCASE" AS txt FROM t1',  # noqa
+            sqlite.dialect(),
+        ),
+        id_="iaaa",
+    )
+    def test_cast_with_collate(self, collation_name, expected_sql, dialect):
+        t1 = Table(
+            "t1",
+            MetaData(),
+            Column("txt", String(10, collation=collation_name)),
+        )
+        stmt = select(func.cast(t1.c.txt, t1.c.txt.type))
+        self.assert_compile(stmt, expected_sql, dialect=dialect)
+
     def test_over(self):
         self.assert_compile(func.row_number().over(), "row_number() OVER ()")
         self.assert_compile(