From: Mike Bayer Date: Sat, 17 Apr 2021 04:30:29 +0000 (-0400) Subject: pass asfrom correctly in compilers X-Git-Tag: rel_1_4_9~1^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=901f7a2b534e4bbc88d7c6894541223cb0dd968d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pass asfrom correctly in compilers Fixed an argument error in the default and PostgreSQL compilers that would interfere with an UPDATE..FROM or DELETE..FROM..USING statement that was then SELECTed from as a CTE. The incorrect pattern was also fixed in the mysql and sybase dialects. MySQL supports CTEs but not "returning". Fixes: #6303 Change-Id: Ic94805611a5ec443749fb6b1fd8a1326b0d83ef7 --- diff --git a/doc/build/changelog/unreleased_14/6303.rst b/doc/build/changelog/unreleased_14/6303.rst new file mode 100644 index 0000000000..8c7a9f2300 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6303.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, postgresql, sql, regression + :tickets: 6303 + + Fixed an argument error in the default and PostgreSQL compilers that + would interfere with an UPDATE..FROM or DELETE..FROM..USING statement + that was then SELECTed from as a CTE. diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index c5113b054b..d4c70a78e6 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1783,8 +1783,9 @@ class MySQLCompiler(compiler.SQLCompiler): return None def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + kw["asfrom"] = True return ", ".join( - t._compiler_dispatch(self, asfrom=True, **kw) + t._compiler_dispatch(self, **kw) for t in [from_table] + list(extra_froms) ) @@ -1806,8 +1807,9 @@ class MySQLCompiler(compiler.SQLCompiler): self, delete_stmt, from_table, extra_froms, from_hints, **kw ): """Render the DELETE .. USING clause specific to MySQL.""" + kw["asfrom"] = True return "USING " + ", ".join( - t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + t._compiler_dispatch(self, fromhints=from_hints, **kw) for t in [from_table] + extra_froms ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 0e9968031a..47a933479f 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2420,8 +2420,9 @@ class PGCompiler(compiler.SQLCompiler): def update_from_clause( self, update_stmt, from_table, extra_froms, from_hints, **kw ): + kw["asfrom"] = True return "FROM " + ", ".join( - t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + t._compiler_dispatch(self, fromhints=from_hints, **kw) for t in extra_froms ) @@ -2429,8 +2430,9 @@ class PGCompiler(compiler.SQLCompiler): self, delete_stmt, from_table, extra_froms, from_hints, **kw ): """Render the DELETE .. USING clause specific to PostgreSQL.""" + kw["asfrom"] = True return "USING " + ", ".join( - t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + t._compiler_dispatch(self, fromhints=from_hints, **kw) for t in extra_froms ) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 7c10973e62..5b6aefe9e5 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -564,8 +564,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler): self, delete_stmt, from_table, extra_froms, from_hints, **kw ): """Render the DELETE .. FROM clause specific to Sybase.""" + kw["asfrom"] = True return "FROM " + ", ".join( - t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + t._compiler_dispatch(self, fromhints=from_hints, **kw) for t in [from_table] + extra_froms ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 0c701cb523..b8e418d996 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3881,16 +3881,18 @@ class StrSQLCompiler(SQLCompiler): def update_from_clause( self, update_stmt, from_table, extra_froms, from_hints, **kw ): + kw["asfrom"] = True return "FROM " + ", ".join( - t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + t._compiler_dispatch(self, fromhints=from_hints, **kw) for t in extra_froms ) def delete_extra_from_clause( self, update_stmt, from_table, extra_froms, from_hints, **kw ): + kw["asfrom"] = True return ", " + ", ".join( - t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + t._compiler_dispatch(self, fromhints=from_hints, **kw) for t in extra_froms ) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index fc2c9b40d7..7752a9b815 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,3 +1,4 @@ +from sqlalchemy import testing from sqlalchemy.dialects import mssql from sqlalchemy.engine import default from sqlalchemy.exc import CompileError @@ -976,6 +977,71 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): eq_(insert.compile().isinsert, True) + @testing.combinations( + ("default_enhanced",), + ("postgresql",), + ) + def test_select_from_update_cte(self, dialect): + t1 = table("table_1", column("id"), column("val")) + + t2 = table("table_2", column("id"), column("val")) + + upd = ( + t1.update() + .values(val=t2.c.val) + .where(t1.c.id == t2.c.id) + .returning(t1.c.id, t1.c.val) + ) + + cte = upd.cte("update_cte") + + qry = select(cte) + + self.assert_compile( + qry, + "WITH update_cte AS (UPDATE table_1 SET val=table_2.val " + "FROM table_2 WHERE table_1.id = table_2.id " + "RETURNING table_1.id, table_1.val) " + "SELECT update_cte.id, update_cte.val FROM update_cte", + dialect=dialect, + ) + + @testing.combinations( + ("default_enhanced",), + ("postgresql",), + ) + def test_select_from_delete_cte(self, dialect): + t1 = table("table_1", column("id"), column("val")) + + t2 = table("table_2", column("id"), column("val")) + + dlt = ( + t1.delete().where(t1.c.id == t2.c.id).returning(t1.c.id, t1.c.val) + ) + + cte = dlt.cte("delete_cte") + + qry = select(cte) + + if dialect == "postgresql": + self.assert_compile( + qry, + "WITH delete_cte AS (DELETE FROM table_1 USING table_2 " + "WHERE table_1.id = table_2.id RETURNING table_1.id, " + "table_1.val) SELECT delete_cte.id, delete_cte.val " + "FROM delete_cte", + dialect=dialect, + ) + else: + self.assert_compile( + qry, + "WITH delete_cte AS (DELETE FROM table_1 , table_2 " + "WHERE table_1.id = table_2.id " + "RETURNING table_1.id, table_1.val) " + "SELECT delete_cte.id, delete_cte.val FROM delete_cte", + dialect=dialect, + ) + def test_anon_update_cte(self): orders = table("orders", column("region")) stmt = (