]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix regression when deserializing python rows into cython
authorFederico Caselli <cfederico87@gmail.com>
Mon, 6 Mar 2023 23:25:59 +0000 (00:25 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Mar 2023 15:26:26 +0000 (10:26 -0500)
Fixed regression involving pickling of Python rows between the cython and
pure Python implementations of :class:`.Row`, which occurred as part of
refactoring code for version 2.0 with typing. A particular constant were
turned into a string based ``Enum`` for the pure Python version of
:class:`.Row` whereas the cython version continued to use an integer
constant, leading to deserialization failures.

Regression occurred in a4bb502cf95ea3523e4d383c4377e50f402d7d52

Fixes: #9423
Change-Id: Icbd85cacb2d589cef7c246de7064249926146f2e

doc/build/changelog/unreleased_20/9423.rst [new file with mode: 0644]
lib/sqlalchemy/engine/_py_row.py
lib/sqlalchemy/testing/util.py
test/base/test_result.py
test/base/test_utils.py
test/sql/test_resultset.py

diff --git a/doc/build/changelog/unreleased_20/9423.rst b/doc/build/changelog/unreleased_20/9423.rst
new file mode 100644 (file)
index 0000000..9d4e0ac
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9418
+
+    Fixed regression involving pickling of Python rows between the cython and
+    pure Python implementations of :class:`.Row`, which occurred as part of
+    refactoring code for version 2.0 with typing. A particular constant were
+    turned into a string based ``Enum`` for the pure Python version of
+    :class:`.Row` whereas the cython version continued to use an integer
+    constant, leading to deserialization failures.
index 7cbac552fd8816607e5fcf013dd8ab3f39ee3148..1b952fe4c167feb68e9a4b3b6603b9dc2946d726 100644 (file)
@@ -24,7 +24,7 @@ if typing.TYPE_CHECKING:
 MD_INDEX = 0  # integer index in cursor.description
 
 
-class _KeyStyle(enum.Enum):
+class _KeyStyle(enum.IntEnum):
     KEY_INTEGER_ONLY = 0
     """__getitem__ only allows integer values and slices, raises TypeError
     otherwise"""
@@ -121,6 +121,9 @@ class BaseRow:
         mdindex = rec[MD_INDEX]
         if mdindex is None:
             self._parent._raise_for_ambiguous_column_name(rec)
+        # NOTE: keep "== KEY_OBJECTS_ONLY" instead of "is KEY_OBJECTS_ONLY"
+        # since deserializing the class from cython will load an int in
+        # _key_style, not an instance of _KeyStyle
         elif self._key_style == KEY_OBJECTS_ONLY and isinstance(key, int):
             raise KeyError(key)
 
index 153ce4d0ab9feb79505634f600f72ed96f2c7996..d061f26a289112953b1a68a8565aaea307c05b51 100644 (file)
@@ -59,7 +59,7 @@ def picklers():
 
     # yes, this thing needs this much testing
     for pickle_ in picklers:
-        for protocol in range(-2, pickle.HIGHEST_PROTOCOL):
+        for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1):
             yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
 
 
index 3e6444daa1ffe1e840c23ae1085dc90cbe0f99f7..6e7e53c21e7cb8197bf36d3dd5d31a7f6c12c3fd 100644 (file)
@@ -195,35 +195,48 @@ class ResultTupleTest(fixtures.TestBase):
             eq_(kt._asdict(), {"a": 1, "b": 3})
 
     @testing.requires.cextensions
-    def test_serialize_cy_py_cy(self):
-        from sqlalchemy.engine._py_row import BaseRow as _PyRow
-        from sqlalchemy.cyextension.resultproxy import BaseRow as _CyRow
+    @testing.variation("direction", ["py_to_cy", "cy_to_py"])
+    def test_serialize_cy_py_cy(self, direction: testing.Variation):
+        from sqlalchemy.engine import _py_row
+        from sqlalchemy.cyextension import resultproxy as _cy_row
 
         global Row
 
-        p = result.SimpleResultMetaData(["a", None, "b"])
+        p = result.SimpleResultMetaData(["a", "w", "b"])
+
+        if direction.py_to_cy:
+            dump_cls = _py_row.BaseRow
+            num = _py_row.KEY_INTEGER_ONLY
+            load_cls = _cy_row.BaseRow
+        elif direction.cy_to_py:
+            dump_cls = _cy_row.BaseRow
+            num = _cy_row.KEY_INTEGER_ONLY
+            load_cls = _py_row.BaseRow
+        else:
+            direction.fail()
 
         for loads, dumps in picklers():
 
-            class Row(_CyRow):
+            class Row(dump_cls):
                 pass
 
-            row = Row(p, p._processors, p._keymap, 0, (1, 2, 3))
+            row = Row(p, p._processors, p._keymap, num, (1, 2, 3))
 
             state = dumps(row)
 
-            class Row(_PyRow):
+            class Row(load_cls):
                 pass
 
             row2 = loads(state)
-            is_true(isinstance(row2, _PyRow))
+            is_true(isinstance(row2, load_cls))
+            is_false(isinstance(row2, dump_cls))
             state2 = dumps(row2)
 
-            class Row(_CyRow):
+            class Row(dump_cls):
                 pass
 
             row3 = loads(state2)
-            is_true(isinstance(row3, _CyRow))
+            is_true(isinstance(row3, dump_cls))
 
 
 class ResultTest(fixtures.TestBase):
index b979d43bcc096b822ea135360658c3de4d88f619..01877f776633891db9431fd35e8be0cd3e558c37 100644 (file)
@@ -2458,9 +2458,8 @@ class SymbolTest(fixtures.TestBase):
         s = pickle.dumps(sym1)
         pickle.loads(s)
 
-        for protocol in 0, 1, 2:
-            print(protocol)
-            serial = pickle.dumps(sym1)
+        for _, dumper in picklers():
+            serial = dumper(sym1)
             rt = pickle.loads(serial)
             assert rt is sym1
             assert rt is sym2
index 41bb81200d6ec0d0850bc4867bc9f504aae3340c..0537dc22819961a52d10d0889686e4433c053c10 100644 (file)
@@ -4,7 +4,11 @@ from contextlib import contextmanager
 import csv
 from io import StringIO
 import operator
+import os
 import pickle
+import subprocess
+import sys
+from tempfile import mkstemp
 from unittest.mock import Mock
 from unittest.mock import patch
 
@@ -502,25 +506,25 @@ class CursorResultTest(fixtures.TablesTest):
             lambda: result[0]._mapping[addresses.c.address_id],
         )
 
-    @testing.variation("use_labels", [True, False])
-    # def _dont_test_pickle_rows_other_process(self, connection, use_labels):
-        result = self._pickle_row_data(connection, use_labels)
-
-        f, name = mkstemp("pkl")
-        with os.fdopen(f, "wb") as f:
-            pickle.dump(result, f)
-        name = name.replace(os.sep, "/")
-        code = (
-            "import sqlalchemy; import pickle; print(["
-            f"r[0] for r in pickle.load(open('''{name}''', 'rb'))])"
-        )
-        proc = subprocess.run(
-            [sys.executable, "-c", code], stdout=subprocess.PIPE
-        )
-        exp = str([r[0] for r in result]).encode()
-        eq_(proc.returncode, 0)
-        eq_(proc.stdout.strip(), exp)
-        os.unlink(name)
+    @testing.variation("use_labels", [True, False])
+    def test_pickle_rows_other_process(self, connection, use_labels):
+        result = self._pickle_row_data(connection, use_labels)
+
+        f, name = mkstemp("pkl")
+        with os.fdopen(f, "wb") as f:
+            pickle.dump(result, f)
+        name = name.replace(os.sep, "/")
+        code = (
+            "import sqlalchemy; import pickle; print(["
+            f"r[0] for r in pickle.load(open('''{name}''', 'rb'))])"
+        )
+        proc = subprocess.run(
+            [sys.executable, "-c", code], stdout=subprocess.PIPE
+        )
+        exp = str([r[0] for r in result]).encode()
+        eq_(proc.returncode, 0)
+        eq_(proc.stdout.strip(), exp)
+        os.unlink(name)
 
     def test_column_error_printing(self, connection):
         result = connection.execute(select(1))