]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
try to omit unnecessary cols for ORM bulk insert + returning
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Apr 2023 13:48:04 +0000 (09:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Apr 2023 14:24:41 +0000 (10:24 -0400)
Fixed bug in ORM bulk insert feature where additional unnecessary columns
would be rendered in the INSERT statement if RETURNING of individual
columns were requested.

Fixes: #9685
Change-Id: Ibf5f06ab017215c7c9bd8850c3a006f73fe78c68

doc/build/changelog/unreleased_20/9685.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/persistence.py
test/orm/dml/test_bulk_statements.py

diff --git a/doc/build/changelog/unreleased_20/9685.rst b/doc/build/changelog/unreleased_20/9685.rst
new file mode 100644 (file)
index 0000000..7100a96
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9685
+
+    Fixed bug in ORM bulk insert feature where additional unnecessary columns
+    would be rendered in the INSERT statement if RETURNING of individual columns
+    were requested.
index f9d9d6a433bb5384247b66414cc6f7251d40095a..c096dc3e5d39353da8e0b3cc2c332f9c9f9823e4 100644 (file)
@@ -464,8 +464,9 @@ class ORMDMLState(AbstractORMCompileState):
         compiler,
         orm_level_statement,
         dml_level_statement,
+        dml_mapper,
+        *,
         use_supplemental_cols=True,
-        dml_mapper=None,
     ):
         """establish ORM column handlers for an INSERT, UPDATE, or DELETE
         which uses explicit returning().
@@ -504,7 +505,17 @@ class ORMDMLState(AbstractORMCompileState):
 
             if use_supplemental_cols:
                 dml_level_statement = dml_level_statement.return_defaults(
-                    supplemental_cols=cols_to_return
+                    # this is a little weird looking, but by passing
+                    # primary key as the main list of cols, this tells
+                    # return_defaults to omit server-default cols.  Since
+                    # we have cols_to_return, just return what we asked for
+                    # (plus primary key, which ORM persistence needs since
+                    # we likely set bookkeeping=True here, which is another
+                    # whole thing...).   We dont want to clutter the
+                    # statement up with lots of other cols the user didn't
+                    # ask for.  see #9685
+                    *dml_mapper.primary_key,
+                    supplemental_cols=cols_to_return,
                 )
             else:
                 dml_level_statement = dml_level_statement.returning(
@@ -1280,6 +1291,7 @@ class BulkORMInsert(ORMDMLState, InsertDMLState):
             compiler,
             orm_level_statement,
             statement,
+            dml_mapper=mapper,
             use_supplemental_cols=False,
         )
         self.statement = statement
@@ -1314,8 +1326,8 @@ class BulkORMInsert(ORMDMLState, InsertDMLState):
             compiler,
             orm_level_statement,
             statement,
-            use_supplemental_cols=True,
             dml_mapper=emit_insert_mapper,
+            use_supplemental_cols=True,
         )
 
         self.statement = statement
@@ -1425,6 +1437,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
             compiler,
             orm_level_statement,
             new_stmt,
+            dml_mapper=mapper,
             use_supplemental_cols=use_supplemental_cols,
         )
 
@@ -1795,6 +1808,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
             compiler,
             orm_level_statement,
             new_stmt,
+            dml_mapper=mapper,
             use_supplemental_cols=use_supplemental_cols,
         )
 
index a331d4ed85cdcf983061b96f35a5411f81877498..a12156eb528a517d64c560ebff8131e6279aac63 100644 (file)
@@ -1076,12 +1076,16 @@ def _emit_insert_statements(
             else:
                 do_executemany = False
 
-            if not has_all_defaults and base_mapper._prefer_eager_defaults(
-                connection.dialect, table
-            ):
-                statement = statement.return_defaults(
-                    *mapper._server_default_cols[table]
-                )
+            if use_orm_insert_stmt is None:
+                if (
+                    not has_all_defaults
+                    and base_mapper._prefer_eager_defaults(
+                        connection.dialect, table
+                    )
+                ):
+                    statement = statement.return_defaults(
+                        *mapper._server_default_cols[table]
+                    )
 
             if mapper.version_id_col is not None:
                 statement = statement.return_defaults(mapper.version_id_col)
index b4628af670b974845bf3c785426c164a3c145e65..7a9f3324f8856e90fcfc5359144eb49820467a28 100644 (file)
@@ -33,7 +33,7 @@ from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 
 
-class InsertStmtTest(fixtures.TestBase):
+class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
     def test_no_returning_error(self, decl_base):
         class A(fixtures.ComparableEntity, decl_base):
             __tablename__ = "a"
@@ -89,6 +89,48 @@ class InsertStmtTest(fixtures.TestBase):
             [("d3", 5), ("d4", 6)],
         )
 
+    @testing.requires.insert_returning
+    def test_insert_returning_cols_dont_give_me_defaults(self, decl_base):
+        """test #9685"""
+
+        class User(decl_base):
+            __tablename__ = "users"
+
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+
+            name: Mapped[str] = mapped_column()
+            other_thing: Mapped[Optional[str]]
+            server_thing: Mapped[str] = mapped_column(server_default="thing")
+
+        decl_base.metadata.create_all(testing.db)
+        insert_stmt = insert(User).returning(User.id)
+
+        s = fixture_session()
+
+        with self.sql_execution_asserter() as asserter:
+            result = s.execute(
+                insert_stmt,
+                [
+                    {"name": "some name 1"},
+                    {"name": "some name 2"},
+                    {"name": "some name 3"},
+                ],
+            )
+
+        eq_(result.all(), [(1,), (2,), (3,)])
+
+        asserter.assert_(
+            CompiledSQL(
+                "INSERT INTO users (name) VALUES (:name) "
+                "RETURNING users.id",
+                [
+                    {"name": "some name 1"},
+                    {"name": "some name 2"},
+                    {"name": "some name 3"},
+                ],
+            ),
+        )
+
     @testing.requires.insert_returning
     def test_insert_from_select_col_property(self, decl_base):
         """test #9273"""
@@ -191,17 +233,12 @@ class BulkDMLReturningInhTest:
         with self.sql_execution_asserter() as asserter:
             result = s.execute(stmt, values)
 
-        if inspect(B).single:
-            single_inh = ", a.bd, a.zcol, a.q"
-        else:
-            single_inh = ""
-
         if use_returning:
             asserter.assert_(
                 CompiledSQL(
                     "INSERT INTO a (type, data, xcol) VALUES "
                     "(:type, :data, :xcol) "
-                    f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+                    "RETURNING a.id, a.type, a.data, a.xcol, a.y",
                     [
                         {"type": "a", "data": "d3", "xcol": 5},
                         {"type": "a", "data": "d4", "xcol": 6},
@@ -209,13 +246,13 @@ class BulkDMLReturningInhTest:
                 ),
                 CompiledSQL(
                     "INSERT INTO a (type, data) VALUES (:type, :data) "
-                    f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+                    "RETURNING a.id, a.type, a.data, a.xcol, a.y",
                     [{"type": "a", "data": "d5"}],
                 ),
                 CompiledSQL(
                     "INSERT INTO a (type, data, xcol, y) "
                     "VALUES (:type, :data, :xcol, :y) "
-                    f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+                    "RETURNING a.id, a.type, a.data, a.xcol, a.y",
                     [
                         {"type": "a", "data": "d6", "xcol": 8, "y": 9},
                         {"type": "a", "data": "d7", "xcol": 12, "y": 12},
@@ -224,7 +261,7 @@ class BulkDMLReturningInhTest:
                 CompiledSQL(
                     "INSERT INTO a (type, data, xcol) "
                     "VALUES (:type, :data, :xcol) "
-                    f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+                    "RETURNING a.id, a.type, a.data, a.xcol, a.y",
                     [{"type": "a", "data": "d8", "xcol": 7}],
                 ),
             )
@@ -258,17 +295,18 @@ class BulkDMLReturningInhTest:
             )
 
         if use_returning:
-            eq_(
-                result.scalars().all(),
-                [
-                    A(data="d3", id=mock.ANY, type="a", x=5, y=None),
-                    A(data="d4", id=mock.ANY, type="a", x=6, y=None),
-                    A(data="d5", id=mock.ANY, type="a", x=None, y=None),
-                    A(data="d6", id=mock.ANY, type="a", x=8, y=9),
-                    A(data="d7", id=mock.ANY, type="a", x=12, y=12),
-                    A(data="d8", id=mock.ANY, type="a", x=7, y=None),
-                ],
-            )
+            with self.assert_statement_count(testing.db, 0):
+                eq_(
+                    result.scalars().all(),
+                    [
+                        A(data="d3", id=mock.ANY, type="a", x=5, y=None),
+                        A(data="d4", id=mock.ANY, type="a", x=6, y=None),
+                        A(data="d5", id=mock.ANY, type="a", x=None, y=None),
+                        A(data="d6", id=mock.ANY, type="a", x=8, y=9),
+                        A(data="d7", id=mock.ANY, type="a", x=12, y=12),
+                        A(data="d8", id=mock.ANY, type="a", x=7, y=None),
+                    ],
+                )
 
     @testing.combinations(
         "strings",