]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
consider indpendent CTE for UPDATE..FROM
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Oct 2023 12:40:06 +0000 (08:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Oct 2023 17:29:01 +0000 (13:29 -0400)
Fixed issue where referring to a FROM entry in the SET clause of an UPDATE
statement would not include it in the FROM clause of the UPDATE statement,
if that entry were nowhere else in the statement; this occurs currently for
CTEs that were added using :meth:`.Update.add_cte` to provide the desired
CTE at the top of the statement.

Fixes: #10408
Change-Id: I6e3c6ca7a00cc884bda7e0f24c62c34c75134e5b

doc/build/changelog/unreleased_20/10408.rst [new file with mode: 0644]
lib/sqlalchemy/sql/dml.py
test/sql/test_cte.py

diff --git a/doc/build/changelog/unreleased_20/10408.rst b/doc/build/changelog/unreleased_20/10408.rst
new file mode 100644 (file)
index 0000000..e2fff25
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 10408
+
+    Fixed issue where referring to a FROM entry in the SET clause of an UPDATE
+    statement would not include it in the FROM clause of the UPDATE statement,
+    if that entry were nowhere else in the statement; this occurs currently for
+    CTEs that were added using :meth:`.Update.add_cte` to provide the desired
+    CTE at the top of the statement.
index 4047ba4133f0ccafe14e0b5d287d9fe8a4faac89..921aed3f9e73b92a8d1f8b9fb53d2f9a11301f27 100644 (file)
@@ -211,7 +211,11 @@ class DMLState(CompileState):
         primary_table = all_tables[0]
         seen = {primary_table}
 
-        for crit in statement._where_criteria:
+        consider = statement._where_criteria
+        if self._dict_parameters:
+            consider += tuple(self._dict_parameters.values())
+
+        for crit in consider:
             for item in _from_objects(crit):
                 if not seen.intersection(item._cloned_set):
                     froms.append(item)
index 64e8732b7885d9d4110eaaba4314a74bcdc830dd..d044212aa60230069cee8782f6f835065d38cf7c 100644 (file)
@@ -18,6 +18,7 @@ from sqlalchemy.sql import column
 from sqlalchemy.sql import cte
 from sqlalchemy.sql import exists
 from sqlalchemy.sql import func
+from sqlalchemy.sql import insert
 from sqlalchemy.sql import literal
 from sqlalchemy.sql import select
 from sqlalchemy.sql import table
@@ -1418,6 +1419,31 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
                 dialect=dialect,
             )
 
+    def test_insert_update_w_add_cte(self):
+        """test #10408"""
+        a = table(
+            "a", column("id"), column("x"), column("y"), column("next_id")
+        )
+
+        insert_a_cte = (insert(a).values(x=10, y=15).returning(a.c.id)).cte(
+            "insert_a_cte"
+        )
+
+        update_query = (
+            update(a)
+            .values(next_id=insert_a_cte.c.id)
+            .where(a.c.id == 10)
+            .add_cte(insert_a_cte)
+        )
+
+        self.assert_compile(
+            update_query,
+            "WITH insert_a_cte AS (INSERT INTO a (x, y) "
+            "VALUES (:param_1, :param_2) RETURNING a.id) "
+            "UPDATE a SET next_id=insert_a_cte.id "
+            "FROM insert_a_cte WHERE a.id = :id_1",
+        )
+
     def test_anon_update_cte(self):
         orders = table("orders", column("region"))
         stmt = (