]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed bug where :meth:`_engine.Row`s could not be
authorFederico Caselli <cfederico87@gmail.com>
Sat, 4 Mar 2023 22:33:02 +0000 (23:33 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 4 Mar 2023 22:33:02 +0000 (23:33 +0100)
unpickled by other processes.

Fixes: #9423
Change-Id: Ie496e31158caff5f72e0a9069dddd55f3116e0b8

doc/build/changelog/unreleased_20/9423.rst [new file with mode: 0644]
lib/sqlalchemy/engine/cursor.py
test/sql/test_resultset.py

diff --git a/doc/build/changelog/unreleased_20/9423.rst b/doc/build/changelog/unreleased_20/9423.rst
new file mode 100644 (file)
index 0000000..ddb6f9f
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 9423
+
+    Fixed bug where :meth:`_engine.Row`s could not be
+    unpickled by other processes.
index c65c4a058ce2b4746fc00c0964b13d3ff76bebb8..0eea3398d5914868c84c0d214f69178b9dccea63 100644 (file)
@@ -53,8 +53,6 @@ from ..util import compat
 from ..util.typing import Literal
 from ..util.typing import Self
 
-_UNPICKLED = util.symbol("unpickled")
-
 
 if typing.TYPE_CHECKING:
     from .base import Connection
@@ -445,7 +443,12 @@ class CursorResultMetaData(ResultMetaData):
 
                 # then for the dupe keys, put the "ambiguous column"
                 # record into by_key.
-                by_key.update({key: (None, None, (), key) for key in dupes})
+                by_key.update(
+                    {
+                        key: (None, None, [], key, key, None, None)
+                        for key in dupes
+                    }
+                )
 
             else:
 
@@ -886,7 +889,7 @@ class CursorResultMetaData(ResultMetaData):
                 # ensure it raises
                 CursorResultMetaData._key_fallback(self, ke.args[0], ke)
 
-            index = rec[0]
+            index = rec[MD_INDEX]
 
             if index is None:
                 self._raise_for_ambiguous_column_name(rec)
@@ -894,15 +897,23 @@ class CursorResultMetaData(ResultMetaData):
             yield cast(_NonAmbigCursorKeyMapRecType, rec)
 
     def __getstate__(self):
+        # TODO: consider serializing this as SimpleResultMetaData
         return {
             "_keymap": {
-                key: (rec[MD_INDEX], rec[MD_RESULT_MAP_INDEX], _UNPICKLED, key)
+                key: (
+                    rec[MD_INDEX],
+                    rec[MD_RESULT_MAP_INDEX],
+                    [],
+                    key,
+                    rec[MD_RENDERED_NAME],
+                    None,
+                    None,
+                )
                 for key, rec in self._keymap.items()
                 if isinstance(key, (str, int))
             },
             "_keys": self._keys,
             "_translated_indexes": self._translated_indexes,
-            "_tuplefilter": self._tuplefilter,
         }
 
     def __setstate__(self, state):
index 66584c96e1d17961d79975530b26ecd959772fcf..97f35164afe996ab2791f42892825437776c65eb 100644 (file)
@@ -4,7 +4,11 @@ from contextlib import contextmanager
 import csv
 from io import StringIO
 import operator
+import os
 import pickle
+import subprocess
+import sys
+from tempfile import mkstemp
 from unittest.mock import Mock
 from unittest.mock import patch
 
@@ -420,9 +424,8 @@ class CursorResultTest(fixtures.TablesTest):
         # in 1.x, would warn for string match, but return a result
         not_in(sql.column("content_type"), row._mapping)
 
-    def test_pickled_rows(self, connection):
+    def _pickle_row_data(self, connection, use_labels):
         users = self.tables.users
-        addresses = self.tables.addresses
 
         connection.execute(
             users.insert(),
@@ -433,70 +436,95 @@ class CursorResultTest(fixtures.TablesTest):
             ],
         )
 
-        for use_pickle in False, True:
-            for use_labels in False, True:
-                result = connection.execute(
-                    users.select()
-                    .order_by(users.c.user_id)
-                    .set_label_style(
-                        LABEL_STYLE_TABLENAME_PLUS_COL
-                        if use_labels
-                        else LABEL_STYLE_NONE
-                    )
-                ).fetchall()
+        result = connection.execute(
+            users.select()
+            .order_by(users.c.user_id)
+            .set_label_style(
+                LABEL_STYLE_TABLENAME_PLUS_COL
+                if use_labels
+                else LABEL_STYLE_NONE
+            )
+        ).all()
+        return result
 
-                if use_pickle:
-                    result = pickle.loads(pickle.dumps(result))
+    @testing.variation("use_pickle", [True, False])
+    @testing.variation("use_labels", [True, False])
+    def test_pickled_rows(self, connection, use_pickle, use_labels):
+        users = self.tables.users
+        addresses = self.tables.addresses
 
-                eq_(result, [(7, "jack"), (8, "ed"), (9, "fred")])
-                if use_labels:
-                    eq_(result[0]._mapping["users_user_id"], 7)
-                    eq_(
-                        list(result[0]._fields),
-                        ["users_user_id", "users_user_name"],
-                    )
-                else:
-                    eq_(result[0]._mapping["user_id"], 7)
-                    eq_(list(result[0]._fields), ["user_id", "user_name"])
+        result = self._pickle_row_data(connection, use_labels)
 
-                eq_(result[0][0], 7)
+        if use_pickle:
+            result = pickle.loads(pickle.dumps(result))
 
-                assert_raises(
-                    exc.NoSuchColumnError,
-                    lambda: result[0]._mapping["fake key"],
-                )
+        eq_(result, [(7, "jack"), (8, "ed"), (9, "fred")])
+        if use_labels:
+            eq_(result[0]._mapping["users_user_id"], 7)
+            eq_(
+                list(result[0]._fields),
+                ["users_user_id", "users_user_name"],
+            )
+        else:
+            eq_(result[0]._mapping["user_id"], 7)
+            eq_(list(result[0]._fields), ["user_id", "user_name"])
 
-                # previously would warn
+        eq_(result[0][0], 7)
 
-                if use_pickle:
-                    with expect_raises_message(
-                        exc.NoSuchColumnError,
-                        "Row was unpickled; lookup by ColumnElement is "
-                        "unsupported",
-                    ):
-                        result[0]._mapping[users.c.user_id]
-                else:
-                    eq_(result[0]._mapping[users.c.user_id], 7)
+        assert_raises(
+            exc.NoSuchColumnError,
+            lambda: result[0]._mapping["fake key"],
+        )
 
-                if use_pickle:
-                    with expect_raises_message(
-                        exc.NoSuchColumnError,
-                        "Row was unpickled; lookup by ColumnElement is "
-                        "unsupported",
-                    ):
-                        result[0]._mapping[users.c.user_name]
-                else:
-                    eq_(result[0]._mapping[users.c.user_name], "jack")
+        # previously would warn
 
-                assert_raises(
-                    exc.NoSuchColumnError,
-                    lambda: result[0]._mapping[addresses.c.user_id],
-                )
+        if use_pickle:
+            with expect_raises_message(
+                exc.NoSuchColumnError,
+                "Row was unpickled; lookup by ColumnElement is " "unsupported",
+            ):
+                result[0]._mapping[users.c.user_id]
+        else:
+            eq_(result[0]._mapping[users.c.user_id], 7)
 
-                assert_raises(
-                    exc.NoSuchColumnError,
-                    lambda: result[0]._mapping[addresses.c.address_id],
-                )
+        if use_pickle:
+            with expect_raises_message(
+                exc.NoSuchColumnError,
+                "Row was unpickled; lookup by ColumnElement is " "unsupported",
+            ):
+                result[0]._mapping[users.c.user_name]
+        else:
+            eq_(result[0]._mapping[users.c.user_name], "jack")
+
+        assert_raises(
+            exc.NoSuchColumnError,
+            lambda: result[0]._mapping[addresses.c.user_id],
+        )
+
+        assert_raises(
+            exc.NoSuchColumnError,
+            lambda: result[0]._mapping[addresses.c.address_id],
+        )
+
+    @testing.variation("use_labels", [True, False])
+    def test_pickle_rows_other_process(self, connection, use_labels):
+        result = self._pickle_row_data(connection, use_labels)
+
+        f, name = mkstemp("pkl")
+        with os.fdopen(f, "wb") as f:
+            pickle.dump(result, f)
+        name = name.replace(os.sep, "/")
+        code = (
+            "import sqlalchemy; import pickle; print(["
+            f"r[0] for r in pickle.load(open('''{name}''', 'rb'))])"
+        )
+        proc = subprocess.run(
+            [sys.executable, "-c", code], stdout=subprocess.PIPE
+        )
+        exp = str([r[0] for r in result]).encode()
+        eq_(proc.returncode, 0)
+        eq_(proc.stdout.strip(), exp)
+        os.unlink(name)
 
     def test_column_error_printing(self, connection):
         result = connection.execute(select(1))