]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix issue in bulk_save_objects
authorFederico Caselli <cfederico87@gmail.com>
Sun, 28 Apr 2024 10:01:05 +0000 (12:01 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 29 Apr 2024 19:56:58 +0000 (21:56 +0200)
Fixes issue in :meth:`_orm.Session.bulk_save_objects()` where it would write a
wrong identity key when using ``return_defaults=True``.
The wrong identity key could lead to an index error when entities are then pickled.

Fixes: #11332
Change-Id: I8d095392ad03e8d3408e477476cd5de8a5bca2c0

doc/build/changelog/unreleased_20/11332.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/session.py
test/orm/dml/test_bulk.py
test/orm/test_pickled.py

diff --git a/doc/build/changelog/unreleased_20/11332.rst b/doc/build/changelog/unreleased_20/11332.rst
new file mode 100644 (file)
index 0000000..c8f7486
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11332
+
+    Fixes issue in :meth:`_orm.Session.bulk_save_objects` where it would write a
+    wrong identity key when using ``return_defaults=True``.
+    The wrong identity key could lead to an index error when entities are then pickled.
index d59570bc2023fa64e2e06785f83fe277ac4bcdff..37beb0f2bb40c9f9e25a95225c6486f87693492a 100644 (file)
@@ -78,6 +78,7 @@ def _bulk_insert(
     mapper: Mapper[_O],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     return_defaults: bool,
     render_nulls: bool,
@@ -91,6 +92,7 @@ def _bulk_insert(
     mapper: Mapper[_O],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     return_defaults: bool,
     render_nulls: bool,
@@ -103,6 +105,7 @@ def _bulk_insert(
     mapper: Mapper[_O],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     return_defaults: bool,
     render_nulls: bool,
@@ -220,6 +223,7 @@ def _bulk_insert(
             state.key = (
                 identity_cls,
                 tuple([dict_[key] for key in identity_props]),
+                None,
             )
 
     if use_orm_insert_stmt is not None:
@@ -232,6 +236,7 @@ def _bulk_update(
     mapper: Mapper[Any],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Literal[None] = ...,
@@ -244,6 +249,7 @@ def _bulk_update(
     mapper: Mapper[Any],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Optional[dml.Update] = ...,
@@ -255,6 +261,7 @@ def _bulk_update(
     mapper: Mapper[Any],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Optional[dml.Update] = None,
index 13b906fe2475fd77d15d989a6fef1e189b4b8eda..3963bf1b176048cfd644c91902b7102ece3b0332 100644 (file)
@@ -4591,11 +4591,11 @@ class Session(_SessionClassMethods, EventTarget):
             self._bulk_save_mappings(
                 mapper,
                 states,
-                isupdate,
-                True,
-                return_defaults,
-                update_changed_only,
-                False,
+                isupdate=isupdate,
+                isstates=True,
+                return_defaults=return_defaults,
+                update_changed_only=update_changed_only,
+                render_nulls=False,
             )
 
     def bulk_insert_mappings(
@@ -4674,11 +4674,11 @@ class Session(_SessionClassMethods, EventTarget):
         self._bulk_save_mappings(
             mapper,
             mappings,
-            False,
-            False,
-            return_defaults,
-            False,
-            render_nulls,
+            isupdate=False,
+            isstates=False,
+            return_defaults=return_defaults,
+            update_changed_only=False,
+            render_nulls=render_nulls,
         )
 
     def bulk_update_mappings(
@@ -4720,13 +4720,20 @@ class Session(_SessionClassMethods, EventTarget):
 
         """
         self._bulk_save_mappings(
-            mapper, mappings, True, False, False, False, False
+            mapper,
+            mappings,
+            isupdate=True,
+            isstates=False,
+            return_defaults=False,
+            update_changed_only=False,
+            render_nulls=False,
         )
 
     def _bulk_save_mappings(
         self,
         mapper: Mapper[_O],
         mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+        *,
         isupdate: bool,
         isstates: bool,
         return_defaults: bool,
@@ -4743,17 +4750,17 @@ class Session(_SessionClassMethods, EventTarget):
                     mapper,
                     mappings,
                     transaction,
-                    isstates,
-                    update_changed_only,
+                    isstates=isstates,
+                    update_changed_only=update_changed_only,
                 )
             else:
                 bulk_persistence._bulk_insert(
                     mapper,
                     mappings,
                     transaction,
-                    isstates,
-                    return_defaults,
-                    render_nulls,
+                    isstates=isstates,
+                    return_defaults=return_defaults,
+                    render_nulls=render_nulls,
                 )
             transaction.commit()
 
index 62b435e9cbfe8e0d24adcd508c30b10180f98848..4d24a52ecebb6290026fba65b9d41dba8d2deedb 100644 (file)
@@ -2,6 +2,7 @@ from sqlalchemy import FetchedValue
 from sqlalchemy import ForeignKey
 from sqlalchemy import Identity
 from sqlalchemy import insert
+from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
@@ -147,6 +148,17 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
         if statement_type == "save_objects":
             eq_(objects[0].__dict__["id"], 1)
 
+    def test_bulk_save_objects_defaults_key(self):
+        User = self.classes.User
+
+        pes = [User(name=f"foo{i}") for i in range(3)]
+        s = fixture_session()
+        s.bulk_save_objects(pes, return_defaults=True)
+        key = inspect(pes[0]).key
+
+        s.commit()
+        eq_(inspect(s.get(User, 1)).key, key)
+
     def test_bulk_save_mappings_preserve_order(self):
         (User,) = self.classes("User")
 
index 96dec4a60b70d501cafcf681405061d86f6aec07..18904cc38611f7b65d6fcda27f633fee374fa4af 100644 (file)
@@ -654,6 +654,17 @@ class PickleTest(fixtures.MappedTest):
             )
             is_not_none(collections.collection_adapter(repickled.addresses))
 
+    def test_bulk_save_objects_defaults_pickle(self):
+        "Test for #11332"
+        users = self.tables.users
+
+        self.mapper_registry.map_imperatively(User, users)
+        pes = [User(name=f"foo{i}") for i in range(3)]
+        s = fixture_session()
+        s.bulk_save_objects(pes, return_defaults=True)
+        state = pickle.dumps(pes)
+        pickle.loads(state)
+
 
 class OptionsTest(_Polymorphic):
     def test_options_of_type(self):