From: Mike Bayer Date: Tue, 3 Oct 2023 12:40:06 +0000 (-0400) Subject: consider indpendent CTE for UPDATE..FROM X-Git-Tag: rel_2_0_22~18^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a2804c621961231b0cc3d9b4f2893417d460e447;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git consider indpendent CTE for UPDATE..FROM Fixed issue where referring to a FROM entry in the SET clause of an UPDATE statement would not include it in the FROM clause of the UPDATE statement, if that entry were nowhere else in the statement; this occurs currently for CTEs that were added using :meth:`.Update.add_cte` to provide the desired CTE at the top of the statement. Fixes: #10408 Change-Id: I6e3c6ca7a00cc884bda7e0f24c62c34c75134e5b --- diff --git a/doc/build/changelog/unreleased_20/10408.rst b/doc/build/changelog/unreleased_20/10408.rst new file mode 100644 index 0000000000..e2fff25817 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10408.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql + :tickets: 10408 + + Fixed issue where referring to a FROM entry in the SET clause of an UPDATE + statement would not include it in the FROM clause of the UPDATE statement, + if that entry were nowhere else in the statement; this occurs currently for + CTEs that were added using :meth:`.Update.add_cte` to provide the desired + CTE at the top of the statement. diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 4047ba4133..921aed3f9e 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -211,7 +211,11 @@ class DMLState(CompileState): primary_table = all_tables[0] seen = {primary_table} - for crit in statement._where_criteria: + consider = statement._where_criteria + if self._dict_parameters: + consider += tuple(self._dict_parameters.values()) + + for crit in consider: for item in _from_objects(crit): if not seen.intersection(item._cloned_set): froms.append(item) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 64e8732b78..d044212aa6 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -18,6 +18,7 @@ from sqlalchemy.sql import column from sqlalchemy.sql import cte from sqlalchemy.sql import exists from sqlalchemy.sql import func +from sqlalchemy.sql import insert from sqlalchemy.sql import literal from sqlalchemy.sql import select from sqlalchemy.sql import table @@ -1418,6 +1419,31 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): dialect=dialect, ) + def test_insert_update_w_add_cte(self): + """test #10408""" + a = table( + "a", column("id"), column("x"), column("y"), column("next_id") + ) + + insert_a_cte = (insert(a).values(x=10, y=15).returning(a.c.id)).cte( + "insert_a_cte" + ) + + update_query = ( + update(a) + .values(next_id=insert_a_cte.c.id) + .where(a.c.id == 10) + .add_cte(insert_a_cte) + ) + + self.assert_compile( + update_query, + "WITH insert_a_cte AS (INSERT INTO a (x, y) " + "VALUES (:param_1, :param_2) RETURNING a.id) " + "UPDATE a SET next_id=insert_a_cte.id " + "FROM insert_a_cte WHERE a.id = :id_1", + ) + def test_anon_update_cte(self): orders = table("orders", column("region")) stmt = (