]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pass asfrom correctly in compilers
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Apr 2021 04:30:29 +0000 (00:30 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Apr 2021 04:37:47 +0000 (00:37 -0400)
Fixed an argument error in the default and PostgreSQL compilers that
would interfere with an UPDATE..FROM or DELETE..FROM..USING statement
that was then SELECTed from as a CTE.

The incorrect pattern was also fixed in the mysql and sybase dialects.
MySQL supports CTEs but not "returning".

Fixes: #6303
Change-Id: Ic94805611a5ec443749fb6b1fd8a1326b0d83ef7

doc/build/changelog/unreleased_14/6303.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sybase/base.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_14/6303.rst b/doc/build/changelog/unreleased_14/6303.rst
new file mode 100644 (file)
index 0000000..8c7a9f2
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, postgresql, sql, regression
+    :tickets: 6303
+
+    Fixed an argument error in the default and PostgreSQL compilers that
+    would interfere with an UPDATE..FROM or DELETE..FROM..USING statement
+    that was then SELECTed from as a CTE.
index c5113b054b6051c35b476b775463fd28b3a1f3ca..d4c70a78e6b013ad6b7b300abcc464928e92c59f 100644 (file)
@@ -1783,8 +1783,9 @@ class MySQLCompiler(compiler.SQLCompiler):
             return None
 
     def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+        kw["asfrom"] = True
         return ", ".join(
-            t._compiler_dispatch(self, asfrom=True, **kw)
+            t._compiler_dispatch(self, **kw)
             for t in [from_table] + list(extra_froms)
         )
 
@@ -1806,8 +1807,9 @@ class MySQLCompiler(compiler.SQLCompiler):
         self, delete_stmt, from_table, extra_froms, from_hints, **kw
     ):
         """Render the DELETE .. USING clause specific to MySQL."""
+        kw["asfrom"] = True
         return "USING " + ", ".join(
-            t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+            t._compiler_dispatch(self, fromhints=from_hints, **kw)
             for t in [from_table] + extra_froms
         )
 
index 0e9968031a5497495a4c6aa334bfa23a75108294..47a933479f781d7e6fc61205d9b50c786ec19654 100644 (file)
@@ -2420,8 +2420,9 @@ class PGCompiler(compiler.SQLCompiler):
     def update_from_clause(
         self, update_stmt, from_table, extra_froms, from_hints, **kw
     ):
+        kw["asfrom"] = True
         return "FROM " + ", ".join(
-            t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+            t._compiler_dispatch(self, fromhints=from_hints, **kw)
             for t in extra_froms
         )
 
@@ -2429,8 +2430,9 @@ class PGCompiler(compiler.SQLCompiler):
         self, delete_stmt, from_table, extra_froms, from_hints, **kw
     ):
         """Render the DELETE .. USING clause specific to PostgreSQL."""
+        kw["asfrom"] = True
         return "USING " + ", ".join(
-            t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+            t._compiler_dispatch(self, fromhints=from_hints, **kw)
             for t in extra_froms
         )
 
index 7c10973e62c54bb109eb251684ab811f62faadbb..5b6aefe9e5df01fb0fc59e259c7d64ba2a3fe6db 100644 (file)
@@ -564,8 +564,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler):
         self, delete_stmt, from_table, extra_froms, from_hints, **kw
     ):
         """Render the DELETE .. FROM clause specific to Sybase."""
+        kw["asfrom"] = True
         return "FROM " + ", ".join(
-            t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+            t._compiler_dispatch(self, fromhints=from_hints, **kw)
             for t in [from_table] + extra_froms
         )
 
index 0c701cb52336195b5dcae41f041915426e933a86..b8e418d996256f1806a1a9dd4e062b754dee7a3f 100644 (file)
@@ -3881,16 +3881,18 @@ class StrSQLCompiler(SQLCompiler):
     def update_from_clause(
         self, update_stmt, from_table, extra_froms, from_hints, **kw
     ):
+        kw["asfrom"] = True
         return "FROM " + ", ".join(
-            t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+            t._compiler_dispatch(self, fromhints=from_hints, **kw)
             for t in extra_froms
         )
 
     def delete_extra_from_clause(
         self, update_stmt, from_table, extra_froms, from_hints, **kw
     ):
+        kw["asfrom"] = True
         return ", " + ", ".join(
-            t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+            t._compiler_dispatch(self, fromhints=from_hints, **kw)
             for t in extra_froms
         )
 
index fc2c9b40d755b02c9fb5edd52d72bc9ff8491340..7752a9b815559688797f971a6a999f304cab9daf 100644 (file)
@@ -1,3 +1,4 @@
+from sqlalchemy import testing
 from sqlalchemy.dialects import mssql
 from sqlalchemy.engine import default
 from sqlalchemy.exc import CompileError
@@ -976,6 +977,71 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
 
         eq_(insert.compile().isinsert, True)
 
+    @testing.combinations(
+        ("default_enhanced",),
+        ("postgresql",),
+    )
+    def test_select_from_update_cte(self, dialect):
+        t1 = table("table_1", column("id"), column("val"))
+
+        t2 = table("table_2", column("id"), column("val"))
+
+        upd = (
+            t1.update()
+            .values(val=t2.c.val)
+            .where(t1.c.id == t2.c.id)
+            .returning(t1.c.id, t1.c.val)
+        )
+
+        cte = upd.cte("update_cte")
+
+        qry = select(cte)
+
+        self.assert_compile(
+            qry,
+            "WITH update_cte AS (UPDATE table_1 SET val=table_2.val "
+            "FROM table_2 WHERE table_1.id = table_2.id "
+            "RETURNING table_1.id, table_1.val) "
+            "SELECT update_cte.id, update_cte.val FROM update_cte",
+            dialect=dialect,
+        )
+
+    @testing.combinations(
+        ("default_enhanced",),
+        ("postgresql",),
+    )
+    def test_select_from_delete_cte(self, dialect):
+        t1 = table("table_1", column("id"), column("val"))
+
+        t2 = table("table_2", column("id"), column("val"))
+
+        dlt = (
+            t1.delete().where(t1.c.id == t2.c.id).returning(t1.c.id, t1.c.val)
+        )
+
+        cte = dlt.cte("delete_cte")
+
+        qry = select(cte)
+
+        if dialect == "postgresql":
+            self.assert_compile(
+                qry,
+                "WITH delete_cte AS (DELETE FROM table_1 USING table_2 "
+                "WHERE table_1.id = table_2.id RETURNING table_1.id, "
+                "table_1.val) SELECT delete_cte.id, delete_cte.val "
+                "FROM delete_cte",
+                dialect=dialect,
+            )
+        else:
+            self.assert_compile(
+                qry,
+                "WITH delete_cte AS (DELETE FROM table_1 , table_2 "
+                "WHERE table_1.id = table_2.id "
+                "RETURNING table_1.id, table_1.val) "
+                "SELECT delete_cte.id, delete_cte.val FROM delete_cte",
+                dialect=dialect,
+            )
+
     def test_anon_update_cte(self):
         orders = table("orders", column("region"))
         stmt = (