From f4188f571df11905bb4aab107b45298c7130a0ec Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 14 Jun 2022 15:41:31 -0400 Subject: [PATCH] pickle mutable parents according to key Fixed bug in :class:`.Mutable` where pickling and unpickling of an ORM mapped instance would not correctly restore state for mappings that contained multiple :class:`.Mutable`-enabled attributes. Fixes: #8133 Change-Id: I508763e0df0d7a624e1169f9a46d7f25404add1e (cherry picked from commit 69020e416d9836fcc0bc99fcf008563263fb86f3) --- doc/build/changelog/unreleased_14/8133.rst | 7 +++ lib/sqlalchemy/ext/mutable.py | 7 +-- test/ext/test_mutable.py | 59 ++++++++++++++++++++++ 3 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/8133.rst diff --git a/doc/build/changelog/unreleased_14/8133.rst b/doc/build/changelog/unreleased_14/8133.rst new file mode 100644 index 0000000000..36da8ad8e6 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8133.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, ext + :tickets: 8133 + + Fixed bug in :class:`.Mutable` where pickling and unpickling of an ORM + mapped instance would not correctly restore state for mappings that + contained multiple :class:`.Mutable`-enabled attributes. diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index b5217a4267..934ac37a05 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -354,6 +354,7 @@ pickling process of the parent's object-relational state so that the :meth:`MutableBase._parents` collection is restored to all ``Point`` objects. """ +from collections import defaultdict import weakref from .. import event @@ -496,12 +497,12 @@ class MutableBase(object): val = state.dict.get(key, None) if val is not None: if "ext.mutable.values" not in state_dict: - state_dict["ext.mutable.values"] = [] - state_dict["ext.mutable.values"].append(val) + state_dict["ext.mutable.values"] = defaultdict(list) + state_dict["ext.mutable.values"][key].append(val) def unpickle(state, state_dict): if "ext.mutable.values" in state_dict: - for val in state_dict["ext.mutable.values"]: + for val in state_dict["ext.mutable.values"][key]: val._parents[state] = key event.listen(parent_cls, "load", load, raw=True, propagate=True) diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index 1d88deb7a0..ff167b2536 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -4,7 +4,9 @@ import pickle from sqlalchemy import event from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import util @@ -16,6 +18,7 @@ from sqlalchemy.orm import attributes from sqlalchemy.orm import column_property from sqlalchemy.orm import composite from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import Session from sqlalchemy.orm.instrumentation import ClassManager from sqlalchemy.orm.mapper import Mapper from sqlalchemy.testing import assert_raises @@ -41,6 +44,10 @@ class SubFoo(Foo): pass +class Foo2(fixtures.BasicEntity): + pass + + class FooWithEq(object): def __init__(self, **kw): for k in kw: @@ -102,6 +109,58 @@ class _MutableDictTestFixture(object): ClassManager.dispatch._clear() +class MiscTest(fixtures.TestBase): + @testing.combinations(True, False, argnames="pickleit") + def test_pickle_parent_multi_attrs(self, registry, connection, pickleit): + """test #8133""" + + local_foo = Table( + "lf", + registry.metadata, + Column("id", Integer, primary_key=True), + Column("j1", MutableDict.as_mutable(PickleType)), + Column("j2", MutableDict.as_mutable(PickleType)), + Column("j3", MutableDict.as_mutable(PickleType)), + Column("j4", MutableDict.as_mutable(PickleType)), + ) + + registry.map_imperatively(Foo2, local_foo) + registry.metadata.create_all(connection) + + with Session(connection) as sess: + + data = dict( + j1={"a": 1}, + j2={"b": 2}, + j3={"c": 3}, + j4={"d": 4}, + ) + lf = Foo2(**data) + sess.add(lf) + sess.commit() + + all_attrs = {"j1", "j2", "j3", "j4"} + for attr in all_attrs: + for loads, dumps in picklers(): + with Session(connection) as sess: + f1 = sess.scalars(select(Foo2)).first() + if pickleit: + f2 = loads(dumps(f1)) + else: + f2 = f1 + + existing_dict = getattr(f2, attr) + existing_dict["q"] = "c" + eq_( + inspect(f2).attrs[attr].history, + ([existing_dict], (), ()), + ) + for other_attr in all_attrs.difference([attr]): + a = inspect(f2).attrs[other_attr].history + b = ((), [data[other_attr]], ()) + eq_(a, b) + + class _MutableDictTestBase(_MutableDictTestFixture): run_define_tables = "each" -- 2.47.2