]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-132775: Add _PyMarshal_GetXIData() (gh-133108)
authorEric Snow <ericsnowcurrently@gmail.com>
Mon, 28 Apr 2025 23:23:46 +0000 (17:23 -0600)
committerGitHub <noreply@github.com>
Mon, 28 Apr 2025 23:23:46 +0000 (17:23 -0600)
Note that the bulk of this change is tests.

Include/internal/pycore_crossinterp.h
Lib/test/_crossinterp_definitions.py
Lib/test/test_crossinterp.py
Modules/_testinternalcapi.c
Python/crossinterp.c

index 5cf9f8fb5a03882bbbf2ba49ea6292f54b4d9923..4b7446a1f40ccf3e8ff85233fc3f6cdf1cd3895e 100644 (file)
@@ -171,6 +171,13 @@ PyAPI_FUNC(_PyBytes_data_t *) _PyBytes_GetXIDataWrapped(
         xid_newobjfunc,
         _PyXIData_t *);
 
+// _PyObject_GetXIData() for marshal
+PyAPI_FUNC(PyObject *) _PyMarshal_ReadObjectFromXIData(_PyXIData_t *);
+PyAPI_FUNC(int) _PyMarshal_GetXIData(
+        PyThreadState *,
+        PyObject *,
+        _PyXIData_t *);
+
 
 /* using cross-interpreter data */
 
index 9b52aea39522f56e2e6d7a023af440327eb87d57..0d5f6c7db064d3b839b54a86fd17275d14ed35f9 100644 (file)
@@ -100,7 +100,7 @@ ham_C_nested, *_ = eggs_closure_N(2)
 ham_C_closure, *_ = eggs_closure_C(2)
 
 
-FUNCTIONS = [
+TOP_FUNCTIONS = [
     # shallow
     spam_minimal,
     spam_full,
@@ -112,6 +112,8 @@ FUNCTIONS = [
     spam_NC,
     spam_CN,
     spam_CC,
+]
+NESTED_FUNCTIONS = [
     # inner func
     eggs_nested,
     eggs_closure,
@@ -125,6 +127,10 @@ FUNCTIONS = [
     ham_C_nested,
     ham_C_closure,
 ]
+FUNCTIONS = [
+    *TOP_FUNCTIONS,
+    *NESTED_FUNCTIONS,
+]
 
 
 #######################################
@@ -157,8 +163,10 @@ FUNCTION_LIKE = [
     gen_spam_1,
     gen_spam_2,
     async_spam,
-    coro_spam,  # actually FunctionType?
     asyncgen_spam,
+]
+FUNCTION_LIKE_APPLIED = [
+    coro_spam,  # actually FunctionType?
     asynccoro_spam,  # actually FunctionType?
 ]
 
@@ -202,6 +210,13 @@ class SpamFull:
     # __str__
     # ...
 
+    def __eq__(self, other):
+        if not isinstance(other, SpamFull):
+            return NotImplemented
+        return (self.a == other.a and
+                self.b == other.b and
+                self.c == other.c)
+
     @property
     def prop(self):
         return True
@@ -222,9 +237,47 @@ def class_eggs_inner():
 EggsNested = class_eggs_inner()
 
 
+TOP_CLASSES = {
+    Spam: (),
+    SpamOkay: (),
+    SpamFull: (1, 2, 3),
+    SubSpamFull: (1, 2, 3),
+    SubTuple: ([1, 2, 3],),
+}
+CLASSES_WITHOUT_EQUALITY = [
+    Spam,
+    SpamOkay,
+]
+BUILTIN_SUBCLASSES = [
+    SubTuple,
+]
+NESTED_CLASSES = {
+    EggsNested: (),
+}
+CLASSES = {
+    **TOP_CLASSES,
+    **NESTED_CLASSES,
+}
+
 
 #######################################
 # exceptions
 
 class MimimalError(Exception):
     pass
+
+
+class RichError(Exception):
+    def __init__(self, msg, value=None):
+        super().__init__(msg, value)
+        self.msg = msg
+        self.value = value
+
+    def __eq__(self, other):
+        if not isinstance(other, RichError):
+            return NotImplemented
+        if self.msg != other.msg:
+            return False
+        if self.value != other.value:
+            return False
+        return True
index e1d1998fefc7fbd5eece86cc5d08c2948e13361d..5ebb78b0ea9e3b4460d351b7b3da2de9c6b9f005 100644 (file)
@@ -17,6 +17,9 @@ BUILTIN_TYPES = [o for _, o in __builtins__.items()
                  if isinstance(o, type)]
 EXCEPTION_TYPES = [cls for cls in BUILTIN_TYPES
                    if issubclass(cls, BaseException)]
+OTHER_TYPES = [o for n, o in vars(types).items()
+               if (isinstance(o, type) and
+                  n not in ('DynamicClassAttribute', '_GeneratorWrapper'))]
 
 
 class _GetXIDataTests(unittest.TestCase):
@@ -40,16 +43,42 @@ class _GetXIDataTests(unittest.TestCase):
                 got = _testinternalcapi.restore_crossinterp_data(xid)
                 yield obj, got
 
-    def assert_roundtrip_equal(self, values, *, mode=None):
-        for obj, got in self.iter_roundtrip_values(values, mode=mode):
-             self.assertEqual(got, obj)
-             self.assertIs(type(got), type(obj))
-
     def assert_roundtrip_identical(self, values, *, mode=None):
         for obj, got in self.iter_roundtrip_values(values, mode=mode):
             # XXX What about between interpreters?
             self.assertIs(got, obj)
 
+    def assert_roundtrip_equal(self, values, *, mode=None, expecttype=None):
+        for obj, got in self.iter_roundtrip_values(values, mode=mode):
+            self.assertEqual(got, obj)
+            self.assertIs(type(got),
+                          type(obj) if expecttype is None else expecttype)
+
+#    def assert_roundtrip_equal_not_identical(self, values, *,
+#                                            mode=None, expecttype=None):
+#        mode = self._resolve_mode(mode)
+#        for obj in values:
+#            cls = type(obj)
+#            with self.subTest(obj):
+#                got = self._get_roundtrip(obj, mode)
+#                self.assertIsNot(got, obj)
+#                self.assertIs(type(got), type(obj))
+#                self.assertEqual(got, obj)
+#                self.assertIs(type(got),
+#                              cls if expecttype is None else expecttype)
+#
+#    def assert_roundtrip_not_equal(self, values, *, mode=None, expecttype=None):
+#        mode = self._resolve_mode(mode)
+#        for obj in values:
+#            cls = type(obj)
+#            with self.subTest(obj):
+#                got = self._get_roundtrip(obj, mode)
+#                self.assertIsNot(got, obj)
+#                self.assertIs(type(got), type(obj))
+#                self.assertNotEqual(got, obj)
+#                self.assertIs(type(got),
+#                              cls if expecttype is None else expecttype)
+
     def assert_not_shareable(self, values, exctype=None, *, mode=None):
         mode = self._resolve_mode(mode)
         for obj in values:
@@ -66,6 +95,197 @@ class _GetXIDataTests(unittest.TestCase):
         return mode
 
 
+class MarshalTests(_GetXIDataTests):
+
+    MODE = 'marshal'
+
+    def test_simple_builtin_singletons(self):
+        self.assert_roundtrip_identical([
+            True,
+            False,
+            None,
+            Ellipsis,
+        ])
+        self.assert_not_shareable([
+            NotImplemented,
+        ])
+
+    def test_simple_builtin_objects(self):
+        self.assert_roundtrip_equal([
+            # int
+            *range(-1, 258),
+            sys.maxsize + 1,
+            sys.maxsize,
+            -sys.maxsize - 1,
+            -sys.maxsize - 2,
+            2**1000,
+            # complex
+            1+2j,
+            # float
+            0.0,
+            1.1,
+            -1.0,
+            0.12345678,
+            -0.12345678,
+            # bytes
+            *(i.to_bytes(2, 'little', signed=True)
+              for i in range(-1, 258)),
+            b'hello world',
+            # str
+            'hello world',
+            '你好世界',
+            '',
+        ])
+        self.assert_not_shareable([
+            object(),
+            types.SimpleNamespace(),
+        ])
+
+    def test_bytearray(self):
+        # bytearray is special because it unmarshals to bytes, not bytearray.
+        self.assert_roundtrip_equal([
+            bytearray(),
+            bytearray(b'hello world'),
+        ], expecttype=bytes)
+
+    def test_compound_immutable_builtin_objects(self):
+        self.assert_roundtrip_equal([
+            # tuple
+            (),
+            (1,),
+            ("hello", "world"),
+            (1, True, "hello"),
+            # frozenset
+            frozenset([1, 2, 3]),
+        ])
+        # nested
+        self.assert_roundtrip_equal([
+            # tuple
+            ((1,),),
+            ((1, 2), (3, 4)),
+            ((1, 2), (3, 4), (5, 6)),
+            # frozenset
+            frozenset([frozenset([1]), frozenset([2]), frozenset([3])]),
+        ])
+
+    def test_compound_mutable_builtin_objects(self):
+        self.assert_roundtrip_equal([
+            # list
+            [],
+            [1, 2, 3],
+            # dict
+            {},
+            {1: 7, 2: 8, 3: 9},
+            # set
+            set(),
+            {1, 2, 3},
+        ])
+        # nested
+        self.assert_roundtrip_equal([
+            [[1], [2], [3]],
+            {1: {'a': True}, 2: {'b': False}},
+            {(1, 2, 3,)},
+        ])
+
+    def test_compound_builtin_objects_with_bad_items(self):
+        bogus = object()
+        self.assert_not_shareable([
+            (bogus,),
+            frozenset([bogus]),
+            [bogus],
+            {bogus: True},
+            {True: bogus},
+            {bogus},
+        ])
+
+    def test_builtin_code(self):
+        self.assert_roundtrip_equal([
+            *(f.__code__ for f in defs.FUNCTIONS),
+            *(f.__code__ for f in defs.FUNCTION_LIKE),
+        ])
+
+    def test_builtin_type(self):
+        shareable = [
+            StopIteration,
+        ]
+        types = [
+            *BUILTIN_TYPES,
+            *OTHER_TYPES,
+        ]
+        self.assert_not_shareable(cls for cls in types
+                                  if cls not in shareable)
+        self.assert_roundtrip_identical(cls for cls in types
+                                        if cls in shareable)
+
+    def test_builtin_function(self):
+        functions = [
+            len,
+            sys.is_finalizing,
+            sys.exit,
+            _testinternalcapi.get_crossinterp_data,
+        ]
+        for func in functions:
+            assert type(func) is types.BuiltinFunctionType, func
+
+        self.assert_not_shareable(functions)
+
+    def test_builtin_exception(self):
+        msg = 'error!'
+        try:
+            raise Exception
+        except Exception as exc:
+            caught = exc
+        special = {
+            BaseExceptionGroup: (msg, [caught]),
+            ExceptionGroup: (msg, [caught]),
+#            UnicodeError: (None, msg, None, None, None),
+            UnicodeEncodeError: ('utf-8', '', 1, 3, msg),
+            UnicodeDecodeError: ('utf-8', b'', 1, 3, msg),
+            UnicodeTranslateError: ('', 1, 3, msg),
+        }
+        exceptions = []
+        for cls in EXCEPTION_TYPES:
+            args = special.get(cls) or (msg,)
+            exceptions.append(cls(*args))
+
+        self.assert_not_shareable(exceptions)
+        # Note that StopIteration (the type) can be marshalled,
+        # but its instances cannot.
+
+    def test_module(self):
+        assert type(sys) is types.ModuleType, type(sys)
+        assert type(defs) is types.ModuleType, type(defs)
+        assert type(unittest) is types.ModuleType, type(defs)
+
+        assert 'emptymod' not in sys.modules
+        with import_helper.ready_to_import('emptymod', ''):
+            import emptymod
+
+        self.assert_not_shareable([
+            sys,
+            defs,
+            unittest,
+            emptymod,
+        ])
+
+    def test_user_class(self):
+        self.assert_not_shareable(defs.TOP_CLASSES)
+
+        instances = []
+        for cls, args in defs.TOP_CLASSES.items():
+            instances.append(cls(*args))
+        self.assert_not_shareable(instances)
+
+    def test_user_function(self):
+        self.assert_not_shareable(defs.TOP_FUNCTIONS)
+
+    def test_user_exception(self):
+        self.assert_not_shareable([
+            defs.MimimalError('error!'),
+            defs.RichError('error!', 42),
+        ])
+
+
 class ShareableTypeTests(_GetXIDataTests):
 
     MODE = 'xidata'
@@ -184,6 +404,7 @@ class ShareableTypeTests(_GetXIDataTests):
 
     def test_function_like(self):
         self.assert_not_shareable(defs.FUNCTION_LIKE)
+        self.assert_not_shareable(defs.FUNCTION_LIKE_APPLIED)
 
     def test_builtin_wrapper(self):
         _wrappers = {
@@ -243,9 +464,7 @@ class ShareableTypeTests(_GetXIDataTests):
     def test_builtin_type(self):
         self.assert_not_shareable([
             *BUILTIN_TYPES,
-            *(o for n, o in vars(types).items()
-              if (isinstance(o, type) and
-                  n not in ('DynamicClassAttribute', '_GeneratorWrapper'))),
+            *OTHER_TYPES,
         ])
 
     def test_exception(self):
index 353cb630513abc322fca6fb579d54bfc8d4850c9..0ef064fe80d17351d4a8da329d9dd4a0c79332b6 100644 (file)
@@ -1730,6 +1730,11 @@ get_crossinterp_data(PyObject *self, PyObject *args, PyObject *kwargs)
             goto error;
         }
     }
+    else if (strcmp(mode, "marshal") == 0) {
+        if (_PyMarshal_GetXIData(tstate, obj, xidata) != 0) {
+            goto error;
+        }
+    }
     else {
         PyErr_Format(PyExc_ValueError, "unsupported mode %R", modeobj);
         goto error;
index 662c9c72b15eb713320cbf1f9aa1589d982dc8a1..753d784a503467ca3d29b00cd212cf3e1d0c2919 100644 (file)
@@ -2,6 +2,7 @@
 /* API for managing interactions between isolated interpreters */
 
 #include "Python.h"
+#include "marshal.h"              // PyMarshal_WriteObjectToString()
 #include "pycore_ceval.h"         // _Py_simple_func
 #include "pycore_crossinterp.h"   // _PyXIData_t
 #include "pycore_initconfig.h"    // _PyStatus_OK()
@@ -286,6 +287,48 @@ _PyObject_GetXIData(PyThreadState *tstate,
 }
 
 
+/* marshal wrapper */
+
+PyObject *
+_PyMarshal_ReadObjectFromXIData(_PyXIData_t *xidata)
+{
+    PyThreadState *tstate = _PyThreadState_GET();
+    _PyBytes_data_t *shared = (_PyBytes_data_t *)xidata->data;
+    PyObject *obj = PyMarshal_ReadObjectFromString(shared->bytes, shared->len);
+    if (obj == NULL) {
+        PyObject *cause = _PyErr_GetRaisedException(tstate);
+        assert(cause != NULL);
+        _set_xid_lookup_failure(
+                    tstate, NULL, "object could not be unmarshalled", cause);
+        Py_DECREF(cause);
+        return NULL;
+    }
+    return obj;
+}
+
+int
+_PyMarshal_GetXIData(PyThreadState *tstate, PyObject *obj, _PyXIData_t *xidata)
+{
+    PyObject *bytes = PyMarshal_WriteObjectToString(obj, Py_MARSHAL_VERSION);
+    if (bytes == NULL) {
+        PyObject *cause = _PyErr_GetRaisedException(tstate);
+        assert(cause != NULL);
+        _set_xid_lookup_failure(
+                    tstate, NULL, "object could not be marshalled", cause);
+        Py_DECREF(cause);
+        return -1;
+    }
+    size_t size = sizeof(_PyBytes_data_t);
+    _PyBytes_data_t *shared = _PyBytes_GetXIDataWrapped(
+            tstate, bytes, size, _PyMarshal_ReadObjectFromXIData, xidata);
+    Py_DECREF(bytes);
+    if (shared == NULL) {
+        return -1;
+    }
+    return 0;
+}
+
+
 /* using cross-interpreter data */
 
 PyObject *