From: Federico Caselli Date: Sun, 28 Apr 2024 10:01:05 +0000 (+0200) Subject: Fix issue in bulk_save_objects X-Git-Tag: rel_2_0_30~13^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=368d88d1ac29db7b5d3933e37d43aebc06ad633b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix issue in bulk_save_objects Fixes issue in :meth:`_orm.Session.bulk_save_objects()` where it would write a wrong identity key when using ``return_defaults=True``. The wrong identity key could lead to an index error when entities are then pickled. Fixes: #11332 Change-Id: I8d095392ad03e8d3408e477476cd5de8a5bca2c0 (cherry picked from commit ade4bdfb0406fadff566aa8d39abe6aa29af521f) --- diff --git a/doc/build/changelog/unreleased_20/11332.rst b/doc/build/changelog/unreleased_20/11332.rst new file mode 100644 index 0000000000..c8f748654c --- /dev/null +++ b/doc/build/changelog/unreleased_20/11332.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm + :tickets: 11332 + + Fixes issue in :meth:`_orm.Session.bulk_save_objects` where it would write a + wrong identity key when using ``return_defaults=True``. + The wrong identity key could lead to an index error when entities are then pickled. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 5d2558d953..2ed6a4beaa 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -76,6 +76,7 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, @@ -89,6 +90,7 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, @@ -101,6 +103,7 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, @@ -218,6 +221,7 @@ def _bulk_insert( state.key = ( identity_cls, tuple([dict_[key] for key in identity_props]), + None, ) if use_orm_insert_stmt is not None: @@ -230,6 +234,7 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Literal[None] = ..., @@ -242,6 +247,7 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = ..., @@ -253,6 +259,7 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = None, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index acc6895e86..a4bf7c1cec 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -4574,11 +4574,11 @@ class Session(_SessionClassMethods, EventTarget): self._bulk_save_mappings( mapper, states, - isupdate, - True, - return_defaults, - update_changed_only, - False, + isupdate=isupdate, + isstates=True, + return_defaults=return_defaults, + update_changed_only=update_changed_only, + render_nulls=False, ) def bulk_insert_mappings( @@ -4657,11 +4657,11 @@ class Session(_SessionClassMethods, EventTarget): self._bulk_save_mappings( mapper, mappings, - False, - False, - return_defaults, - False, - render_nulls, + isupdate=False, + isstates=False, + return_defaults=return_defaults, + update_changed_only=False, + render_nulls=render_nulls, ) def bulk_update_mappings( @@ -4703,13 +4703,20 @@ class Session(_SessionClassMethods, EventTarget): """ self._bulk_save_mappings( - mapper, mappings, True, False, False, False, False + mapper, + mappings, + isupdate=True, + isstates=False, + return_defaults=False, + update_changed_only=False, + render_nulls=False, ) def _bulk_save_mappings( self, mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + *, isupdate: bool, isstates: bool, return_defaults: bool, @@ -4726,17 +4733,17 @@ class Session(_SessionClassMethods, EventTarget): mapper, mappings, transaction, - isstates, - update_changed_only, + isstates=isstates, + update_changed_only=update_changed_only, ) else: bulk_persistence._bulk_insert( mapper, mappings, transaction, - isstates, - return_defaults, - render_nulls, + isstates=isstates, + return_defaults=return_defaults, + render_nulls=render_nulls, ) transaction.commit() diff --git a/test/orm/dml/test_bulk.py b/test/orm/dml/test_bulk.py index 62b435e9cb..4d24a52ece 100644 --- a/test/orm/dml/test_bulk.py +++ b/test/orm/dml/test_bulk.py @@ -2,6 +2,7 @@ from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import Identity from sqlalchemy import insert +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing @@ -147,6 +148,17 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): if statement_type == "save_objects": eq_(objects[0].__dict__["id"], 1) + def test_bulk_save_objects_defaults_key(self): + User = self.classes.User + + pes = [User(name=f"foo{i}") for i in range(3)] + s = fixture_session() + s.bulk_save_objects(pes, return_defaults=True) + key = inspect(pes[0]).key + + s.commit() + eq_(inspect(s.get(User, 1)).key, key) + def test_bulk_save_mappings_preserve_order(self): (User,) = self.classes("User") diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 96dec4a60b..18904cc386 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -654,6 +654,17 @@ class PickleTest(fixtures.MappedTest): ) is_not_none(collections.collection_adapter(repickled.addresses)) + def test_bulk_save_objects_defaults_pickle(self): + "Test for #11332" + users = self.tables.users + + self.mapper_registry.map_imperatively(User, users) + pes = [User(name=f"foo{i}") for i in range(3)] + s = fixture_session() + s.bulk_save_objects(pes, return_defaults=True) + state = pickle.dumps(pes) + pickle.loads(state) + class OptionsTest(_Polymorphic): def test_options_of_type(self):