From: Mike Bayer Date: Tue, 14 Jun 2022 19:41:31 +0000 (-0400) Subject: pickle mutable parents according to key X-Git-Tag: rel_2_0_0b1~235^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=042bccd2e4d6fbfcdf70ede760b29f78771f4b22;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pickle mutable parents according to key Fixed bug in :class:`.Mutable` where pickling and unpickling of an ORM mapped instance would not correctly restore state for mappings that contained multiple :class:`.Mutable`-enabled attributes. Fixes: #8133 Change-Id: I508763e0df0d7a624e1169f9a46d7f25404add1e --- diff --git a/doc/build/changelog/unreleased_14/8133.rst b/doc/build/changelog/unreleased_14/8133.rst new file mode 100644 index 0000000000..36da8ad8e6 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8133.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, ext + :tickets: 8133 + + Fixed bug in :class:`.Mutable` where pickling and unpickling of an ORM + mapped instance would not correctly restore state for mappings that + contained multiple :class:`.Mutable`-enabled attributes. diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index ba7f9b0a41..1ae5aee8b3 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -355,6 +355,7 @@ pickling process of the parent's object-relational state so that the :meth:`MutableBase._parents` collection is restored to all ``Point`` objects. """ +from collections import defaultdict import weakref from .. import event @@ -496,12 +497,12 @@ class MutableBase: val = state.dict.get(key, None) if val is not None: if "ext.mutable.values" not in state_dict: - state_dict["ext.mutable.values"] = [] - state_dict["ext.mutable.values"].append(val) + state_dict["ext.mutable.values"] = defaultdict(list) + state_dict["ext.mutable.values"][key].append(val) def unpickle(state, state_dict): if "ext.mutable.values" in state_dict: - for val in state_dict["ext.mutable.values"]: + for val in state_dict["ext.mutable.values"][key]: val._parents[state] = key event.listen(parent_cls, "load", load, raw=True, propagate=True) diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index 81cecb08b3..1ef9afbd87 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -4,7 +4,9 @@ import pickle from sqlalchemy import event from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.ext.mutable import MutableComposite @@ -15,6 +17,7 @@ from sqlalchemy.orm import attributes from sqlalchemy.orm import column_property from sqlalchemy.orm import composite from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import Session from sqlalchemy.orm.instrumentation import ClassManager from sqlalchemy.orm.mapper import Mapper from sqlalchemy.testing import assert_raises @@ -40,6 +43,10 @@ class SubFoo(Foo): pass +class Foo2(fixtures.BasicEntity): + pass + + class FooWithEq: def __init__(self, **kw): for k in kw: @@ -101,6 +108,58 @@ class _MutableDictTestFixture: ClassManager.dispatch._clear() +class MiscTest(fixtures.TestBase): + @testing.combinations(True, False, argnames="pickleit") + def test_pickle_parent_multi_attrs(self, registry, connection, pickleit): + """test #8133""" + + local_foo = Table( + "lf", + registry.metadata, + Column("id", Integer, primary_key=True), + Column("j1", MutableDict.as_mutable(PickleType)), + Column("j2", MutableDict.as_mutable(PickleType)), + Column("j3", MutableDict.as_mutable(PickleType)), + Column("j4", MutableDict.as_mutable(PickleType)), + ) + + registry.map_imperatively(Foo2, local_foo) + registry.metadata.create_all(connection) + + with Session(connection) as sess: + + data = dict( + j1={"a": 1}, + j2={"b": 2}, + j3={"c": 3}, + j4={"d": 4}, + ) + lf = Foo2(**data) + sess.add(lf) + sess.commit() + + all_attrs = {"j1", "j2", "j3", "j4"} + for attr in all_attrs: + for loads, dumps in picklers(): + with Session(connection) as sess: + f1 = sess.scalars(select(Foo2)).first() + if pickleit: + f2 = loads(dumps(f1)) + else: + f2 = f1 + + existing_dict = getattr(f2, attr) + existing_dict["q"] = "c" + eq_( + inspect(f2).attrs[attr].history, + ([existing_dict], (), ()), + ) + for other_attr in all_attrs.difference([attr]): + a = inspect(f2).attrs[other_attr].history + b = ((), [data[other_attr]], ()) + eq_(a, b) + + class _MutableDictTestBase(_MutableDictTestFixture): run_define_tables = "each"