]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use _generate_columns_plus_names for ddl returning c populate
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 26 Aug 2025 18:47:34 +0000 (14:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 26 Aug 2025 22:13:51 +0000 (18:13 -0400)
Improved the implementation of :meth:`.UpdateBase.returning` to use more
robust logic in setting up the ``.c`` collection of a derived statement
such as a CTE.  This fixes issues related to RETURNING clauses that feature
expressions based on returned columns with or without qualifying labels.

Co-authored-by: Juhyeong Ko <dury.ko@gmail.com>
Fixes: #12271
Change-Id: Id0d486d4304002f1affdec2e7662ac2965936f2a
(cherry picked from commit 4c4011b50bf8f2f6acca86b11ae3d900b30034a0)

doc/build/changelog/unreleased_20/12271.rst [new file with mode: 0644]
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/selectable.py
test/sql/test_returning.py

diff --git a/doc/build/changelog/unreleased_20/12271.rst b/doc/build/changelog/unreleased_20/12271.rst
new file mode 100644 (file)
index 0000000..1cc53cf
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 12271
+
+    Improved the implementation of :meth:`.UpdateBase.returning` to use more
+    robust logic in setting up the ``.c`` collection of a derived statement
+    such as a CTE.  This fixes issues related to RETURNING clauses that feature
+    expressions based on returned columns with or without qualifying labels.
index f5071146be2464e2def98f73f4372c8a35ce2999..51da9fa33b80a2572fcff076b818f7e10bc7bdd8 100644 (file)
@@ -426,13 +426,26 @@ class UpdateBase(
         primary_key: ColumnSet,
         foreign_keys: Set[KeyedColumnElement[Any]],
     ) -> None:
-        columns._populate_separate_keys(
-            col._make_proxy(
-                fromclause, primary_key=primary_key, foreign_keys=foreign_keys
+        prox = [
+            c._make_proxy(
+                fromclause,
+                key=proxy_key,
+                name=required_label_name,
+                name_is_truncatable=True,
+                primary_key=primary_key,
+                foreign_keys=foreign_keys,
             )
-            for col in self._all_selected_columns
-            if is_column_element(col)
-        )
+            for (
+                required_label_name,
+                proxy_key,
+                fallback_label_name,
+                c,
+                repeated,
+            ) in (self._generate_columns_plus_names(False))
+            if is_column_element(c)
+        ]
+
+        columns._populate_separate_keys(prox)
 
     def params(self, *arg: Any, **kw: Any) -> NoReturn:
         """Set the parameters for the statement.
index 6c2c7ca1705725f99d4ec3c6993944a93a0312c7..9e7c647d93719a1d60abc26f7b0e700d4e549dd9 100644 (file)
@@ -2353,10 +2353,11 @@ class SelectsRows(ReturnsRows):
         cols: Optional[_SelectIterable] = None,
     ) -> List[_ColumnsPlusNames]:
         """Generate column names as rendered in a SELECT statement by
-        the compiler.
+        the compiler, as well as tokens used to populate the .c. collection
+        on a :class:`.FromClause`.
 
         This is distinct from the _column_naming_convention generator that's
-        intended for population of .c collections and similar, which has
+        intended for population of the Select.selected_columns collection,
         different rules.   the collection returned here calls upon the
         _column_naming_convention as well.
 
index 6cccd01d4a9c9e12c9b77b5bb9ff389d8494c80f..5d5a21dacd65efb07c71ec281a124d2e62a3c09f 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_not_none
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import provision
@@ -234,6 +235,79 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
             "RETURNING t.x, t.y, t.z) SELECT c.z FROM c",
         )
 
+    def test_dml_returning_c_labels_one(self):
+        """tests for #12271"""
+
+        tbl = table("tbl", column("id"))
+
+        stmt = (
+            update(tbl)
+            .values(id=20)
+            .returning(tbl.c.id, (tbl.c.id * -1).label("smth"))
+            .cte()
+        )
+
+        self.assert_compile(
+            select(stmt.c.id, stmt.c.smth),
+            "WITH anon_1 AS (UPDATE tbl SET id=:param_1 "
+            "RETURNING tbl.id, tbl.id * :id_1 AS smth) "
+            "SELECT anon_1.id, anon_1.smth FROM anon_1",
+            dialect="default",
+        )
+
+    def test_dml_returning_c_labels_two(self):
+        """tests for #12271"""
+
+        tbl = table("tbl", column("id"))
+
+        stmt = insert(tbl).returning(tbl.c.id, (tbl.c.id * -1)).cte()
+
+        self.assert_compile(
+            select(stmt.c.id),
+            "WITH anon_1 AS (INSERT INTO tbl (id) VALUES (:id) "
+            "RETURNING tbl.id, tbl.id * :id_1 AS anon_2) "
+            "SELECT anon_1.id FROM anon_1",
+            dialect="default",
+        )
+
+    def test_dml_returning_c_labels_three(self, table_fixture):
+        """tests for #12271"""
+        t = table_fixture
+
+        stmt = (
+            delete(t)
+            .returning(t.c.id, (t.c.id * -1).label("negative_id"))
+            .cte()
+        )
+
+        eq_(list(stmt.c.keys()), ["id", "negative_id"])
+        eq_(stmt.c.negative_id.name, "negative_id")
+
+    def test_dml_returning_c_labels_four(self, table_fixture):
+        """tests for #12271"""
+        t = table_fixture
+
+        stmt = (
+            delete(t)
+            .returning(
+                t.c.id, t.c.id * -1, t.c.id + 10, t.c.id - 10, -1 * t.c.id
+            )
+            .cte()
+        )
+
+        eq_(len(stmt.c), 5)
+        is_not_none(stmt.c.id)
+        assert all(col is not None for col in stmt.c)
+        self.assert_compile(
+            select(stmt),
+            "WITH anon_1 AS (DELETE FROM foo RETURNING foo.id, "
+            "foo.id * :id_1 AS anon_2, foo.id + :id_2 AS anon_3, "
+            "foo.id - :id_3 AS anon_4, :id_4 * foo.id AS anon_5) "
+            "SELECT anon_1.id, anon_1.anon_2, anon_1.anon_3, anon_1.anon_4, "
+            "anon_1.anon_5 FROM anon_1",
+            dialect="default",
+        )
+
 
 class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
     __requires__ = ("insert_returning",)