]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement custom setstate to work around implicit type/comparator
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Aug 2023 14:17:35 +0000 (10:17 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Wed, 9 Aug 2023 22:14:37 +0000 (22:14 +0000)
Fixed issue where unpickling of a :class:`_schema.Column` or other
:class:`_sql.ColumnElement` would fail to restore the correct "comparator"
object, which is used to generate SQL expressions specific to the type
object.

Fixes: #10213
Change-Id: I74e805024bcc0d93d549bd94757c2865b3117d72
(cherry picked from commit 9d2b83740ad5c700b28cf4ca7807c09c7338c36a)

doc/build/changelog/unreleased_14/10213.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/type_api.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_14/10213.rst b/doc/build/changelog/unreleased_14/10213.rst
new file mode 100644 (file)
index 0000000..96c17b1
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 10213
+    :versions: 2.0.20
+
+    Fixed issue where unpickling of a :class:`_schema.Column` or other
+    :class:`_sql.ColumnElement` would fail to restore the correct "comparator"
+    object, which is used to generate SQL expressions specific to the type
+    object.
index a89273e4da71b9e6190f8c50acbdcd318c1d0eb8..4eac22628533114aa1f2a84bc3724c651288dbd8 100644 (file)
@@ -850,6 +850,9 @@ class ColumnElement(
         else:
             return comparator_factory(self)
 
+    def __setstate__(self, state):
+        self.__dict__.update(state)
+
     def __getattr__(self, key):
         try:
             return getattr(self.comparator, key)
index b404b41a5e1e8d731bad1f7b7b339cd32a2f198a..25ae7eabc23e9604f718511473b8b149522a792d 100644 (file)
@@ -111,8 +111,12 @@ class TypeEngine(Traversible):
 
             return op, self.type
 
+        # note: this reduce is needed for tests to pass under python 2.
+        # it does not appear to apply to python 3.  It has however been
+        # modified to accommodate issue #10213.  In SQLA 2 this reduce
+        # has been removed.
         def __reduce__(self):
-            return _reconstitute_comparator, (self.expr,)
+            return _reconstitute_comparator, (self.expr, self.expr.type)
 
     hashable = True
     """Flag, if False, means values from this type aren't hashable.
@@ -1945,8 +1949,14 @@ class Variant(TypeDecorator):
         return self.impl.comparator_factory
 
 
-def _reconstitute_comparator(expression):
-    return expression.comparator
+def _reconstitute_comparator(expression, type_=None):
+    # changed for #10213, added type_ argument.
+    # for previous pickles, keep type_ optional
+    if type_ is None:
+        return expression.comparator
+
+    comparator_factory = type_.comparator_factory
+    return comparator_factory(expression)
 
 
 def to_instance(typeobj, *arg, **kw):
index a03cb21fb30ca6af77d9fc0c850244df6f4d2466..fb0ecddb3829b6cc0e590f66e1ecfccd5510ee44 100644 (file)
@@ -1,5 +1,6 @@
 import datetime
 import operator
+import pickle
 
 from sqlalchemy import and_
 from sqlalchemy import between
@@ -68,6 +69,7 @@ from sqlalchemy.types import DateTime
 from sqlalchemy.types import Indexable
 from sqlalchemy.types import JSON
 from sqlalchemy.types import MatchType
+from sqlalchemy.types import NullType
 from sqlalchemy.types import TypeDecorator
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.types import UserDefinedType
@@ -2250,6 +2252,22 @@ class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         clause = tuple_(1, 2, 3)
         eq_(str(clause), str(util.pickle.loads(util.pickle.dumps(clause))))
 
+    @testing.combinations(Integer(), String(), JSON(), argnames="typ")
+    @testing.variation("eval_first", [True, False])
+    def test_pickle_comparator(self, typ, eval_first):
+        """test #10213"""
+
+        table1 = Table("t", MetaData(), Column("x", typ))
+        t1 = table1.c.x
+
+        if eval_first:
+            t1.comparator
+
+        t1p = pickle.loads(pickle.dumps(table1.c.x))
+
+        is_not(t1p.comparator.__class__, NullType.Comparator)
+        is_(t1.comparator.__class__, t1p.comparator.__class__)
+
     @testing.combinations(
         (operator.lt, "<", ">"),
         (operator.gt, ">", "<"),