]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Pickling fixes for ORM / Core
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Oct 2021 16:00:52 +0000 (12:00 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Oct 2021 17:52:23 +0000 (13:52 -0400)
Fixed regression where ORM loaded objects could not be pickled in cases
where loader options making use of ``"*"`` were used in certain
combinations, such as combining the :func:`_orm.joinedload` loader strategy
with ``raiseload('*')`` of sub-elements.

Fixes: #7134
Fixed issue where SQL queries using the
:meth:`_functions.FunctionElement.within_group` construct could not be
pickled, typically when using the ``sqlalchemy.ext.serializer`` extension
but also for general generic pickling.

Fixes: #6520
Change-Id: Ib73fd49c875e6da9898493c190f610e68b88ec72

doc/build/changelog/unreleased_14/6520.rst [new file with mode: 0644]
doc/build/changelog/unreleased_14/7134.rst [new file with mode: 0644]
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/sql/elements.py
test/orm/test_pickled.py
test/sql/test_functions.py

diff --git a/doc/build/changelog/unreleased_14/6520.rst b/doc/build/changelog/unreleased_14/6520.rst
new file mode 100644 (file)
index 0000000..88defe7
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 6520
+
+    Fixed issue where SQL queries using the
+    :meth:`_functions.FunctionElement.within_group` construct could not be
+    pickled, typically when using the ``sqlalchemy.ext.serializer`` extension
+    but also for general generic pickling.
diff --git a/doc/build/changelog/unreleased_14/7134.rst b/doc/build/changelog/unreleased_14/7134.rst
new file mode 100644 (file)
index 0000000..e785db1
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 7134
+
+    Fixed regression where ORM loaded objects could not be pickled in cases
+    where loader options making use of ``"*"`` were used in certain
+    combinations, such as combining the :func:`_orm.joinedload` loader strategy
+    with ``raiseload('*')`` of sub-elements.
+
index 0327605d529851077815a040e181146881faa44e..47bec83c9451979e8ac608b8b457b2ae83135b12 100644 (file)
@@ -11,7 +11,7 @@
 from itertools import chain
 import logging
 
-from .base import class_mapper
+from . import base as orm_base
 from .. import exc
 from .. import inspection
 from .. import util
@@ -66,7 +66,7 @@ class PathRegistry(HasCacheKey):
 
     def __eq__(self, other):
         try:
-            return other is not None and self.path == other.path
+            return other is not None and self.path == other._path_for_compare
         except AttributeError:
             util.warn(
                 "Comparison of PathRegistry to %r is not supported"
@@ -76,7 +76,7 @@ class PathRegistry(HasCacheKey):
 
     def __ne__(self, other):
         try:
-            return other is None or self.path != other.path
+            return other is None or self.path != other._path_for_compare
         except AttributeError:
             util.warn(
                 "Comparison of PathRegistry to %r is not supported"
@@ -84,6 +84,10 @@ class PathRegistry(HasCacheKey):
             )
             return True
 
+    @property
+    def _path_for_compare(self):
+        return self.path
+
     def set(self, attributes, key, value):
         log.debug("set '%s' on path '%s' to '%s'", key, self, value)
         attributes[(key, self.natural_path)] = value
@@ -131,21 +135,45 @@ class PathRegistry(HasCacheKey):
     def _serialize_path(cls, path):
         return list(
             zip(
-                [m.class_ for m in [path[i] for i in range(0, len(path), 2)]],
-                [path[i].key for i in range(1, len(path), 2)] + [None],
+                [
+                    m.class_ if (m.is_mapper or m.is_aliased_class) else str(m)
+                    for m in [path[i] for i in range(0, len(path), 2)]
+                ],
+                [
+                    path[i].key if (path[i].is_property) else str(path[i])
+                    for i in range(1, len(path), 2)
+                ]
+                + [None],
             )
         )
 
     @classmethod
     def _deserialize_path(cls, path):
+        def _deserialize_mapper_token(mcls):
+            return (
+                # note: we likely dont want configure=True here however
+                # this is maintained at the moment for backwards compatibility
+                orm_base._inspect_mapped_class(mcls, configure=True)
+                if mcls not in PathToken._intern
+                else PathToken._intern[mcls]
+            )
+
+        def _deserialize_key_token(mcls, key):
+            if key is None:
+                return None
+            elif key in PathToken._intern:
+                return PathToken._intern[key]
+            else:
+                return orm_base._inspect_mapped_class(
+                    mcls, configure=True
+                ).attrs[key]
+
         p = tuple(
             chain(
                 *[
                     (
-                        class_mapper(mcls),
-                        class_mapper(mcls).attrs[key]
-                        if key is not None
-                        else None,
+                        _deserialize_mapper_token(mcls),
+                        _deserialize_key_token(mcls, key),
                     )
                     for mcls, key in path
                 ]
@@ -224,13 +252,16 @@ class RootRegistry(PathRegistry):
     is_root = True
 
     def __getitem__(self, entity):
-        return entity._path_registry
+        if entity in PathToken._intern:
+            return PathToken._intern[entity]
+        else:
+            return entity._path_registry
 
 
 PathRegistry.root = RootRegistry()
 
 
-class PathToken(HasCacheKey, str):
+class PathToken(orm_base.InspectionAttr, HasCacheKey, str):
     """cacheable string token"""
 
     _intern = {}
@@ -238,6 +269,10 @@ class PathToken(HasCacheKey, str):
     def _gen_cache_key(self, anon_map, bindparams):
         return (str(self),)
 
+    @property
+    def _path_for_compare(self):
+        return None
+
     @classmethod
     def intern(cls, strvalue):
         if strvalue in cls._intern:
@@ -445,6 +480,8 @@ class AbstractEntityRegistry(PathRegistry):
     def __getitem__(self, entity):
         if isinstance(entity, (int, slice)):
             return self.path[entity]
+        elif entity in PathToken._intern:
+            return TokenRegistry(self, PathToken._intern[entity])
         else:
             return PropRegistry(self, entity)
 
index 6f1756af341c77f76f979e27fb8fa54d5cc51619..e49665019a36940ca8af0fe1432df0e78753c359 100644 (file)
@@ -4319,6 +4319,9 @@ class WithinGroup(ColumnElement):
                 *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole
             )
 
+    def __reduce__(self):
+        return self.__class__, (self.element,) + tuple(self.order_by)
+
     def over(self, partition_by=None, order_by=None, range_=None, rows=None):
         """Produce an OVER clause against this :class:`.WithinGroup`
         construct.
index 8f4683d15c6eb4d9c38621e442c426256f565b68..e33ec9a5c5a52e2d23edff91916149a3053e715a 100644 (file)
@@ -494,42 +494,50 @@ class PickleTest(fixtures.MappedTest):
         eq_(sa.inspect(u2).info["some_key"], "value")
 
     @testing.requires.non_broken_pickle
-    def test_unbound_options(self):
+    @testing.combinations(
+        lambda User: sa.orm.joinedload(User.addresses),
+        lambda: sa.orm.joinedload("addresses"),
+        lambda: sa.orm.defer("name"),
+        lambda User: sa.orm.defer(User.name),
+        lambda Address: sa.orm.joinedload("addresses").joinedload(
+            Address.dingaling
+        ),
+        lambda: sa.orm.joinedload("addresses").raiseload("*"),
+        lambda: sa.orm.raiseload("*"),
+    )
+    def test_unbound_options(self, test_case):
         sess, User, Address, Dingaling = self._option_test_fixture()
 
-        for opt in [
-            sa.orm.joinedload(User.addresses),
-            sa.orm.joinedload("addresses"),
-            sa.orm.defer("name"),
-            sa.orm.defer(User.name),
-            sa.orm.joinedload("addresses").joinedload(Address.dingaling),
-        ]:
-            opt2 = pickle.loads(pickle.dumps(opt))
-            eq_(opt.path, opt2.path)
+        opt = testing.resolve_lambda(test_case, User=User, Address=Address)
+        opt2 = pickle.loads(pickle.dumps(opt))
+        eq_(opt.path, opt2.path)
 
         u1 = sess.query(User).options(opt).first()
         pickle.loads(pickle.dumps(u1))
 
     @testing.requires.non_broken_pickle
-    def test_bound_options(self):
+    @testing.combinations(
+        lambda User: sa.orm.Load(User).joinedload(User.addresses),
+        lambda User: sa.orm.Load(User).joinedload("addresses"),
+        lambda User: sa.orm.Load(User).joinedload("addresses").raiseload("*"),
+        lambda User: sa.orm.Load(User).defer("name"),
+        lambda User: sa.orm.Load(User).defer(User.name),
+        lambda User, Address: sa.orm.Load(User)
+        .joinedload("addresses")
+        .joinedload(Address.dingaling),
+        lambda User, Address: sa.orm.Load(User)
+        .joinedload("addresses", innerjoin=True)
+        .joinedload(Address.dingaling),
+    )
+    def test_bound_options(self, test_case):
         sess, User, Address, Dingaling = self._option_test_fixture()
 
-        for opt in [
-            sa.orm.Load(User).joinedload(User.addresses),
-            sa.orm.Load(User).joinedload("addresses"),
-            sa.orm.Load(User).defer("name"),
-            sa.orm.Load(User).defer(User.name),
-            sa.orm.Load(User)
-            .joinedload("addresses")
-            .joinedload(Address.dingaling),
-            sa.orm.Load(User)
-            .joinedload("addresses", innerjoin=True)
-            .joinedload(Address.dingaling),
-        ]:
-            opt2 = pickle.loads(pickle.dumps(opt))
-            eq_(opt.path, opt2.path)
-            eq_(opt.context.keys(), opt2.context.keys())
-            eq_(opt.local_opts, opt2.local_opts)
+        opt = testing.resolve_lambda(test_case, User=User, Address=Address)
+
+        opt2 = pickle.loads(pickle.dumps(opt))
+        eq_(opt.path, opt2.path)
+        eq_(opt.context.keys(), opt2.context.keys())
+        eq_(opt.local_opts, opt2.local_opts)
 
         u1 = sess.query(User).options(opt).first()
         pickle.loads(pickle.dumps(u1))
index 43b505c997cdc3f2725af7df007d6dae8ebd1374..f3fb724c073fd3eae8a7058f4a3c7d9fe49ce171 100644 (file)
@@ -542,6 +542,31 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "row_number() OVER ()",
         )
 
+    def test_pickle_within_group(self):
+        """test #6520"""
+
+        # TODO: the test/sql package lacks a comprehensive pickling
+        # test suite even though there are __reduce__ methods in several
+        # places in sql/elements.py.   likely as part of
+        # test/sql/test_compare.py might be a place this can happen but
+        # this still relies upon a strategy for table metadata as we have
+        # in serializer.
+
+        f1 = func.percentile_cont(literal(1)).within_group()
+
+        self.assert_compile(
+            util.pickle.loads(util.pickle.dumps(f1)),
+            "percentile_cont(:param_1) WITHIN GROUP (ORDER BY )",
+        )
+
+        f1 = func.percentile_cont(literal(1)).within_group(
+            column("q"), column("p").desc()
+        )
+        self.assert_compile(
+            util.pickle.loads(util.pickle.dumps(f1)),
+            "percentile_cont(:param_1) WITHIN GROUP (ORDER BY q, p DESC)",
+        )
+
     def test_functions_with_cols(self):
         users = table(
             "users", column("id"), column("name"), column("fullname")