From: Federico Caselli Date: Sat, 4 Mar 2023 22:33:02 +0000 (+0100) Subject: Fixed bug where :meth:`_engine.Row`s could not be X-Git-Tag: rel_2_0_5~3^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4dc2ce0fb15d72e508c72b0cb0e7fd622e6cc34f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixed bug where :meth:`_engine.Row`s could not be unpickled by other processes. Fixes: #9423 Change-Id: Ie496e31158caff5f72e0a9069dddd55f3116e0b8 --- diff --git a/doc/build/changelog/unreleased_20/9423.rst b/doc/build/changelog/unreleased_20/9423.rst new file mode 100644 index 0000000000..ddb6f9f4f7 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9423.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, engine + :tickets: 9423 + + Fixed bug where :meth:`_engine.Row`s could not be + unpickled by other processes. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index c65c4a058c..0eea3398d5 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -53,8 +53,6 @@ from ..util import compat from ..util.typing import Literal from ..util.typing import Self -_UNPICKLED = util.symbol("unpickled") - if typing.TYPE_CHECKING: from .base import Connection @@ -445,7 +443,12 @@ class CursorResultMetaData(ResultMetaData): # then for the dupe keys, put the "ambiguous column" # record into by_key. - by_key.update({key: (None, None, (), key) for key in dupes}) + by_key.update( + { + key: (None, None, [], key, key, None, None) + for key in dupes + } + ) else: @@ -886,7 +889,7 @@ class CursorResultMetaData(ResultMetaData): # ensure it raises CursorResultMetaData._key_fallback(self, ke.args[0], ke) - index = rec[0] + index = rec[MD_INDEX] if index is None: self._raise_for_ambiguous_column_name(rec) @@ -894,15 +897,23 @@ class CursorResultMetaData(ResultMetaData): yield cast(_NonAmbigCursorKeyMapRecType, rec) def __getstate__(self): + # TODO: consider serializing this as SimpleResultMetaData return { "_keymap": { - key: (rec[MD_INDEX], rec[MD_RESULT_MAP_INDEX], _UNPICKLED, key) + key: ( + rec[MD_INDEX], + rec[MD_RESULT_MAP_INDEX], + [], + key, + rec[MD_RENDERED_NAME], + None, + None, + ) for key, rec in self._keymap.items() if isinstance(key, (str, int)) }, "_keys": self._keys, "_translated_indexes": self._translated_indexes, - "_tuplefilter": self._tuplefilter, } def __setstate__(self, state): diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 66584c96e1..97f35164af 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -4,7 +4,11 @@ from contextlib import contextmanager import csv from io import StringIO import operator +import os import pickle +import subprocess +import sys +from tempfile import mkstemp from unittest.mock import Mock from unittest.mock import patch @@ -420,9 +424,8 @@ class CursorResultTest(fixtures.TablesTest): # in 1.x, would warn for string match, but return a result not_in(sql.column("content_type"), row._mapping) - def test_pickled_rows(self, connection): + def _pickle_row_data(self, connection, use_labels): users = self.tables.users - addresses = self.tables.addresses connection.execute( users.insert(), @@ -433,70 +436,95 @@ class CursorResultTest(fixtures.TablesTest): ], ) - for use_pickle in False, True: - for use_labels in False, True: - result = connection.execute( - users.select() - .order_by(users.c.user_id) - .set_label_style( - LABEL_STYLE_TABLENAME_PLUS_COL - if use_labels - else LABEL_STYLE_NONE - ) - ).fetchall() + result = connection.execute( + users.select() + .order_by(users.c.user_id) + .set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + if use_labels + else LABEL_STYLE_NONE + ) + ).all() + return result - if use_pickle: - result = pickle.loads(pickle.dumps(result)) + @testing.variation("use_pickle", [True, False]) + @testing.variation("use_labels", [True, False]) + def test_pickled_rows(self, connection, use_pickle, use_labels): + users = self.tables.users + addresses = self.tables.addresses - eq_(result, [(7, "jack"), (8, "ed"), (9, "fred")]) - if use_labels: - eq_(result[0]._mapping["users_user_id"], 7) - eq_( - list(result[0]._fields), - ["users_user_id", "users_user_name"], - ) - else: - eq_(result[0]._mapping["user_id"], 7) - eq_(list(result[0]._fields), ["user_id", "user_name"]) + result = self._pickle_row_data(connection, use_labels) - eq_(result[0][0], 7) + if use_pickle: + result = pickle.loads(pickle.dumps(result)) - assert_raises( - exc.NoSuchColumnError, - lambda: result[0]._mapping["fake key"], - ) + eq_(result, [(7, "jack"), (8, "ed"), (9, "fred")]) + if use_labels: + eq_(result[0]._mapping["users_user_id"], 7) + eq_( + list(result[0]._fields), + ["users_user_id", "users_user_name"], + ) + else: + eq_(result[0]._mapping["user_id"], 7) + eq_(list(result[0]._fields), ["user_id", "user_name"]) - # previously would warn + eq_(result[0][0], 7) - if use_pickle: - with expect_raises_message( - exc.NoSuchColumnError, - "Row was unpickled; lookup by ColumnElement is " - "unsupported", - ): - result[0]._mapping[users.c.user_id] - else: - eq_(result[0]._mapping[users.c.user_id], 7) + assert_raises( + exc.NoSuchColumnError, + lambda: result[0]._mapping["fake key"], + ) - if use_pickle: - with expect_raises_message( - exc.NoSuchColumnError, - "Row was unpickled; lookup by ColumnElement is " - "unsupported", - ): - result[0]._mapping[users.c.user_name] - else: - eq_(result[0]._mapping[users.c.user_name], "jack") + # previously would warn - assert_raises( - exc.NoSuchColumnError, - lambda: result[0]._mapping[addresses.c.user_id], - ) + if use_pickle: + with expect_raises_message( + exc.NoSuchColumnError, + "Row was unpickled; lookup by ColumnElement is " "unsupported", + ): + result[0]._mapping[users.c.user_id] + else: + eq_(result[0]._mapping[users.c.user_id], 7) - assert_raises( - exc.NoSuchColumnError, - lambda: result[0]._mapping[addresses.c.address_id], - ) + if use_pickle: + with expect_raises_message( + exc.NoSuchColumnError, + "Row was unpickled; lookup by ColumnElement is " "unsupported", + ): + result[0]._mapping[users.c.user_name] + else: + eq_(result[0]._mapping[users.c.user_name], "jack") + + assert_raises( + exc.NoSuchColumnError, + lambda: result[0]._mapping[addresses.c.user_id], + ) + + assert_raises( + exc.NoSuchColumnError, + lambda: result[0]._mapping[addresses.c.address_id], + ) + + @testing.variation("use_labels", [True, False]) + def test_pickle_rows_other_process(self, connection, use_labels): + result = self._pickle_row_data(connection, use_labels) + + f, name = mkstemp("pkl") + with os.fdopen(f, "wb") as f: + pickle.dump(result, f) + name = name.replace(os.sep, "/") + code = ( + "import sqlalchemy; import pickle; print([" + f"r[0] for r in pickle.load(open('''{name}''', 'rb'))])" + ) + proc = subprocess.run( + [sys.executable, "-c", code], stdout=subprocess.PIPE + ) + exp = str([r[0] for r in result]).encode() + eq_(proc.returncode, 0) + eq_(proc.stdout.strip(), exp) + os.unlink(name) def test_column_error_printing(self, connection): result = connection.execute(select(1))