]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Serialize the context dictionary in Load objects
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Aug 2019 15:44:09 +0000 (11:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Aug 2019 18:32:32 +0000 (14:32 -0400)
Fixed bug where :class:`.Load` objects were not pickleable due to
mapper/relationship state in the internal context dictionary.  These
objects are now converted to picklable using similar techniques as that of
other elements within the loader option system that have long been
serializable.

Fixes: #4823
Change-Id: Id2a0d8b640ac475c86d6416ad540671e66d410e5
(cherry picked from commit cd2ccee9d807eb601db2d242ce4cdfa8acb98111)

doc/build/changelog/unreleased_13/4823.rst [new file with mode: 0644]
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/orm/strategy_options.py
test/orm/test_pickled.py
test/orm/test_utils.py

diff --git a/doc/build/changelog/unreleased_13/4823.rst b/doc/build/changelog/unreleased_13/4823.rst
new file mode 100644 (file)
index 0000000..7541330
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 4823
+
+    Fixed bug where :class:`.Load` objects were not pickleable due to
+    mapper/relationship state in the internal context dictionary.  These
+    objects are now converted to picklable using similar techniques as that of
+    other elements within the loader option system that have long been
+    serializable.
index 4803dbecb99f246e0838b3805e6e72062f663ce4..2f680a3a162e4955cbf0f85182b782d10f961abc 100644 (file)
@@ -100,8 +100,8 @@ class PathRegistry(object):
     def __reduce__(self):
         return _unreduce_path, (self.serialize(),)
 
-    def serialize(self):
-        path = self.path
+    @classmethod
+    def _serialize_path(cls, path):
         return list(
             zip(
                 [m.class_ for m in [path[i] for i in range(0, len(path), 2)]],
@@ -110,10 +110,7 @@ class PathRegistry(object):
         )
 
     @classmethod
-    def deserialize(cls, path):
-        if path is None:
-            return None
-
+    def _deserialize_path(cls, path):
         p = tuple(
             chain(
                 *[
@@ -129,6 +126,35 @@ class PathRegistry(object):
         )
         if p and p[-1] is None:
             p = p[0:-1]
+        return p
+
+    @classmethod
+    def serialize_context_dict(cls, dict_, tokens):
+        return [
+            ((key, cls._serialize_path(path)), value)
+            for (key, path), value in [
+                (k, v)
+                for k, v in dict_.items()
+                if isinstance(k, tuple) and k[0] in tokens
+            ]
+        ]
+
+    @classmethod
+    def deserialize_context_dict(cls, serialized):
+        return util.OrderedDict(
+            ((key, tuple(cls._deserialize_path(path))), value)
+            for (key, path), value in serialized
+        )
+
+    def serialize(self):
+        path = self.path
+        return self._serialize_path(path)
+
+    @classmethod
+    def deserialize(cls, path):
+        if path is None:
+            return None
+        p = cls._deserialize_path(path)
         return cls.coerce(p)
 
     @classmethod
index 178342d227893fef813d913c960ff5d1ddbfa19e..3928cc5a495cdaca3edb80230f678251c1b8a2f5 100644 (file)
@@ -464,12 +464,16 @@ class Load(Generative, MapperOption):
 
     def __getstate__(self):
         d = self.__dict__.copy()
+        d["context"] = PathRegistry.serialize_context_dict(
+            d["context"], ("loader",)
+        )
         d["path"] = self.path.serialize()
         return d
 
     def __setstate__(self, state):
         self.__dict__.update(state)
         self.path = PathRegistry.deserialize(self.path)
+        self.context = PathRegistry.deserialize_context_dict(self.context)
 
     def _chop_path(self, to_chop, path):
         i = -1
index f2ef8e2e9a7b0c461c5b06109e29c06fb1173e1f..399c881ac5bfb9ea4eb61827da2619a255a1ece9 100644 (file)
@@ -96,6 +96,32 @@ class PickleTest(fixtures.MappedTest):
             test_needs_fk=True,
         )
 
+    def _option_test_fixture(self):
+        users, addresses, dingalings = (
+            self.tables.users,
+            self.tables.addresses,
+            self.tables.dingalings,
+        )
+
+        mapper(
+            User,
+            users,
+            properties={"addresses": relationship(Address, backref="user")},
+        )
+        mapper(
+            Address,
+            addresses,
+            properties={"dingaling": relationship(Dingaling)},
+        )
+        mapper(Dingaling, dingalings)
+        sess = create_session()
+        u1 = User(name="ed")
+        u1.addresses.append(Address(email_address="ed@bar.com"))
+        sess.add(u1)
+        sess.flush()
+        sess.expunge_all()
+        return sess, User, Address, Dingaling
+
     def test_transient(self):
         users, addresses = (self.tables.users, self.tables.addresses)
 
@@ -418,40 +444,65 @@ class PickleTest(fixtures.MappedTest):
         eq_(sa.inspect(u2).info["some_key"], "value")
 
     @testing.requires.non_broken_pickle
-    def test_options_with_descriptors(self):
-        users, addresses, dingalings = (
-            self.tables.users,
-            self.tables.addresses,
-            self.tables.dingalings,
-        )
+    def test_unbound_options(self):
+        sess, User, Address, Dingaling = self._option_test_fixture()
 
-        mapper(
-            User,
-            users,
-            properties={"addresses": relationship(Address, backref="user")},
-        )
-        mapper(
-            Address,
-            addresses,
-            properties={"dingaling": relationship(Dingaling)},
-        )
-        mapper(Dingaling, dingalings)
-        sess = create_session()
-        u1 = User(name="ed")
-        u1.addresses.append(Address(email_address="ed@bar.com"))
-        sess.add(u1)
-        sess.flush()
-        sess.expunge_all()
+        for opt in [
+            sa.orm.joinedload(User.addresses),
+            sa.orm.joinedload("addresses"),
+            sa.orm.defer("name"),
+            sa.orm.defer(User.name),
+            sa.orm.joinedload("addresses").joinedload(Address.dingaling),
+        ]:
+            opt2 = pickle.loads(pickle.dumps(opt))
+            eq_(opt.path, opt2.path)
+
+        u1 = sess.query(User).options(opt).first()
+        pickle.loads(pickle.dumps(u1))
+
+    @testing.requires.non_broken_pickle
+    def test_bound_options(self):
+        sess, User, Address, Dingaling = self._option_test_fixture()
+
+        for opt in [
+            sa.orm.Load(User).joinedload(User.addresses),
+            sa.orm.Load(User).joinedload("addresses"),
+            sa.orm.Load(User).defer("name"),
+            sa.orm.Load(User).defer(User.name),
+            sa.orm.Load(User)
+            .joinedload("addresses")
+            .joinedload(Address.dingaling),
+            sa.orm.Load(User)
+            .joinedload("addresses", innerjoin=True)
+            .joinedload(Address.dingaling),
+        ]:
+            opt2 = pickle.loads(pickle.dumps(opt))
+            eq_(opt.path, opt2.path)
+            eq_(opt.context.keys(), opt2.context.keys())
+            eq_(opt.local_opts, opt2.local_opts)
+
+        u1 = sess.query(User).options(opt).first()
+        pickle.loads(pickle.dumps(u1))
+
+    @testing.requires.non_broken_pickle
+    def test_became_bound_options(self):
+        sess, User, Address, Dingaling = self._option_test_fixture()
 
         for opt in [
             sa.orm.joinedload(User.addresses),
             sa.orm.joinedload("addresses"),
             sa.orm.defer("name"),
             sa.orm.defer(User.name),
-            sa.orm.joinedload("addresses"Address.dingaling),
+            sa.orm.joinedload("addresses").joinedload(Address.dingaling),
         ]:
+            q = sess.query(User).options(opt)
+            opt = [
+                v for v in q._attributes.values() if isinstance(v, sa.orm.Load)
+            ][0]
+
             opt2 = pickle.loads(pickle.dumps(opt))
             eq_(opt.path, opt2.path)
+            eq_(opt.local_opts, opt2.local_opts)
 
         u1 = sess.query(User).options(opt).first()
         pickle.loads(pickle.dumps(u1))
index 45dd2e38bda72ab65e17f41e62a5cdc780e04404..e47fc3f267a2e426e15eb4a458afdf6ea8e5d4bc 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import Table
+from sqlalchemy import util
 from sqlalchemy.ext.hybrid import hybrid_method
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import aliased
@@ -743,6 +744,62 @@ class PathRegistryTest(_fixtures.FixtureTest):
         eq_(p2.serialize(), [(User, "addresses"), (Address, None)])
         eq_(p3.serialize(), [(User, "addresses")])
 
+    def test_serialize_context_dict(self):
+        reg = util.OrderedDict()
+        umapper = inspect(self.classes.User)
+        amapper = inspect(self.classes.Address)
+
+        p1 = PathRegistry.coerce((umapper, umapper.attrs.addresses))
+        p2 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper))
+        p3 = PathRegistry.coerce((amapper, amapper.attrs.email_address))
+
+        p1.set(reg, "p1key", "p1value")
+        p2.set(reg, "p2key", "p2value")
+        p3.set(reg, "p3key", "p3value")
+        eq_(
+            reg,
+            {
+                ("p1key", p1.path): "p1value",
+                ("p2key", p2.path): "p2value",
+                ("p3key", p3.path): "p3value",
+            },
+        )
+
+        serialized = PathRegistry.serialize_context_dict(
+            reg, ("p1key", "p2key")
+        )
+        eq_(
+            serialized,
+            [
+                (("p1key", p1.serialize()), "p1value"),
+                (("p2key", p2.serialize()), "p2value"),
+            ],
+        )
+
+    def test_deseralize_context_dict(self):
+        umapper = inspect(self.classes.User)
+        amapper = inspect(self.classes.Address)
+
+        p1 = PathRegistry.coerce((umapper, umapper.attrs.addresses))
+        p2 = PathRegistry.coerce((umapper, umapper.attrs.addresses, amapper))
+        p3 = PathRegistry.coerce((amapper, amapper.attrs.email_address))
+
+        serialized = [
+            (("p1key", p1.serialize()), "p1value"),
+            (("p2key", p2.serialize()), "p2value"),
+            (("p3key", p3.serialize()), "p3value"),
+        ]
+        deserialized = PathRegistry.deserialize_context_dict(serialized)
+
+        eq_(
+            deserialized,
+            {
+                ("p1key", p1.path): "p1value",
+                ("p2key", p2.path): "p2value",
+                ("p3key", p3.path): "p3value",
+            },
+        )
+
     def test_deseralize(self):
         User = self.classes.User
         Address = self.classes.Address