]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add tests for pickling types inside an expression, some reduce methods
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Oct 2024 12:20:25 +0000 (08:20 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Oct 2024 16:00:44 +0000 (12:00 -0400)
Fixed regression from 1.4 where some datatypes such as those derived from
:class:`.TypeDecorator` could not be pickled when they were part of a
larger SQL expression composition due to internal supporting structures
themselves not being pickleable.

Fixes: #12002
Change-Id: I016e37b0c62071413f24c9aac35f6ecf475becaa

doc/build/changelog/unreleased_20/12002.rst [new file with mode: 0644]
lib/sqlalchemy/sql/type_api.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/12002.rst b/doc/build/changelog/unreleased_20/12002.rst
new file mode 100644 (file)
index 0000000..49ac701
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, sql, regression
+    :tickets: 12002
+
+    Fixed regression from 1.4 where some datatypes such as those derived from
+    :class:`.TypeDecorator` could not be pickled when they were part of a
+    larger SQL expression composition due to internal supporting structures
+    themselves not being pickleable.
index 3367aab64c9568fa0469896944aaeafd9cae66cd..9f40905fa62ca361a5c623997460b350452f8405 100644 (file)
@@ -183,6 +183,9 @@ class TypeEngine(Visitable, Generic[_T]):
             self.expr = expr
             self.type = expr.type
 
+        def __reduce__(self) -> Any:
+            return self.__class__, (self.expr,)
+
         @util.preload_module("sqlalchemy.sql.default_comparator")
         def operate(
             self, op: OperatorType, *other: Any, **kwargs: Any
@@ -1721,20 +1724,38 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
             kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
             return super().reverse_operate(op, other, **kwargs)
 
+    @staticmethod
+    def _reduce_td_comparator(
+        impl: TypeEngine[Any], expr: ColumnElement[_T]
+    ) -> Any:
+        return TypeDecorator._create_td_comparator_type(impl)(expr)
+
+    @staticmethod
+    def _create_td_comparator_type(
+        impl: TypeEngine[Any],
+    ) -> _ComparatorFactory[Any]:
+
+        def __reduce__(self: TypeDecorator.Comparator[Any]) -> Any:
+            return (TypeDecorator._reduce_td_comparator, (impl, self.expr))
+
+        return type(
+            "TDComparator",
+            (TypeDecorator.Comparator, impl.comparator_factory),  # type: ignore # noqa: E501
+            {"__reduce__": __reduce__},
+        )
+
     @property
     def comparator_factory(  # type: ignore  # mypy properties bug
         self,
     ) -> _ComparatorFactory[Any]:
         if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:  # type: ignore # noqa: E501
-            return self.impl.comparator_factory
+            return self.impl_instance.comparator_factory
         else:
             # reconcile the Comparator class on the impl with that
-            # of TypeDecorator
-            return type(
-                "TDComparator",
-                (TypeDecorator.Comparator, self.impl.comparator_factory),  # type: ignore # noqa: E501
-                {},
-            )
+            # of TypeDecorator.
+            # the use of multiple staticmethods is to support repeated
+            # pickling of the Comparator itself
+            return TypeDecorator._create_td_comparator_type(self.impl_instance)
 
     def _copy_with_check(self) -> Self:
         tt = self.copy()
index e47b85029aceed2117c3185b6f09a3e00c9b274b..f5a042e32a4650ddb41f7e0cbfb1f16a47b162a2 100644 (file)
@@ -512,6 +512,11 @@ class AsGenericTest(fixtures.TestBase):
         assert isinstance(gentype, TypeEngine)
 
 
+class SomeTypeDecorator(TypeDecorator):
+    impl = String()
+    cache_ok = True
+
+
 class PickleTypesTest(fixtures.TestBase):
     @testing.combinations(
         ("Boo", Boolean()),
@@ -530,6 +535,7 @@ class PickleTypesTest(fixtures.TestBase):
         ("Lar", LargeBinary()),
         ("Pic", PickleType()),
         ("Int", Interval()),
+        ("Dec", SomeTypeDecorator()),
         argnames="name,type_",
         id_="ar",
     )
@@ -543,10 +549,37 @@ class PickleTypesTest(fixtures.TestBase):
         meta = MetaData()
         Table("foo", meta, column_type)
 
+        expr = select(1).where(column_type == bindparam("q"))
+
         for loads, dumps in picklers():
             loads(dumps(column_type))
             loads(dumps(meta))
 
+            expr_str_one = str(expr)
+            ne = loads(dumps(expr))
+
+            eq_(str(ne), expr_str_one)
+
+            re_pickle_it = loads(dumps(ne))
+            eq_(str(re_pickle_it), expr_str_one)
+
+    def test_pickle_td_comparator(self):
+        comparator = SomeTypeDecorator().comparator_factory(column("q"))
+
+        expected_mro = (
+            TypeDecorator.Comparator,
+            sqltypes.Concatenable.Comparator,
+            TypeEngine.Comparator,
+        )
+        eq_(comparator.__class__.__mro__[1:4], expected_mro)
+
+        for loads, dumps in picklers():
+            unpickled = loads(dumps(comparator))
+            eq_(unpickled.__class__.__mro__[1:4], expected_mro)
+
+            reunpickled = loads(dumps(unpickled))
+            eq_(reunpickled.__class__.__mro__[1:4], expected_mro)
+
     @testing.combinations(
         ("Str", String()),
         ("Tex", Text()),