From: Federico Caselli Date: Mon, 6 Mar 2023 23:25:59 +0000 (+0100) Subject: Fix regression when deserializing python rows into cython X-Git-Tag: rel_2_0_6~5^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fd9aa847920b9e4dff61ef7a5555c9fa6e362484;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix regression when deserializing python rows into cython Fixed regression involving pickling of Python rows between the cython and pure Python implementations of :class:`.Row`, which occurred as part of refactoring code for version 2.0 with typing. A particular constant were turned into a string based ``Enum`` for the pure Python version of :class:`.Row` whereas the cython version continued to use an integer constant, leading to deserialization failures. Regression occurred in a4bb502cf95ea3523e4d383c4377e50f402d7d52 Fixes: #9423 Change-Id: Icbd85cacb2d589cef7c246de7064249926146f2e --- diff --git a/doc/build/changelog/unreleased_20/9423.rst b/doc/build/changelog/unreleased_20/9423.rst new file mode 100644 index 0000000000..9d4e0ac3e9 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9423.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 9418 + + Fixed regression involving pickling of Python rows between the cython and + pure Python implementations of :class:`.Row`, which occurred as part of + refactoring code for version 2.0 with typing. A particular constant were + turned into a string based ``Enum`` for the pure Python version of + :class:`.Row` whereas the cython version continued to use an integer + constant, leading to deserialization failures. diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py index 7cbac552fd..1b952fe4c1 100644 --- a/lib/sqlalchemy/engine/_py_row.py +++ b/lib/sqlalchemy/engine/_py_row.py @@ -24,7 +24,7 @@ if typing.TYPE_CHECKING: MD_INDEX = 0 # integer index in cursor.description -class _KeyStyle(enum.Enum): +class _KeyStyle(enum.IntEnum): KEY_INTEGER_ONLY = 0 """__getitem__ only allows integer values and slices, raises TypeError otherwise""" @@ -121,6 +121,9 @@ class BaseRow: mdindex = rec[MD_INDEX] if mdindex is None: self._parent._raise_for_ambiguous_column_name(rec) + # NOTE: keep "== KEY_OBJECTS_ONLY" instead of "is KEY_OBJECTS_ONLY" + # since deserializing the class from cython will load an int in + # _key_style, not an instance of _KeyStyle elif self._key_style == KEY_OBJECTS_ONLY and isinstance(key, int): raise KeyError(key) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 153ce4d0ab..d061f26a28 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -59,7 +59,7 @@ def picklers(): # yes, this thing needs this much testing for pickle_ in picklers: - for protocol in range(-2, pickle.HIGHEST_PROTOCOL): + for protocol in range(-2, pickle.HIGHEST_PROTOCOL + 1): yield pickle_.loads, lambda d: pickle_.dumps(d, protocol) diff --git a/test/base/test_result.py b/test/base/test_result.py index 3e6444daa1..6e7e53c21e 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -195,35 +195,48 @@ class ResultTupleTest(fixtures.TestBase): eq_(kt._asdict(), {"a": 1, "b": 3}) @testing.requires.cextensions - def test_serialize_cy_py_cy(self): - from sqlalchemy.engine._py_row import BaseRow as _PyRow - from sqlalchemy.cyextension.resultproxy import BaseRow as _CyRow + @testing.variation("direction", ["py_to_cy", "cy_to_py"]) + def test_serialize_cy_py_cy(self, direction: testing.Variation): + from sqlalchemy.engine import _py_row + from sqlalchemy.cyextension import resultproxy as _cy_row global Row - p = result.SimpleResultMetaData(["a", None, "b"]) + p = result.SimpleResultMetaData(["a", "w", "b"]) + + if direction.py_to_cy: + dump_cls = _py_row.BaseRow + num = _py_row.KEY_INTEGER_ONLY + load_cls = _cy_row.BaseRow + elif direction.cy_to_py: + dump_cls = _cy_row.BaseRow + num = _cy_row.KEY_INTEGER_ONLY + load_cls = _py_row.BaseRow + else: + direction.fail() for loads, dumps in picklers(): - class Row(_CyRow): + class Row(dump_cls): pass - row = Row(p, p._processors, p._keymap, 0, (1, 2, 3)) + row = Row(p, p._processors, p._keymap, num, (1, 2, 3)) state = dumps(row) - class Row(_PyRow): + class Row(load_cls): pass row2 = loads(state) - is_true(isinstance(row2, _PyRow)) + is_true(isinstance(row2, load_cls)) + is_false(isinstance(row2, dump_cls)) state2 = dumps(row2) - class Row(_CyRow): + class Row(dump_cls): pass row3 = loads(state2) - is_true(isinstance(row3, _CyRow)) + is_true(isinstance(row3, dump_cls)) class ResultTest(fixtures.TestBase): diff --git a/test/base/test_utils.py b/test/base/test_utils.py index b979d43bcc..01877f7766 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -2458,9 +2458,8 @@ class SymbolTest(fixtures.TestBase): s = pickle.dumps(sym1) pickle.loads(s) - for protocol in 0, 1, 2: - print(protocol) - serial = pickle.dumps(sym1) + for _, dumper in picklers(): + serial = dumper(sym1) rt = pickle.loads(serial) assert rt is sym1 assert rt is sym2 diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 41bb81200d..0537dc2281 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -4,7 +4,11 @@ from contextlib import contextmanager import csv from io import StringIO import operator +import os import pickle +import subprocess +import sys +from tempfile import mkstemp from unittest.mock import Mock from unittest.mock import patch @@ -502,25 +506,25 @@ class CursorResultTest(fixtures.TablesTest): lambda: result[0]._mapping[addresses.c.address_id], ) - # @testing.variation("use_labels", [True, False]) - # def _dont_test_pickle_rows_other_process(self, connection, use_labels): - # result = self._pickle_row_data(connection, use_labels) - - # f, name = mkstemp("pkl") - # with os.fdopen(f, "wb") as f: - # pickle.dump(result, f) - # name = name.replace(os.sep, "/") - # code = ( - # "import sqlalchemy; import pickle; print([" - # f"r[0] for r in pickle.load(open('''{name}''', 'rb'))])" - # ) - # proc = subprocess.run( - # [sys.executable, "-c", code], stdout=subprocess.PIPE - # ) - # exp = str([r[0] for r in result]).encode() - # eq_(proc.returncode, 0) - # eq_(proc.stdout.strip(), exp) - # os.unlink(name) + @testing.variation("use_labels", [True, False]) + def test_pickle_rows_other_process(self, connection, use_labels): + result = self._pickle_row_data(connection, use_labels) + + f, name = mkstemp("pkl") + with os.fdopen(f, "wb") as f: + pickle.dump(result, f) + name = name.replace(os.sep, "/") + code = ( + "import sqlalchemy; import pickle; print([" + f"r[0] for r in pickle.load(open('''{name}''', 'rb'))])" + ) + proc = subprocess.run( + [sys.executable, "-c", code], stdout=subprocess.PIPE + ) + exp = str([r[0] for r in result]).encode() + eq_(proc.returncode, 0) + eq_(proc.stdout.strip(), exp) + os.unlink(name) def test_column_error_printing(self, connection): result = connection.execute(select(1))