From: Mike Bayer Date: Wed, 13 Oct 2021 16:00:52 +0000 (-0400) Subject: Pickling fixes for ORM / Core X-Git-Tag: rel_1_4_26~19^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=53ad3cf4e9b02f841fff960ec95870110f6c7bcb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Pickling fixes for ORM / Core Fixed regression where ORM loaded objects could not be pickled in cases where loader options making use of ``"*"`` were used in certain combinations, such as combining the :func:`_orm.joinedload` loader strategy with ``raiseload('*')`` of sub-elements. Fixes: #7134 Fixed issue where SQL queries using the :meth:`_functions.FunctionElement.within_group` construct could not be pickled, typically when using the ``sqlalchemy.ext.serializer`` extension but also for general generic pickling. Fixes: #6520 Change-Id: Ib73fd49c875e6da9898493c190f610e68b88ec72 --- diff --git a/doc/build/changelog/unreleased_14/6520.rst b/doc/build/changelog/unreleased_14/6520.rst new file mode 100644 index 0000000000..88defe7660 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6520.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 6520 + + Fixed issue where SQL queries using the + :meth:`_functions.FunctionElement.within_group` construct could not be + pickled, typically when using the ``sqlalchemy.ext.serializer`` extension + but also for general generic pickling. diff --git a/doc/build/changelog/unreleased_14/7134.rst b/doc/build/changelog/unreleased_14/7134.rst new file mode 100644 index 0000000000..e785db17d9 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7134.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 7134 + + Fixed regression where ORM loaded objects could not be pickled in cases + where loader options making use of ``"*"`` were used in certain + combinations, such as combining the :func:`_orm.joinedload` loader strategy + with ``raiseload('*')`` of sub-elements. + diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 0327605d52..47bec83c94 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -11,7 +11,7 @@ from itertools import chain import logging -from .base import class_mapper +from . import base as orm_base from .. import exc from .. import inspection from .. import util @@ -66,7 +66,7 @@ class PathRegistry(HasCacheKey): def __eq__(self, other): try: - return other is not None and self.path == other.path + return other is not None and self.path == other._path_for_compare except AttributeError: util.warn( "Comparison of PathRegistry to %r is not supported" @@ -76,7 +76,7 @@ class PathRegistry(HasCacheKey): def __ne__(self, other): try: - return other is None or self.path != other.path + return other is None or self.path != other._path_for_compare except AttributeError: util.warn( "Comparison of PathRegistry to %r is not supported" @@ -84,6 +84,10 @@ class PathRegistry(HasCacheKey): ) return True + @property + def _path_for_compare(self): + return self.path + def set(self, attributes, key, value): log.debug("set '%s' on path '%s' to '%s'", key, self, value) attributes[(key, self.natural_path)] = value @@ -131,21 +135,45 @@ class PathRegistry(HasCacheKey): def _serialize_path(cls, path): return list( zip( - [m.class_ for m in [path[i] for i in range(0, len(path), 2)]], - [path[i].key for i in range(1, len(path), 2)] + [None], + [ + m.class_ if (m.is_mapper or m.is_aliased_class) else str(m) + for m in [path[i] for i in range(0, len(path), 2)] + ], + [ + path[i].key if (path[i].is_property) else str(path[i]) + for i in range(1, len(path), 2) + ] + + [None], ) ) @classmethod def _deserialize_path(cls, path): + def _deserialize_mapper_token(mcls): + return ( + # note: we likely dont want configure=True here however + # this is maintained at the moment for backwards compatibility + orm_base._inspect_mapped_class(mcls, configure=True) + if mcls not in PathToken._intern + else PathToken._intern[mcls] + ) + + def _deserialize_key_token(mcls, key): + if key is None: + return None + elif key in PathToken._intern: + return PathToken._intern[key] + else: + return orm_base._inspect_mapped_class( + mcls, configure=True + ).attrs[key] + p = tuple( chain( *[ ( - class_mapper(mcls), - class_mapper(mcls).attrs[key] - if key is not None - else None, + _deserialize_mapper_token(mcls), + _deserialize_key_token(mcls, key), ) for mcls, key in path ] @@ -224,13 +252,16 @@ class RootRegistry(PathRegistry): is_root = True def __getitem__(self, entity): - return entity._path_registry + if entity in PathToken._intern: + return PathToken._intern[entity] + else: + return entity._path_registry PathRegistry.root = RootRegistry() -class PathToken(HasCacheKey, str): +class PathToken(orm_base.InspectionAttr, HasCacheKey, str): """cacheable string token""" _intern = {} @@ -238,6 +269,10 @@ class PathToken(HasCacheKey, str): def _gen_cache_key(self, anon_map, bindparams): return (str(self),) + @property + def _path_for_compare(self): + return None + @classmethod def intern(cls, strvalue): if strvalue in cls._intern: @@ -445,6 +480,8 @@ class AbstractEntityRegistry(PathRegistry): def __getitem__(self, entity): if isinstance(entity, (int, slice)): return self.path[entity] + elif entity in PathToken._intern: + return TokenRegistry(self, PathToken._intern[entity]) else: return PropRegistry(self, entity) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 6f1756af34..e49665019a 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4319,6 +4319,9 @@ class WithinGroup(ColumnElement): *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole ) + def __reduce__(self): + return self.__class__, (self.element,) + tuple(self.order_by) + def over(self, partition_by=None, order_by=None, range_=None, rows=None): """Produce an OVER clause against this :class:`.WithinGroup` construct. diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 8f4683d15c..e33ec9a5c5 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -494,42 +494,50 @@ class PickleTest(fixtures.MappedTest): eq_(sa.inspect(u2).info["some_key"], "value") @testing.requires.non_broken_pickle - def test_unbound_options(self): + @testing.combinations( + lambda User: sa.orm.joinedload(User.addresses), + lambda: sa.orm.joinedload("addresses"), + lambda: sa.orm.defer("name"), + lambda User: sa.orm.defer(User.name), + lambda Address: sa.orm.joinedload("addresses").joinedload( + Address.dingaling + ), + lambda: sa.orm.joinedload("addresses").raiseload("*"), + lambda: sa.orm.raiseload("*"), + ) + def test_unbound_options(self, test_case): sess, User, Address, Dingaling = self._option_test_fixture() - for opt in [ - sa.orm.joinedload(User.addresses), - sa.orm.joinedload("addresses"), - sa.orm.defer("name"), - sa.orm.defer(User.name), - sa.orm.joinedload("addresses").joinedload(Address.dingaling), - ]: - opt2 = pickle.loads(pickle.dumps(opt)) - eq_(opt.path, opt2.path) + opt = testing.resolve_lambda(test_case, User=User, Address=Address) + opt2 = pickle.loads(pickle.dumps(opt)) + eq_(opt.path, opt2.path) u1 = sess.query(User).options(opt).first() pickle.loads(pickle.dumps(u1)) @testing.requires.non_broken_pickle - def test_bound_options(self): + @testing.combinations( + lambda User: sa.orm.Load(User).joinedload(User.addresses), + lambda User: sa.orm.Load(User).joinedload("addresses"), + lambda User: sa.orm.Load(User).joinedload("addresses").raiseload("*"), + lambda User: sa.orm.Load(User).defer("name"), + lambda User: sa.orm.Load(User).defer(User.name), + lambda User, Address: sa.orm.Load(User) + .joinedload("addresses") + .joinedload(Address.dingaling), + lambda User, Address: sa.orm.Load(User) + .joinedload("addresses", innerjoin=True) + .joinedload(Address.dingaling), + ) + def test_bound_options(self, test_case): sess, User, Address, Dingaling = self._option_test_fixture() - for opt in [ - sa.orm.Load(User).joinedload(User.addresses), - sa.orm.Load(User).joinedload("addresses"), - sa.orm.Load(User).defer("name"), - sa.orm.Load(User).defer(User.name), - sa.orm.Load(User) - .joinedload("addresses") - .joinedload(Address.dingaling), - sa.orm.Load(User) - .joinedload("addresses", innerjoin=True) - .joinedload(Address.dingaling), - ]: - opt2 = pickle.loads(pickle.dumps(opt)) - eq_(opt.path, opt2.path) - eq_(opt.context.keys(), opt2.context.keys()) - eq_(opt.local_opts, opt2.local_opts) + opt = testing.resolve_lambda(test_case, User=User, Address=Address) + + opt2 = pickle.loads(pickle.dumps(opt)) + eq_(opt.path, opt2.path) + eq_(opt.context.keys(), opt2.context.keys()) + eq_(opt.local_opts, opt2.local_opts) u1 = sess.query(User).options(opt).first() pickle.loads(pickle.dumps(u1)) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 43b505c997..f3fb724c07 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -542,6 +542,31 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "row_number() OVER ()", ) + def test_pickle_within_group(self): + """test #6520""" + + # TODO: the test/sql package lacks a comprehensive pickling + # test suite even though there are __reduce__ methods in several + # places in sql/elements.py. likely as part of + # test/sql/test_compare.py might be a place this can happen but + # this still relies upon a strategy for table metadata as we have + # in serializer. + + f1 = func.percentile_cont(literal(1)).within_group() + + self.assert_compile( + util.pickle.loads(util.pickle.dumps(f1)), + "percentile_cont(:param_1) WITHIN GROUP (ORDER BY )", + ) + + f1 = func.percentile_cont(literal(1)).within_group( + column("q"), column("p").desc() + ) + self.assert_compile( + util.pickle.loads(util.pickle.dumps(f1)), + "percentile_cont(:param_1) WITHIN GROUP (ORDER BY q, p DESC)", + ) + def test_functions_with_cols(self): users = table( "users", column("id"), column("name"), column("fullname")