]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
mutate lists in place for return_defaults=True
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Aug 2024 13:49:55 +0000 (09:49 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Aug 2024 17:11:57 +0000 (13:11 -0400)
Fixed regression from version 1.4 in
:meth:`_orm.Session.bulk_insert_mappings` where using the
:paramref:`_orm.Session.bulk_insert_mappings.return_defaults` parameter
would not populate the passed in dictionaries with newly generated primary
key values.

Fixes: #11661
Change-Id: I331d81a5b04456f107eb868f882d67773b3eec38

doc/build/changelog/unreleased_20/11661.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
test/orm/dml/test_bulk.py

diff --git a/doc/build/changelog/unreleased_20/11661.rst b/doc/build/changelog/unreleased_20/11661.rst
new file mode 100644 (file)
index 0000000..35985d8
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 11661
+
+    Fixed regression from version 1.4 in
+    :meth:`_orm.Session.bulk_insert_mappings` where using the
+    :paramref:`_orm.Session.bulk_insert_mappings.return_defaults` parameter
+    would not populate the passed in dictionaries with newly generated primary
+    key values.
+
index b53a8302eacb6e671d3290755e0625cb5ffb1e7e..b5134034d6c0623a4bd340791f0b4a78b3167664 100644 (file)
@@ -121,13 +121,35 @@ def _bulk_insert(
         )
 
     if isstates:
+        if TYPE_CHECKING:
+            mappings = cast(Iterable[InstanceState[_O]], mappings)
+
         if return_defaults:
+            # list of states allows us to attach .key for return_defaults case
             states = [(state, state.dict) for state in mappings]
             mappings = [dict_ for (state, dict_) in states]
         else:
             mappings = [state.dict for state in mappings]
     else:
-        mappings = [dict(m) for m in mappings]
+        if TYPE_CHECKING:
+            mappings = cast(Iterable[Dict[str, Any]], mappings)
+
+        if return_defaults:
+            # use dictionaries given, so that newly populated defaults
+            # can be delivered back to the caller (see #11661). This is **not**
+            # compatible with other use cases such as a session-executed
+            # insert() construct, as this will confuse the case of
+            # insert-per-subclass for joined inheritance cases (see
+            # test_bulk_statements.py::BulkDMLReturningJoinedInhTest).
+            #
+            # So in this conditional, we have **only** called
+            # session.bulk_insert_mappings() which does not have this
+            # requirement
+            mappings = list(mappings)
+        else:
+            # for all other cases we need to establish a local dictionary
+            # so that the incoming dictionaries aren't mutated
+            mappings = [dict(m) for m in mappings]
         _expand_composites(mapper, mappings)
 
     connection = session_transaction.connection(base_mapper)
index 4d24a52ecebb6290026fba65b9d41dba8d2deedb..3159c139da2bb8446a580c4a397a46090aff415f 100644 (file)
@@ -90,8 +90,14 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
         cls.mapper_registry.map_imperatively(Address, a)
         cls.mapper_registry.map_imperatively(Order, o)
 
-    @testing.combinations("save_objects", "insert_mappings", "insert_stmt")
-    def test_bulk_save_return_defaults(self, statement_type):
+    @testing.combinations(
+        "save_objects",
+        "insert_mappings",
+        "insert_stmt",
+        argnames="statement_type",
+    )
+    @testing.variation("return_defaults", [True, False])
+    def test_bulk_save_return_defaults(self, statement_type, return_defaults):
         (User,) = self.classes("User")
 
         s = fixture_session()
@@ -102,12 +108,14 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
 
             returning_users_id = " RETURNING users.id"
             with self.sql_execution_asserter() as asserter:
-                s.bulk_save_objects(objects, return_defaults=True)
+                s.bulk_save_objects(objects, return_defaults=return_defaults)
         elif statement_type == "insert_mappings":
             data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
             returning_users_id = " RETURNING users.id"
             with self.sql_execution_asserter() as asserter:
-                s.bulk_insert_mappings(User, data, return_defaults=True)
+                s.bulk_insert_mappings(
+                    User, data, return_defaults=return_defaults
+                )
         elif statement_type == "insert_stmt":
             data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
 
@@ -120,7 +128,10 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
 
         asserter.assert_(
             Conditional(
-                testing.db.dialect.insert_executemany_returning
+                (
+                    return_defaults
+                    and testing.db.dialect.insert_executemany_returning
+                )
                 or statement_type == "insert_stmt",
                 [
                     CompiledSQL(
@@ -130,23 +141,50 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
                     ),
                 ],
                 [
-                    CompiledSQL(
-                        "INSERT INTO users (name) VALUES (:name)",
-                        [{"name": "u1"}],
-                    ),
-                    CompiledSQL(
-                        "INSERT INTO users (name) VALUES (:name)",
-                        [{"name": "u2"}],
-                    ),
-                    CompiledSQL(
-                        "INSERT INTO users (name) VALUES (:name)",
-                        [{"name": "u3"}],
-                    ),
+                    Conditional(
+                        return_defaults,
+                        [
+                            CompiledSQL(
+                                "INSERT INTO users (name) VALUES (:name)",
+                                [{"name": "u1"}],
+                            ),
+                            CompiledSQL(
+                                "INSERT INTO users (name) VALUES (:name)",
+                                [{"name": "u2"}],
+                            ),
+                            CompiledSQL(
+                                "INSERT INTO users (name) VALUES (:name)",
+                                [{"name": "u3"}],
+                            ),
+                        ],
+                        [
+                            CompiledSQL(
+                                "INSERT INTO users (name) VALUES (:name)",
+                                [
+                                    {"name": "u1"},
+                                    {"name": "u2"},
+                                    {"name": "u3"},
+                                ],
+                            ),
+                        ],
+                    )
                 ],
             )
         )
+
         if statement_type == "save_objects":
-            eq_(objects[0].__dict__["id"], 1)
+            if return_defaults:
+                eq_(objects[0].__dict__["id"], 1)
+                eq_(inspect(objects[0]).key, (User, (1,), None))
+            else:
+                assert "id" not in objects[0].__dict__
+                eq_(inspect(objects[0]).key, None)
+        elif statement_type == "insert_mappings":
+            # test for #11661
+            if return_defaults:
+                eq_(data[0]["id"], 1)
+            else:
+                assert "id" not in data[0]
 
     def test_bulk_save_objects_defaults_key(self):
         User = self.classes.User