From 7fe400f54632835695f7b98f0c1a54424953dfad Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 3 Mar 2020 17:22:30 -0500 Subject: [PATCH] Restore crud flags if visiting_cte is set Fixed bug where a CTE of an INSERT/UPDATE/DELETE that also uses RETURNING could then not be SELECTed from directly, as the internal state of the compiler would try to treat the outer SELECT as a DELETE statement itself and access nonexistent state. Fixes: #5181 Change-Id: Icba76f2148c8344baa1c04bac4ab6c6d24f23072 --- doc/build/changelog/unreleased_13/5181.rst | 9 ++++++ lib/sqlalchemy/sql/compiler.py | 1 + lib/sqlalchemy/sql/crud.py | 6 ++-- test/sql/test_cte.py | 35 ++++++++++++++++++++++ 4 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 doc/build/changelog/unreleased_13/5181.rst diff --git a/doc/build/changelog/unreleased_13/5181.rst b/doc/build/changelog/unreleased_13/5181.rst new file mode 100644 index 0000000000..046dc4f381 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5181.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql, postgresql + :tickets: 5181 + + Fixed bug where a CTE of an INSERT/UPDATE/DELETE that also uses RETURNING + could then not be SELECTed from directly, as the internal state of the + compiler would try to treat the outer SELECT as a DELETE statement itself + and access nonexistent state. + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d31cf67f88..424282951a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -723,6 +723,7 @@ class SQLCompiler(Compiled): # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} + Compiled.__init__(self, dialect, statement, **kwargs) if ( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 433a5fdfab..e474952ce4 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -44,8 +44,10 @@ def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): restore_isdelete = compiler.isdelete should_restore = ( - restore_isinsert or restore_isupdate or restore_isdelete - ) or len(compiler.stack) > 1 + (restore_isinsert or restore_isupdate or restore_isdelete) + or len(compiler.stack) > 1 + or "visiting_cte" in kw + ) if local_stmt_type is ISINSERT: compiler.isupdate = False diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 4a7a80e770..c9178d5801 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -999,6 +999,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "upsert.quantity FROM upsert))", ) + eq_(insert.compile().isinsert, True) + def test_anon_update_cte(self): orders = table("orders", column("region")) stmt = ( @@ -1016,6 +1018,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT anon_1.region FROM anon_1", ) + eq_(stmt.select().compile().isupdate, False) + def test_anon_insert_cte(self): orders = table("orders", column("region")) stmt = ( @@ -1028,6 +1032,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "VALUES (:region) RETURNING orders.region) " "SELECT anon_1.region FROM anon_1", ) + eq_(stmt.select().compile().isinsert, False) def test_pg_example_one(self): products = table("products", column("id"), column("date")) @@ -1054,6 +1059,33 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "INSERT INTO products_log (id, date) " "SELECT moved_rows.id, moved_rows.date FROM moved_rows", ) + eq_(stmt.compile().isinsert, True) + eq_(stmt.compile().isdelete, False) + + def test_pg_example_one_select_only(self): + products = table("products", column("id"), column("date")) + + moved_rows = ( + products.delete() + .where( + and_(products.c.date >= "dateone", products.c.date < "datetwo") + ) + .returning(*products.c) + .cte("moved_rows") + ) + + stmt = moved_rows.select() + + self.assert_compile( + stmt, + "WITH moved_rows AS " + "(DELETE FROM products WHERE products.date >= :date_1 " + "AND products.date < :date_2 " + "RETURNING products.id, products.date) " + "SELECT moved_rows.id, moved_rows.date FROM moved_rows", + ) + + eq_(stmt.compile().isdelete, False) def test_pg_example_two(self): products = table("products", column("id"), column("price")) @@ -1076,6 +1108,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT t.id, t.price " "FROM t", ) + eq_(stmt.compile().isupdate, False) def test_pg_example_three(self): @@ -1136,6 +1169,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT pd.id, pd.price " "FROM pd", ) + eq_(stmt.compile().isinsert, False) def test_update_pulls_from_cte(self): products = table("products", column("id"), column("price")) @@ -1154,6 +1188,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "UPDATE products SET id=:id, price=:price FROM pd " "WHERE products.price = pd.price", ) + eq_(stmt.compile().isupdate, True) def test_standalone_function(self): a = table("a", column("x")) -- 2.39.5