]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix issue in bulk_save_objects
authorFederico Caselli <cfederico87@gmail.com>
Sun, 28 Apr 2024 10:01:05 +0000 (12:01 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 29 Apr 2024 19:58:33 +0000 (21:58 +0200)
Fixes issue in :meth:`_orm.Session.bulk_save_objects()` where it would write a
wrong identity key when using ``return_defaults=True``.
The wrong identity key could lead to an index error when entities are then pickled.

Fixes: #11332
Change-Id: I8d095392ad03e8d3408e477476cd5de8a5bca2c0
(cherry picked from commit ade4bdfb0406fadff566aa8d39abe6aa29af521f)

doc/build/changelog/unreleased_20/11332.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/session.py
test/orm/dml/test_bulk.py
test/orm/test_pickled.py

diff --git a/doc/build/changelog/unreleased_20/11332.rst b/doc/build/changelog/unreleased_20/11332.rst
new file mode 100644 (file)
index 0000000..c8f7486
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11332
+
+    Fixes issue in :meth:`_orm.Session.bulk_save_objects` where it would write a
+    wrong identity key when using ``return_defaults=True``.
+    The wrong identity key could lead to an index error when entities are then pickled.
index 5d2558d9530f577b9f0c8ecc434cd2341b7dcd38..2ed6a4beaacde576add7e5109d195ebc0ac7fda5 100644 (file)
@@ -76,6 +76,7 @@ def _bulk_insert(
     mapper: Mapper[_O],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     return_defaults: bool,
     render_nulls: bool,
@@ -89,6 +90,7 @@ def _bulk_insert(
     mapper: Mapper[_O],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     return_defaults: bool,
     render_nulls: bool,
@@ -101,6 +103,7 @@ def _bulk_insert(
     mapper: Mapper[_O],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     return_defaults: bool,
     render_nulls: bool,
@@ -218,6 +221,7 @@ def _bulk_insert(
             state.key = (
                 identity_cls,
                 tuple([dict_[key] for key in identity_props]),
+                None,
             )
 
     if use_orm_insert_stmt is not None:
@@ -230,6 +234,7 @@ def _bulk_update(
     mapper: Mapper[Any],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Literal[None] = ...,
@@ -242,6 +247,7 @@ def _bulk_update(
     mapper: Mapper[Any],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Optional[dml.Update] = ...,
@@ -253,6 +259,7 @@ def _bulk_update(
     mapper: Mapper[Any],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
+    *,
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Optional[dml.Update] = None,
index acc6895e86f36fad32ab17ea81537a39650127ae..a4bf7c1cecf46f87dfd8348e64e1de6d858fecfc 100644 (file)
@@ -4574,11 +4574,11 @@ class Session(_SessionClassMethods, EventTarget):
             self._bulk_save_mappings(
                 mapper,
                 states,
-                isupdate,
-                True,
-                return_defaults,
-                update_changed_only,
-                False,
+                isupdate=isupdate,
+                isstates=True,
+                return_defaults=return_defaults,
+                update_changed_only=update_changed_only,
+                render_nulls=False,
             )
 
     def bulk_insert_mappings(
@@ -4657,11 +4657,11 @@ class Session(_SessionClassMethods, EventTarget):
         self._bulk_save_mappings(
             mapper,
             mappings,
-            False,
-            False,
-            return_defaults,
-            False,
-            render_nulls,
+            isupdate=False,
+            isstates=False,
+            return_defaults=return_defaults,
+            update_changed_only=False,
+            render_nulls=render_nulls,
         )
 
     def bulk_update_mappings(
@@ -4703,13 +4703,20 @@ class Session(_SessionClassMethods, EventTarget):
 
         """
         self._bulk_save_mappings(
-            mapper, mappings, True, False, False, False, False
+            mapper,
+            mappings,
+            isupdate=True,
+            isstates=False,
+            return_defaults=False,
+            update_changed_only=False,
+            render_nulls=False,
         )
 
     def _bulk_save_mappings(
         self,
         mapper: Mapper[_O],
         mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+        *,
         isupdate: bool,
         isstates: bool,
         return_defaults: bool,
@@ -4726,17 +4733,17 @@ class Session(_SessionClassMethods, EventTarget):
                     mapper,
                     mappings,
                     transaction,
-                    isstates,
-                    update_changed_only,
+                    isstates=isstates,
+                    update_changed_only=update_changed_only,
                 )
             else:
                 bulk_persistence._bulk_insert(
                     mapper,
                     mappings,
                     transaction,
-                    isstates,
-                    return_defaults,
-                    render_nulls,
+                    isstates=isstates,
+                    return_defaults=return_defaults,
+                    render_nulls=render_nulls,
                 )
             transaction.commit()
 
index 62b435e9cbfe8e0d24adcd508c30b10180f98848..4d24a52ecebb6290026fba65b9d41dba8d2deedb 100644 (file)
@@ -2,6 +2,7 @@ from sqlalchemy import FetchedValue
 from sqlalchemy import ForeignKey
 from sqlalchemy import Identity
 from sqlalchemy import insert
+from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
@@ -147,6 +148,17 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
         if statement_type == "save_objects":
             eq_(objects[0].__dict__["id"], 1)
 
+    def test_bulk_save_objects_defaults_key(self):
+        User = self.classes.User
+
+        pes = [User(name=f"foo{i}") for i in range(3)]
+        s = fixture_session()
+        s.bulk_save_objects(pes, return_defaults=True)
+        key = inspect(pes[0]).key
+
+        s.commit()
+        eq_(inspect(s.get(User, 1)).key, key)
+
     def test_bulk_save_mappings_preserve_order(self):
         (User,) = self.classes("User")
 
index 96dec4a60b70d501cafcf681405061d86f6aec07..18904cc38611f7b65d6fcda27f633fee374fa4af 100644 (file)
@@ -654,6 +654,17 @@ class PickleTest(fixtures.MappedTest):
             )
             is_not_none(collections.collection_adapter(repickled.addresses))
 
+    def test_bulk_save_objects_defaults_pickle(self):
+        "Test for #11332"
+        users = self.tables.users
+
+        self.mapper_registry.map_imperatively(User, users)
+        pes = [User(name=f"foo{i}") for i in range(3)]
+        s = fixture_session()
+        s.bulk_save_objects(pes, return_defaults=True)
+        state = pickle.dumps(pes)
+        pickle.loads(state)
+
 
 class OptionsTest(_Polymorphic):
     def test_options_of_type(self):