]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement RETURNING * for ORM DML
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Aug 2023 14:47:11 +0000 (10:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Aug 2023 15:56:58 +0000 (11:56 -0400)
Implemented the "RETURNING '*'" use case for ORM enabled DML statements.
This will render in as many cases as possible and return the unfiltered
result set, however is not supported for multi-parameter "ORM bulk INSERT"
statements that have specific column rendering requirements.

Fixes: #10192
Change-Id: I04297d08eacb9ad1d5fd6d9dd21afefb8e9dc0b1

doc/build/changelog/unreleased_20/10192.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/context.py
test/orm/dml/test_bulk_statements.py
test/orm/dml/test_update_delete_where.py
test/orm/test_core_compilation.py
test/orm/test_loading.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/10192.rst b/doc/build/changelog/unreleased_20/10192.rst
new file mode 100644 (file)
index 0000000..1d59861
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 10192
+
+    Implemented the "RETURNING '*'" use case for ORM enabled DML statements.
+    This will render in as many cases as possible and return the unfiltered
+    result set, however is not supported for multi-parameter "ORM bulk INSERT"
+    statements that have specific column rendering requirements.
+
index 063a2cbe481df8cec6ddace8ebf33ddf0b72ed7c..d38dfa9ce19fcd2350d95f950575fecf4a580758 100644 (file)
@@ -578,7 +578,10 @@ class ORMDMLState(AbstractORMCompileState):
         execution_context = result.context
         compile_state = execution_context.compiled.compile_state
 
-        if compile_state.from_statement_ctx:
+        if (
+            compile_state.from_statement_ctx
+            and not compile_state.from_statement_ctx.compile_options._is_star
+        ):
             load_options = execution_options.get(
                 "_sa_orm_load_options", QueryContext.default_load_options
             )
@@ -1374,6 +1377,16 @@ class BulkORMInsert(ORMDMLState, InsertDMLState):
             use_supplemental_cols=True,
         )
 
+        if (
+            self.from_statement_ctx is not None
+            and self.from_statement_ctx.compile_options._is_star
+        ):
+            raise sa_exc.CompileError(
+                "Can't use RETURNING * with bulk ORM INSERT.  "
+                "Please use a different INSERT form, such as INSERT..VALUES "
+                "or INSERT with a Core Connection"
+            )
+
         self.statement = statement
 
 
index e961e59347eb7a16e05073da7d8915e3f984ae2e..63c4e86c63845832420852884e9ca3edcc5d29ed 100644 (file)
@@ -838,6 +838,13 @@ class ORMFromStatementCompileState(ORMCompileState):
             "plugin_subject", None
         )
         adapter = DMLReturningColFilter(target_mapper, dml_mapper)
+
+        if self.compile_options._is_star and (len(self._entities) != 1):
+            raise sa_exc.CompileError(
+                "Can't generate ORM query that includes multiple expressions "
+                "at the same time as '*'; query for '*' alone if present"
+            )
+
         for entity in self._entities:
             entity.setup_dml_returning_compile_state(self, adapter)
 
@@ -2969,7 +2976,6 @@ class _ColumnEntity(_QueryEntity):
             column = compile_state.compound_eager_adapter.columns[column]
 
         getter = result._getter(column)
-
         ret = getter, self._label_name, self._extra_entities
         self._row_processor = ret
 
@@ -3031,6 +3037,13 @@ class _RawColumnEntity(_ColumnEntity):
     def corresponds_to(self, entity):
         return False
 
+    def setup_dml_returning_compile_state(
+        self,
+        compile_state: ORMCompileState,
+        adapter: DMLReturningColFilter,
+    ) -> None:
+        return self.setup_compile_state(compile_state)
+
     def setup_compile_state(self, compile_state):
         current_adapter = compile_state._get_current_adapter()
         if current_adapter:
index 2888aeaf9e1a69ee12ae9a477e085d455891df8f..d9c91a57073f1e63c8dc9118690732012c57c14c 100644 (file)
@@ -192,6 +192,85 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
             ),
         )
 
+    @testing.requires.insert_returning
+    @testing.requires.returning_star
+    @testing.variation(
+        "insert_type",
+        ["bulk", ("values", testing.requires.multivalues_inserts), "single"],
+    )
+    def test_insert_returning_star(self, decl_base, insert_type):
+        """test #10192"""
+
+        class User(decl_base):
+            __tablename__ = "users"
+
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+
+            name: Mapped[str] = mapped_column()
+            other_thing: Mapped[Optional[str]]
+            server_thing: Mapped[str] = mapped_column(server_default="thing")
+
+        decl_base.metadata.create_all(testing.db)
+        insert_stmt = insert(User).returning(literal_column("*"))
+
+        s = fixture_session()
+
+        if insert_type.bulk or insert_type.single:
+            with expect_raises_message(
+                exc.CompileError,
+                r"Can't use RETURNING \* with bulk ORM INSERT.",
+            ):
+                if insert_type.bulk:
+                    s.execute(
+                        insert_stmt,
+                        [
+                            {"name": "some name 1"},
+                            {"name": "some name 2"},
+                            {"name": "some name 3"},
+                        ],
+                    )
+                else:
+                    s.execute(
+                        insert_stmt,
+                        {"name": "some name 1"},
+                    )
+            return
+        elif insert_type.values:
+            with self.sql_execution_asserter() as asserter:
+                result = s.execute(
+                    insert_stmt.values(
+                        [
+                            {"name": "some name 1"},
+                            {"name": "some name 2"},
+                            {"name": "some name 3"},
+                        ],
+                    )
+                )
+
+            eq_(
+                result.all(),
+                [
+                    (1, "some name 1", None, "thing"),
+                    (2, "some name 2", None, "thing"),
+                    (3, "some name 3", None, "thing"),
+                ],
+            )
+            asserter.assert_(
+                CompiledSQL(
+                    "INSERT INTO users (name) VALUES (:name_m0), "
+                    "(:name_m1), (:name_m2) RETURNING *",
+                    [
+                        {
+                            "name_m0": "some name 1",
+                            "name_m1": "some name 2",
+                            "name_m2": "some name 3",
+                        }
+                    ],
+                ),
+            )
+        else:
+            insert_type.fail()
+
     @testing.requires.insert_returning
     @testing.skip_if(
         "oracle", "oracle doesn't like the no-FROM SELECT inside of an INSERT"
@@ -587,6 +666,60 @@ class UpdateStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
                     CompiledSQL("UPDATE a SET x=:x, y=:y", [{"x": 5, "y": 9}]),
                 )
 
+    @testing.variation("multi_row", ["multirow", "singlerow", "listwsingle"])
+    @testing.requires.update_returning
+    @testing.requires.returning_star
+    def test_bulk_update_returning_star(self, decl_base, multi_row):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=False
+            )
+
+            x: Mapped[int]
+            y: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+
+        s = fixture_session()
+
+        s.add_all(
+            [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)],
+        )
+        s.commit()
+
+        stmt = update(A).returning(literal_column("*"))
+
+        if multi_row.multirow:
+            data = [
+                {"x": 3, "y": 8},
+                {"x": 5, "y": 9},
+                {"x": 12, "y": 15},
+            ]
+
+            stmt = stmt.execution_options(synchronize_session=None)
+        elif multi_row.listwsingle:
+            data = [
+                {"x": 5, "y": 9},
+            ]
+
+            stmt = stmt.execution_options(synchronize_session=None)
+        elif multi_row.singlerow:
+            data = {"x": 5, "y": 9}
+        else:
+            multi_row.fail()
+
+        if multi_row.multirow or multi_row.listwsingle:
+            with expect_raises_message(
+                exc.InvalidRequestError, "No primary key value supplied"
+            ):
+                s.execute(stmt, data)
+                return
+        else:
+            result = s.execute(stmt, data)
+            eq_(result.all(), [(1, 5, 9), (2, 5, 9), (3, 5, 9)])
+
     def test_bulk_update_w_where_one(self, decl_base):
         """test use case in #9595"""
 
index a524ddd14ac53c4ba873c84d2d716efc68b9d1e4..7f76d735d35b594aedd5341dc76611a9d6f18d4f 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import insert
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import lambda_stmt
+from sqlalchemy import literal_column
 from sqlalchemy import MetaData
 from sqlalchemy import or_
 from sqlalchemy import select
@@ -1094,6 +1095,31 @@ class UpdateDeleteTest(fixtures.MappedTest):
             list(zip([25, 37, 29, 27])),
         )
 
+    @testing.requires.update_returning
+    @testing.requires.returning_star
+    def test_update_returning_star(self):
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+        stmt = (
+            update(User)
+            .where(User.age > 29)
+            .values({"age": User.age - 10})
+            .returning(literal_column("*"))
+        )
+
+        result = sess.execute(stmt)
+        eq_(result.all(), [(2, "jack", 37), (4, "jane", 27)])
+
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+        eq_(
+            sess.query(User.age).order_by(User.id).all(),
+            list(zip([25, 37, 29, 27])),
+        )
+
     @testing.combinations(True, False, argnames="implicit_returning")
     def test_update_fetch_returning(self, implicit_returning):
         if implicit_returning:
@@ -1255,6 +1281,28 @@ class UpdateDeleteTest(fixtures.MappedTest):
         # to point to the class, so you can test eq with sets
         eq_(set(result.all()), expected)
 
+    @testing.requires.delete_returning
+    @testing.requires.returning_star
+    def test_delete_returning_star(self):
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+        in_(john, sess)
+        in_(jack, sess)
+
+        stmt = delete(User).where(User.age > 29).returning(literal_column("*"))
+
+        result = sess.execute(stmt)
+        eq_(result.all(), [(2, "jack", 47), (4, "jane", 37)])
+
+        in_(john, sess)
+        not_in(jack, sess)
+        in_(jill, sess)
+        not_in(jane, sess)
+
     @testing.combinations(True, False, argnames="implicit_returning")
     def test_delete_fetch_returning(self, implicit_returning):
         if implicit_returning:
index a8d38de3e3c0cb03bdf8a48846c4c7487ca5a794..06482562b97b9e455a72bca98dc8a8ac85209303 100644 (file)
@@ -2834,8 +2834,6 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
         )
 
     def test_update_from_entity(self):
-        from sqlalchemy.sql import update
-
         User = self.classes.User
         self.assert_compile(
             update(User), "UPDATE users SET id=:id, name=:name"
@@ -2854,8 +2852,6 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
         )
 
     def test_delete_from_entity(self):
-        from sqlalchemy.sql import delete
-
         User = self.classes.User
         self.assert_compile(delete(User), "DELETE FROM users")
 
@@ -2866,8 +2862,6 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
         )
 
     def test_insert_from_entity(self):
-        from sqlalchemy.sql import insert
-
         User = self.classes.User
         self.assert_compile(
             insert(User), "INSERT INTO users (id, name) VALUES (:id, :name)"
@@ -2879,6 +2873,27 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
             checkparams={"name": "ed"},
         )
 
+    def test_update_returning_star(self):
+        User = self.classes.User
+        self.assert_compile(
+            update(User).returning(literal_column("*")),
+            "UPDATE users SET id=:id, name=:name RETURNING *",
+        )
+
+    def test_delete_returning_star(self):
+        User = self.classes.User
+        self.assert_compile(
+            delete(User).returning(literal_column("*")),
+            "DELETE FROM users RETURNING *",
+        )
+
+    def test_insert_returning_star(self):
+        User = self.classes.User
+        self.assert_compile(
+            insert(User).returning(literal_column("*")),
+            "INSERT INTO users (id, name) VALUES (:id, :name) RETURNING *",
+        )
+
     def test_col_prop_builtin_function(self):
         class Foo:
             pass
index d0b5c9d8f9c40c026063feb33f829eb395f08f6c..8be0f3168d8207980b894f979fce6c4e0919ee01 100644 (file)
@@ -1,9 +1,12 @@
+from sqlalchemy import delete
 from sqlalchemy import exc
+from sqlalchemy import insert
 from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import update
 from sqlalchemy.orm import loading
 from sqlalchemy.orm import relationship
 from sqlalchemy.testing import is_true
@@ -77,7 +80,8 @@ class SelectStarTest(_fixtures.FixtureTest):
         lambda User, star: (star, text("some text")),
         argnames="testcase",
     )
-    def test_no_star_orm_combinations(self, exprtype, testcase):
+    @testing.variation("stmt_type", ["select", "update", "insert", "delete"])
+    def test_no_star_orm_combinations(self, exprtype, testcase, stmt_type):
         """test for #8235"""
         User = self.classes.User
 
@@ -91,7 +95,17 @@ class SelectStarTest(_fixtures.FixtureTest):
             assert False
 
         args = testing.resolve_lambda(testcase, User=User, star=star)
-        stmt = select(*args).select_from(User)
+
+        if stmt_type.select:
+            stmt = select(*args).select_from(User)
+        elif stmt_type.insert:
+            stmt = insert(User).returning(*args)
+        elif stmt_type.update:
+            stmt = update(User).values({"data": "foo"}).returning(*args)
+        elif stmt_type.delete:
+            stmt = delete(User).returning(*args)
+        else:
+            stmt_type.fail()
 
         s = fixture_session()
 
index 61cb139338c22cf12fccb2fd38466f7d6cd92a5e..e0941da1b925c4fd91b6e225d22dfc17c5be1097 100644 (file)
@@ -461,6 +461,12 @@ class DefaultRequirements(SuiteRequirements):
     def computed_columns_on_update_returning(self):
         return self.computed_columns + skip_if("oracle")
 
+    @property
+    def returning_star(self):
+        """backend supports RETURNING *"""
+
+        return skip_if(["oracle", "mssql"])
+
     @property
     def correlated_outer_joins(self):
         """Target must support an outer join to a subquery which