From ade4bdfb0406fadff566aa8d39abe6aa29af521f Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sun, 28 Apr 2024 12:01:05 +0200 Subject: [PATCH] Fix issue in bulk_save_objects Fixes issue in :meth:`_orm.Session.bulk_save_objects()` where it would write a wrong identity key when using ``return_defaults=True``. The wrong identity key could lead to an index error when entities are then pickled. Fixes: #11332 Change-Id: I8d095392ad03e8d3408e477476cd5de8a5bca2c0 --- doc/build/changelog/unreleased_20/11332.rst | 7 ++++ lib/sqlalchemy/orm/bulk_persistence.py | 7 ++++ lib/sqlalchemy/orm/session.py | 39 ++++++++++++--------- test/orm/dml/test_bulk.py | 12 +++++++ test/orm/test_pickled.py | 11 ++++++ 5 files changed, 60 insertions(+), 16 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/11332.rst diff --git a/doc/build/changelog/unreleased_20/11332.rst b/doc/build/changelog/unreleased_20/11332.rst new file mode 100644 index 0000000000..c8f748654c --- /dev/null +++ b/doc/build/changelog/unreleased_20/11332.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm + :tickets: 11332 + + Fixes issue in :meth:`_orm.Session.bulk_save_objects` where it would write a + wrong identity key when using ``return_defaults=True``. + The wrong identity key could lead to an index error when entities are then pickled. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index d59570bc20..37beb0f2bb 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -78,6 +78,7 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, @@ -91,6 +92,7 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, @@ -103,6 +105,7 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, return_defaults: bool, render_nulls: bool, @@ -220,6 +223,7 @@ def _bulk_insert( state.key = ( identity_cls, tuple([dict_[key] for key in identity_props]), + None, ) if use_orm_insert_stmt is not None: @@ -232,6 +236,7 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Literal[None] = ..., @@ -244,6 +249,7 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = ..., @@ -255,6 +261,7 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, + *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = None, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 13b906fe24..3963bf1b17 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -4591,11 +4591,11 @@ class Session(_SessionClassMethods, EventTarget): self._bulk_save_mappings( mapper, states, - isupdate, - True, - return_defaults, - update_changed_only, - False, + isupdate=isupdate, + isstates=True, + return_defaults=return_defaults, + update_changed_only=update_changed_only, + render_nulls=False, ) def bulk_insert_mappings( @@ -4674,11 +4674,11 @@ class Session(_SessionClassMethods, EventTarget): self._bulk_save_mappings( mapper, mappings, - False, - False, - return_defaults, - False, - render_nulls, + isupdate=False, + isstates=False, + return_defaults=return_defaults, + update_changed_only=False, + render_nulls=render_nulls, ) def bulk_update_mappings( @@ -4720,13 +4720,20 @@ class Session(_SessionClassMethods, EventTarget): """ self._bulk_save_mappings( - mapper, mappings, True, False, False, False, False + mapper, + mappings, + isupdate=True, + isstates=False, + return_defaults=False, + update_changed_only=False, + render_nulls=False, ) def _bulk_save_mappings( self, mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + *, isupdate: bool, isstates: bool, return_defaults: bool, @@ -4743,17 +4750,17 @@ class Session(_SessionClassMethods, EventTarget): mapper, mappings, transaction, - isstates, - update_changed_only, + isstates=isstates, + update_changed_only=update_changed_only, ) else: bulk_persistence._bulk_insert( mapper, mappings, transaction, - isstates, - return_defaults, - render_nulls, + isstates=isstates, + return_defaults=return_defaults, + render_nulls=render_nulls, ) transaction.commit() diff --git a/test/orm/dml/test_bulk.py b/test/orm/dml/test_bulk.py index 62b435e9cb..4d24a52ece 100644 --- a/test/orm/dml/test_bulk.py +++ b/test/orm/dml/test_bulk.py @@ -2,6 +2,7 @@ from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import Identity from sqlalchemy import insert +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing @@ -147,6 +148,17 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest): if statement_type == "save_objects": eq_(objects[0].__dict__["id"], 1) + def test_bulk_save_objects_defaults_key(self): + User = self.classes.User + + pes = [User(name=f"foo{i}") for i in range(3)] + s = fixture_session() + s.bulk_save_objects(pes, return_defaults=True) + key = inspect(pes[0]).key + + s.commit() + eq_(inspect(s.get(User, 1)).key, key) + def test_bulk_save_mappings_preserve_order(self): (User,) = self.classes("User") diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 96dec4a60b..18904cc386 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -654,6 +654,17 @@ class PickleTest(fixtures.MappedTest): ) is_not_none(collections.collection_adapter(repickled.addresses)) + def test_bulk_save_objects_defaults_pickle(self): + "Test for #11332" + users = self.tables.users + + self.mapper_registry.map_imperatively(User, users) + pes = [User(name=f"foo{i}") for i in range(3)] + s = fixture_session() + s.bulk_save_objects(pes, return_defaults=True) + state = pickle.dumps(pes) + pickle.loads(state) + class OptionsTest(_Polymorphic): def test_options_of_type(self): -- 2.47.2