From: Mike Bayer Date: Tue, 3 Mar 2020 22:22:30 +0000 (-0500) Subject: Restore crud flags if visiting_cte is set X-Git-Tag: rel_1_3_14~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2a9687ed40427d3adb00e1c7f6944ff4e3539c0b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Restore crud flags if visiting_cte is set Fixed bug where a CTE of an INSERT/UPDATE/DELETE that also uses RETURNING could then not be SELECTed from directly, as the internal state of the compiler would try to treat the outer SELECT as a DELETE statement itself and access nonexistent state. Fixes: #5181 Change-Id: Icba76f2148c8344baa1c04bac4ab6c6d24f23072 (cherry picked from commit 7fe400f54632835695f7b98f0c1a54424953dfad) --- diff --git a/doc/build/changelog/unreleased_13/5181.rst b/doc/build/changelog/unreleased_13/5181.rst new file mode 100644 index 0000000000..046dc4f381 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5181.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql, postgresql + :tickets: 5181 + + Fixed bug where a CTE of an INSERT/UPDATE/DELETE that also uses RETURNING + could then not be SELECTed from directly, as the internal state of the + compiler would try to treat the outer SELECT as a DELETE statement itself + and access nonexistent state. + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 562cd31ea8..8a1e424f22 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -586,6 +586,7 @@ class SQLCompiler(Compiled): # a map which tracks "truncated" names based on # dialect.label_length or dialect.max_identifier_length self.truncated_names = {} + Compiled.__init__(self, dialect, statement, **kwargs) if ( diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 4e524523e7..3011145836 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -42,8 +42,10 @@ def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): restore_isdelete = compiler.isdelete should_restore = ( - restore_isinsert or restore_isupdate or restore_isdelete - ) or len(compiler.stack) > 1 + (restore_isinsert or restore_isupdate or restore_isdelete) + or len(compiler.stack) > 1 + or "visiting_cte" in kw + ) if local_stmt_type is ISINSERT: compiler.isupdate = False diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index cb6fc54279..5639ec35cb 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -995,6 +995,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "upsert.quantity FROM upsert))", ) + eq_(insert.compile().isinsert, True) + def test_anon_update_cte(self): orders = table("orders", column("region")) stmt = ( @@ -1012,6 +1014,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT anon_1.region FROM anon_1", ) + eq_(stmt.select().compile().isupdate, False) + def test_anon_insert_cte(self): orders = table("orders", column("region")) stmt = ( @@ -1024,6 +1028,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "VALUES (:region) RETURNING orders.region) " "SELECT anon_1.region FROM anon_1", ) + eq_(stmt.select().compile().isinsert, False) def test_pg_example_one(self): products = table("products", column("id"), column("date")) @@ -1050,6 +1055,33 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "INSERT INTO products_log (id, date) " "SELECT moved_rows.id, moved_rows.date FROM moved_rows", ) + eq_(stmt.compile().isinsert, True) + eq_(stmt.compile().isdelete, False) + + def test_pg_example_one_select_only(self): + products = table("products", column("id"), column("date")) + + moved_rows = ( + products.delete() + .where( + and_(products.c.date >= "dateone", products.c.date < "datetwo") + ) + .returning(*products.c) + .cte("moved_rows") + ) + + stmt = moved_rows.select() + + self.assert_compile( + stmt, + "WITH moved_rows AS " + "(DELETE FROM products WHERE products.date >= :date_1 " + "AND products.date < :date_2 " + "RETURNING products.id, products.date) " + "SELECT moved_rows.id, moved_rows.date FROM moved_rows", + ) + + eq_(stmt.compile().isdelete, False) def test_pg_example_two(self): products = table("products", column("id"), column("price")) @@ -1072,6 +1104,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT t.id, t.price " "FROM t", ) + eq_(stmt.compile().isupdate, False) def test_pg_example_three(self): @@ -1132,6 +1165,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT pd.id, pd.price " "FROM pd", ) + eq_(stmt.compile().isinsert, False) def test_update_pulls_from_cte(self): products = table("products", column("id"), column("price")) @@ -1150,6 +1184,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "UPDATE products SET id=:id, price=:price FROM pd " "WHERE products.price = pd.price", ) + eq_(stmt.compile().isupdate, True) def test_standalone_function(self): a = table("a", column("x"))