From: Mike Bayer Date: Tue, 15 Oct 2024 12:20:25 +0000 (-0400) Subject: add tests for pickling types inside an expression, some reduce methods X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fa568215788c274eb2d178b6eb180ab1f7955c01;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add tests for pickling types inside an expression, some reduce methods Fixed regression from 1.4 where some datatypes such as those derived from :class:`.TypeDecorator` could not be pickled when they were part of a larger SQL expression composition due to internal supporting structures themselves not being pickleable. Fixes: #12002 Change-Id: I016e37b0c62071413f24c9aac35f6ecf475becaa --- diff --git a/doc/build/changelog/unreleased_20/12002.rst b/doc/build/changelog/unreleased_20/12002.rst new file mode 100644 index 0000000000..49ac701759 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12002.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql, regression + :tickets: 12002 + + Fixed regression from 1.4 where some datatypes such as those derived from + :class:`.TypeDecorator` could not be pickled when they were part of a + larger SQL expression composition due to internal supporting structures + themselves not being pickleable. diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 3367aab64c..9f40905fa6 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -183,6 +183,9 @@ class TypeEngine(Visitable, Generic[_T]): self.expr = expr self.type = expr.type + def __reduce__(self) -> Any: + return self.__class__, (self.expr,) + @util.preload_module("sqlalchemy.sql.default_comparator") def operate( self, op: OperatorType, *other: Any, **kwargs: Any @@ -1721,20 +1724,38 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types return super().reverse_operate(op, other, **kwargs) + @staticmethod + def _reduce_td_comparator( + impl: TypeEngine[Any], expr: ColumnElement[_T] + ) -> Any: + return TypeDecorator._create_td_comparator_type(impl)(expr) + + @staticmethod + def _create_td_comparator_type( + impl: TypeEngine[Any], + ) -> _ComparatorFactory[Any]: + + def __reduce__(self: TypeDecorator.Comparator[Any]) -> Any: + return (TypeDecorator._reduce_td_comparator, (impl, self.expr)) + + return type( + "TDComparator", + (TypeDecorator.Comparator, impl.comparator_factory), # type: ignore # noqa: E501 + {"__reduce__": __reduce__}, + ) + @property def comparator_factory( # type: ignore # mypy properties bug self, ) -> _ComparatorFactory[Any]: if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__: # type: ignore # noqa: E501 - return self.impl.comparator_factory + return self.impl_instance.comparator_factory else: # reconcile the Comparator class on the impl with that - # of TypeDecorator - return type( - "TDComparator", - (TypeDecorator.Comparator, self.impl.comparator_factory), # type: ignore # noqa: E501 - {}, - ) + # of TypeDecorator. + # the use of multiple staticmethods is to support repeated + # pickling of the Comparator itself + return TypeDecorator._create_td_comparator_type(self.impl_instance) def _copy_with_check(self) -> Self: tt = self.copy() diff --git a/test/sql/test_types.py b/test/sql/test_types.py index e47b85029a..f5a042e32a 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -512,6 +512,11 @@ class AsGenericTest(fixtures.TestBase): assert isinstance(gentype, TypeEngine) +class SomeTypeDecorator(TypeDecorator): + impl = String() + cache_ok = True + + class PickleTypesTest(fixtures.TestBase): @testing.combinations( ("Boo", Boolean()), @@ -530,6 +535,7 @@ class PickleTypesTest(fixtures.TestBase): ("Lar", LargeBinary()), ("Pic", PickleType()), ("Int", Interval()), + ("Dec", SomeTypeDecorator()), argnames="name,type_", id_="ar", ) @@ -543,10 +549,37 @@ class PickleTypesTest(fixtures.TestBase): meta = MetaData() Table("foo", meta, column_type) + expr = select(1).where(column_type == bindparam("q")) + for loads, dumps in picklers(): loads(dumps(column_type)) loads(dumps(meta)) + expr_str_one = str(expr) + ne = loads(dumps(expr)) + + eq_(str(ne), expr_str_one) + + re_pickle_it = loads(dumps(ne)) + eq_(str(re_pickle_it), expr_str_one) + + def test_pickle_td_comparator(self): + comparator = SomeTypeDecorator().comparator_factory(column("q")) + + expected_mro = ( + TypeDecorator.Comparator, + sqltypes.Concatenable.Comparator, + TypeEngine.Comparator, + ) + eq_(comparator.__class__.__mro__[1:4], expected_mro) + + for loads, dumps in picklers(): + unpickled = loads(dumps(comparator)) + eq_(unpickled.__class__.__mro__[1:4], expected_mro) + + reunpickled = loads(dumps(unpickled)) + eq_(reunpickled.__class__.__mro__[1:4], expected_mro) + @testing.combinations( ("Str", String()), ("Tex", Text()),