From: Mike Bayer Date: Thu, 1 Aug 2024 13:49:55 +0000 (-0400) Subject: mutate lists in place for return_defaults=True X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7001429a7561b3c55dd52b96dfa419004e535743;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git mutate lists in place for return_defaults=True Fixed regression from version 1.4 in :meth:`_orm.Session.bulk_insert_mappings` where using the :paramref:`_orm.Session.bulk_insert_mappings.return_defaults` parameter would not populate the passed in dictionaries with newly generated primary key values. Fixes: #11661 Change-Id: I331d81a5b04456f107eb868f882d67773b3eec38 --- diff --git a/doc/build/changelog/unreleased_20/11661.rst b/doc/build/changelog/unreleased_20/11661.rst new file mode 100644 index 0000000000..35985d8bba --- /dev/null +++ b/doc/build/changelog/unreleased_20/11661.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 11661 + + Fixed regression from version 1.4 in + :meth:`_orm.Session.bulk_insert_mappings` where using the + :paramref:`_orm.Session.bulk_insert_mappings.return_defaults` parameter + would not populate the passed in dictionaries with newly generated primary + key values. + diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index b53a8302ea..b5134034d6 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -121,13 +121,35 @@ def _bulk_insert( ) if isstates: + if TYPE_CHECKING: + mappings = cast(Iterable[InstanceState[_O]], mappings) + if return_defaults: + # list of states allows us to attach .key for return_defaults case states = [(state, state.dict) for state in mappings] mappings = [dict_ for (state, dict_) in states] else: mappings = [state.dict for state in mappings] else: - mappings = [dict(m) for m in mappings] + if TYPE_CHECKING: + mappings = cast(Iterable[Dict[str, Any]], mappings) + + if return_defaults: + # use dictionaries given, so that newly populated defaults + # can be delivered back to the caller (see #11661). This is **not** + # compatible with other use cases such as a session-executed + # insert() construct, as this will confuse the case of + # insert-per-subclass for joined inheritance cases (see + # test_bulk_statements.py::BulkDMLReturningJoinedInhTest). + # + # So in this conditional, we have **only** called + # session.bulk_insert_mappings() which does not have this + # requirement + mappings = list(mappings) + else: + # for all other cases we need to establish a local dictionary + # so that the incoming dictionaries aren't mutated + mappings = [dict(m) for m in mappings] _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) diff --git a/test/orm/dml/test_bulk.py b/test/orm/dml/test_bulk.py index 4d24a52ece..3159c139da 100644 --- a/test/orm/dml/test_bulk.py +++ b/test/orm/dml/test_bulk.py @@ -90,8 +90,14 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): cls.mapper_registry.map_imperatively(Address, a) cls.mapper_registry.map_imperatively(Order, o) - @testing.combinations("save_objects", "insert_mappings", "insert_stmt") - def test_bulk_save_return_defaults(self, statement_type): + @testing.combinations( + "save_objects", + "insert_mappings", + "insert_stmt", + argnames="statement_type", + ) + @testing.variation("return_defaults", [True, False]) + def test_bulk_save_return_defaults(self, statement_type, return_defaults): (User,) = self.classes("User") s = fixture_session() @@ -102,12 +108,14 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): returning_users_id = " RETURNING users.id" with self.sql_execution_asserter() as asserter: - s.bulk_save_objects(objects, return_defaults=True) + s.bulk_save_objects(objects, return_defaults=return_defaults) elif statement_type == "insert_mappings": data = [dict(name="u1"), dict(name="u2"), dict(name="u3")] returning_users_id = " RETURNING users.id" with self.sql_execution_asserter() as asserter: - s.bulk_insert_mappings(User, data, return_defaults=True) + s.bulk_insert_mappings( + User, data, return_defaults=return_defaults + ) elif statement_type == "insert_stmt": data = [dict(name="u1"), dict(name="u2"), dict(name="u3")] @@ -120,7 +128,10 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): asserter.assert_( Conditional( - testing.db.dialect.insert_executemany_returning + ( + return_defaults + and testing.db.dialect.insert_executemany_returning + ) or statement_type == "insert_stmt", [ CompiledSQL( @@ -130,23 +141,50 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): ), ], [ - CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{"name": "u1"}], - ), - CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{"name": "u2"}], - ), - CompiledSQL( - "INSERT INTO users (name) VALUES (:name)", - [{"name": "u3"}], - ), + Conditional( + return_defaults, + [ + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "u1"}], + ), + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "u2"}], + ), + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "u3"}], + ), + ], + [ + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [ + {"name": "u1"}, + {"name": "u2"}, + {"name": "u3"}, + ], + ), + ], + ) ], ) ) + if statement_type == "save_objects": - eq_(objects[0].__dict__["id"], 1) + if return_defaults: + eq_(objects[0].__dict__["id"], 1) + eq_(inspect(objects[0]).key, (User, (1,), None)) + else: + assert "id" not in objects[0].__dict__ + eq_(inspect(objects[0]).key, None) + elif statement_type == "insert_mappings": + # test for #11661 + if return_defaults: + eq_(data[0]["id"], 1) + else: + assert "id" not in data[0] def test_bulk_save_objects_defaults_key(self): User = self.classes.User