]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Restore crud flags if visiting_cte is set
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Mar 2020 22:22:30 +0000 (17:22 -0500)
committermike bayer <mike_mp@zzzcomputing.com>
Wed, 4 Mar 2020 02:22:54 +0000 (02:22 +0000)
Fixed bug where a CTE of an INSERT/UPDATE/DELETE that also uses RETURNING
could then not be SELECTed from directly, as the internal state of the
compiler would try to treat the outer SELECT as a DELETE statement itself
and access nonexistent state.

Fixes: #5181
Change-Id: Icba76f2148c8344baa1c04bac4ab6c6d24f23072
(cherry picked from commit 7fe400f54632835695f7b98f0c1a54424953dfad)

doc/build/changelog/unreleased_13/5181.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_13/5181.rst b/doc/build/changelog/unreleased_13/5181.rst
new file mode 100644 (file)
index 0000000..046dc4f
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql, postgresql
+    :tickets: 5181
+
+    Fixed bug where a CTE of an INSERT/UPDATE/DELETE that also uses RETURNING
+    could then not be SELECTed from directly, as the internal state of the
+    compiler would try to treat the outer SELECT as a DELETE statement itself
+    and access nonexistent state.
+
index 562cd31ea8ca7d89903b1c01895a21c5f5ba6668..8a1e424f22530b9037b78a98e195fdac2783b5c0 100644 (file)
@@ -586,6 +586,7 @@ class SQLCompiler(Compiled):
         # a map which tracks "truncated" names based on
         # dialect.label_length or dialect.max_identifier_length
         self.truncated_names = {}
+
         Compiled.__init__(self, dialect, statement, **kwargs)
 
         if (
index 4e524523e78dae55ff37f4a73a1e8439fa37abe8..30111458362e8a96e0f90dd80cf1d9d15b7dcff0 100644 (file)
@@ -42,8 +42,10 @@ def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
     restore_isdelete = compiler.isdelete
 
     should_restore = (
-        restore_isinsert or restore_isupdate or restore_isdelete
-    ) or len(compiler.stack) > 1
+        (restore_isinsert or restore_isupdate or restore_isdelete)
+        or len(compiler.stack) > 1
+        or "visiting_cte" in kw
+    )
 
     if local_stmt_type is ISINSERT:
         compiler.isupdate = False
index cb6fc54279152b5be8e825050eb4ab478381a79e..5639ec35cbb96fbe8b4bae80151d5ad45e0e004b 100644 (file)
@@ -995,6 +995,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "upsert.quantity FROM upsert))",
         )
 
+        eq_(insert.compile().isinsert, True)
+
     def test_anon_update_cte(self):
         orders = table("orders", column("region"))
         stmt = (
@@ -1012,6 +1014,8 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT anon_1.region FROM anon_1",
         )
 
+        eq_(stmt.select().compile().isupdate, False)
+
     def test_anon_insert_cte(self):
         orders = table("orders", column("region"))
         stmt = (
@@ -1024,6 +1028,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "VALUES (:region) RETURNING orders.region) "
             "SELECT anon_1.region FROM anon_1",
         )
+        eq_(stmt.select().compile().isinsert, False)
 
     def test_pg_example_one(self):
         products = table("products", column("id"), column("date"))
@@ -1050,6 +1055,33 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "INSERT INTO products_log (id, date) "
             "SELECT moved_rows.id, moved_rows.date FROM moved_rows",
         )
+        eq_(stmt.compile().isinsert, True)
+        eq_(stmt.compile().isdelete, False)
+
+    def test_pg_example_one_select_only(self):
+        products = table("products", column("id"), column("date"))
+
+        moved_rows = (
+            products.delete()
+            .where(
+                and_(products.c.date >= "dateone", products.c.date < "datetwo")
+            )
+            .returning(*products.c)
+            .cte("moved_rows")
+        )
+
+        stmt = moved_rows.select()
+
+        self.assert_compile(
+            stmt,
+            "WITH moved_rows AS "
+            "(DELETE FROM products WHERE products.date >= :date_1 "
+            "AND products.date < :date_2 "
+            "RETURNING products.id, products.date) "
+            "SELECT moved_rows.id, moved_rows.date FROM moved_rows",
+        )
+
+        eq_(stmt.compile().isdelete, False)
 
     def test_pg_example_two(self):
         products = table("products", column("id"), column("price"))
@@ -1072,6 +1104,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT t.id, t.price "
             "FROM t",
         )
+        eq_(stmt.compile().isupdate, False)
 
     def test_pg_example_three(self):
 
@@ -1132,6 +1165,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT pd.id, pd.price "
             "FROM pd",
         )
+        eq_(stmt.compile().isinsert, False)
 
     def test_update_pulls_from_cte(self):
         products = table("products", column("id"), column("price"))
@@ -1150,6 +1184,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
             "UPDATE products SET id=:id, price=:price FROM pd "
             "WHERE products.price = pd.price",
         )
+        eq_(stmt.compile().isupdate, True)
 
     def test_standalone_function(self):
         a = table("a", column("x"))