]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pickle mutable parents according to key
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Jun 2022 19:41:31 +0000 (15:41 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Jun 2022 20:16:43 +0000 (16:16 -0400)
Fixed bug in :class:`.Mutable` where pickling and unpickling of an ORM
mapped instance would not correctly restore state for mappings that
contained multiple :class:`.Mutable`-enabled attributes.

Fixes: #8133
Change-Id: I508763e0df0d7a624e1169f9a46d7f25404add1e
(cherry picked from commit 69020e416d9836fcc0bc99fcf008563263fb86f3)

doc/build/changelog/unreleased_14/8133.rst [new file with mode: 0644]
lib/sqlalchemy/ext/mutable.py
test/ext/test_mutable.py

diff --git a/doc/build/changelog/unreleased_14/8133.rst b/doc/build/changelog/unreleased_14/8133.rst
new file mode 100644 (file)
index 0000000..36da8ad
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, ext
+    :tickets: 8133
+
+    Fixed bug in :class:`.Mutable` where pickling and unpickling of an ORM
+    mapped instance would not correctly restore state for mappings that
+    contained multiple :class:`.Mutable`-enabled attributes.
index b5217a426779c6af08d8fbe696477c55a51d9783..934ac37a0560a82bc31c8e677cab5b53478641ee 100644 (file)
@@ -354,6 +354,7 @@ pickling process of the parent's object-relational state so that the
 :meth:`MutableBase._parents` collection is restored to all ``Point`` objects.
 
 """
+from collections import defaultdict
 import weakref
 
 from .. import event
@@ -496,12 +497,12 @@ class MutableBase(object):
             val = state.dict.get(key, None)
             if val is not None:
                 if "ext.mutable.values" not in state_dict:
-                    state_dict["ext.mutable.values"] = []
-                state_dict["ext.mutable.values"].append(val)
+                    state_dict["ext.mutable.values"] = defaultdict(list)
+                state_dict["ext.mutable.values"][key].append(val)
 
         def unpickle(state, state_dict):
             if "ext.mutable.values" in state_dict:
-                for val in state_dict["ext.mutable.values"]:
+                for val in state_dict["ext.mutable.values"][key]:
                     val._parents[state] = key
 
         event.listen(parent_cls, "load", load, raw=True, propagate=True)
index 1d88deb7a0e63b1317237ba252873d1609e574aa..ff167b25365f0e36acf1226c0de37816fa75ab14 100644 (file)
@@ -4,7 +4,9 @@ import pickle
 from sqlalchemy import event
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
+from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import util
@@ -16,6 +18,7 @@ from sqlalchemy.orm import attributes
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import composite
 from sqlalchemy.orm import declarative_base
+from sqlalchemy.orm import Session
 from sqlalchemy.orm.instrumentation import ClassManager
 from sqlalchemy.orm.mapper import Mapper
 from sqlalchemy.testing import assert_raises
@@ -41,6 +44,10 @@ class SubFoo(Foo):
     pass
 
 
+class Foo2(fixtures.BasicEntity):
+    pass
+
+
 class FooWithEq(object):
     def __init__(self, **kw):
         for k in kw:
@@ -102,6 +109,58 @@ class _MutableDictTestFixture(object):
         ClassManager.dispatch._clear()
 
 
+class MiscTest(fixtures.TestBase):
+    @testing.combinations(True, False, argnames="pickleit")
+    def test_pickle_parent_multi_attrs(self, registry, connection, pickleit):
+        """test #8133"""
+
+        local_foo = Table(
+            "lf",
+            registry.metadata,
+            Column("id", Integer, primary_key=True),
+            Column("j1", MutableDict.as_mutable(PickleType)),
+            Column("j2", MutableDict.as_mutable(PickleType)),
+            Column("j3", MutableDict.as_mutable(PickleType)),
+            Column("j4", MutableDict.as_mutable(PickleType)),
+        )
+
+        registry.map_imperatively(Foo2, local_foo)
+        registry.metadata.create_all(connection)
+
+        with Session(connection) as sess:
+
+            data = dict(
+                j1={"a": 1},
+                j2={"b": 2},
+                j3={"c": 3},
+                j4={"d": 4},
+            )
+            lf = Foo2(**data)
+            sess.add(lf)
+            sess.commit()
+
+        all_attrs = {"j1", "j2", "j3", "j4"}
+        for attr in all_attrs:
+            for loads, dumps in picklers():
+                with Session(connection) as sess:
+                    f1 = sess.scalars(select(Foo2)).first()
+                    if pickleit:
+                        f2 = loads(dumps(f1))
+                    else:
+                        f2 = f1
+
+                existing_dict = getattr(f2, attr)
+                existing_dict["q"] = "c"
+                eq_(
+                    inspect(f2).attrs[attr].history,
+                    ([existing_dict], (), ()),
+                )
+                for other_attr in all_attrs.difference([attr]):
+                    a = inspect(f2).attrs[other_attr].history
+                    b = ((), [data[other_attr]], ())
+                    eq_(a, b)
+
+
 class _MutableDictTestBase(_MutableDictTestFixture):
     run_define_tables = "each"