]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-76785: Module-level Fixes for test.support.interpreters (gh-110236)
authorEric Snow <ericsnowcurrently@gmail.com>
Mon, 2 Oct 2023 20:47:41 +0000 (14:47 -0600)
committerGitHub <noreply@github.com>
Mon, 2 Oct 2023 20:47:41 +0000 (20:47 +0000)
* add RecvChannel.close() and SendChannel.close()
* make RecvChannel and SendChannel shareable
* expose ChannelEmptyError and ChannelNotEmptyError

Lib/test/support/interpreters.py
Lib/test/test_interpreters.py
Modules/_xxinterpchannelsmodule.c

index eeff3abe0324e58c6d2fc67245eb1ab4968ce068..d2beba31e80283ab35fcfc31fc08e585a48c0740 100644 (file)
@@ -7,7 +7,8 @@ import _xxinterpchannels as _channels
 # aliases:
 from _xxsubinterpreters import is_shareable
 from _xxinterpchannels import (
-    ChannelError, ChannelNotFoundError, ChannelEmptyError,
+    ChannelError, ChannelNotFoundError, ChannelClosedError,
+    ChannelEmptyError, ChannelNotEmptyError,
 )
 
 
@@ -117,10 +118,16 @@ def list_all_channels():
 class _ChannelEnd:
     """The base class for RecvChannel and SendChannel."""
 
-    def __init__(self, id):
-        if not isinstance(id, (int, _channels.ChannelID)):
-            raise TypeError(f'id must be an int, got {id!r}')
-        self._id = id
+    _end = None
+
+    def __init__(self, cid):
+        if self._end == 'send':
+            cid = _channels._channel_id(cid, send=True, force=True)
+        elif self._end == 'recv':
+            cid = _channels._channel_id(cid, recv=True, force=True)
+        else:
+            raise NotImplementedError(self._end)
+        self._id = cid
 
     def __repr__(self):
         return f'{type(self).__name__}(id={int(self._id)})'
@@ -147,6 +154,8 @@ _NOT_SET = object()
 class RecvChannel(_ChannelEnd):
     """The receiving end of a cross-interpreter channel."""
 
+    _end = 'recv'
+
     def recv(self, *, _sentinel=object(), _delay=10 / 1000):  # 10 milliseconds
         """Return the next object from the channel.
 
@@ -171,10 +180,15 @@ class RecvChannel(_ChannelEnd):
         else:
             return _channels.recv(self._id, default)
 
+    def close(self):
+        _channels.close(self._id, recv=True)
+
 
 class SendChannel(_ChannelEnd):
     """The sending end of a cross-interpreter channel."""
 
+    _end = 'send'
+
     def send(self, obj):
         """Send the object (i.e. its data) to the channel's receiving end.
 
@@ -196,3 +210,9 @@ class SendChannel(_ChannelEnd):
         # None.  This should be fixed when channel_send_wait() is added.
         # See bpo-32604 and gh-19829.
         return _channels.send(self._id, obj)
+
+    def close(self):
+        _channels.close(self._id, send=True)
+
+
+_channels._register_end_types(SendChannel, RecvChannel)
index e62859a9c2b08ec97390081912ff708a3a3b25ff..ffdd8a1276939731841f873bade0cdfa1c98878f 100644 (file)
@@ -822,6 +822,22 @@ class TestChannels(TestBase):
         after = set(interpreters.list_all_channels())
         self.assertEqual(after, created)
 
+    def test_shareable(self):
+        rch, sch = interpreters.create_channel()
+
+        self.assertTrue(
+            interpreters.is_shareable(rch))
+        self.assertTrue(
+            interpreters.is_shareable(sch))
+
+        sch.send_nowait(rch)
+        sch.send_nowait(sch)
+        rch2 = rch.recv()
+        sch2 = rch.recv()
+
+        self.assertEqual(rch2, rch)
+        self.assertEqual(sch2, sch)
+
 
 class TestRecvChannelAttrs(TestBase):
 
index 6096f88421a73a81ca034611fe3270f748567282..d5be76f1f0e38e2e31ced00c429f290fffd0925b 100644 (file)
@@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags)
 /* module state *************************************************************/
 
 typedef struct {
+    PyTypeObject *send_channel_type;
+    PyTypeObject *recv_channel_type;
+
     /* heap types */
     PyTypeObject *ChannelIDType;
 
@@ -218,6 +221,21 @@ get_module_state(PyObject *mod)
     return state;
 }
 
+static module_state *
+_get_current_module_state(void)
+{
+    PyObject *mod = _get_current_module();
+    if (mod == NULL) {
+        // XXX import it?
+        PyErr_SetString(PyExc_RuntimeError,
+                        MODULE_NAME " module not imported yet");
+        return NULL;
+    }
+    module_state *state = get_module_state(mod);
+    Py_DECREF(mod);
+    return state;
+}
+
 static int
 traverse_module_state(module_state *state, visitproc visit, void *arg)
 {
@@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
 static int
 clear_module_state(module_state *state)
 {
+    Py_CLEAR(state->send_channel_type);
+    Py_CLEAR(state->recv_channel_type);
+
     /* heap types */
     if (state->ChannelIDType != NULL) {
         (void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
@@ -1529,17 +1550,20 @@ typedef struct channelid {
 struct channel_id_converter_data {
     PyObject *module;
     int64_t cid;
+    int end;
 };
 
 static int
 channel_id_converter(PyObject *arg, void *ptr)
 {
     int64_t cid;
+    int end = 0;
     struct channel_id_converter_data *data = ptr;
     module_state *state = get_module_state(data->module);
     assert(state != NULL);
     if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
         cid = ((channelid *)arg)->id;
+        end = ((channelid *)arg)->end;
     }
     else if (PyIndex_Check(arg)) {
         cid = PyLong_AsLongLong(arg);
@@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr)
         return 0;
     }
     data->cid = cid;
+    data->end = end;
     return 1;
 }
 
@@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
 {
     static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
     int64_t cid;
+    int end;
     struct channel_id_converter_data cid_data = {
         .module = mod,
     };
@@ -1614,6 +1640,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
         return NULL;
     }
     cid = cid_data.cid;
+    end = cid_data.end;
 
     // Handle "send" and "recv".
     if (send == 0 && recv == 0) {
@@ -1621,14 +1648,17 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
                         "'send' and 'recv' cannot both be False");
         return NULL;
     }
-
-    int end = 0;
-    if (send == 1) {
+    else if (send == 1) {
         if (recv == 0 || recv == -1) {
             end = CHANNEL_SEND;
         }
+        else {
+            assert(recv == 1);
+            end = 0;
+        }
     }
     else if (recv == 1) {
+        assert(send == 0 || send == -1);
         end = CHANNEL_RECV;
     }
 
@@ -1773,21 +1803,12 @@ done:
     return res;
 }
 
+static PyTypeObject * _get_current_channel_end_type(int end);
+
 static PyObject *
 _channel_from_cid(PyObject *cid, int end)
 {
-    PyObject *highlevel = PyImport_ImportModule("interpreters");
-    if (highlevel == NULL) {
-        PyErr_Clear();
-        highlevel = PyImport_ImportModule("test.support.interpreters");
-        if (highlevel == NULL) {
-            return NULL;
-        }
-    }
-    const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" :
-                                                  "SendChannel";
-    PyObject *cls = PyObject_GetAttrString(highlevel, clsname);
-    Py_DECREF(highlevel);
+    PyObject *cls = (PyObject *)_get_current_channel_end_type(end);
     if (cls == NULL) {
         return NULL;
     }
@@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = {
 };
 
 
+/* SendChannel and RecvChannel classes */
+
+// XXX Use a new __xid__ protocol instead?
+
+static PyTypeObject *
+_get_current_channel_end_type(int end)
+{
+    module_state *state = _get_current_module_state();
+    if (state == NULL) {
+        return NULL;
+    }
+    PyTypeObject *cls;
+    if (end == CHANNEL_SEND) {
+        cls = state->send_channel_type;
+    }
+    else {
+        assert(end == CHANNEL_RECV);
+        cls = state->recv_channel_type;
+    }
+    if (cls == NULL) {
+        PyObject *highlevel = PyImport_ImportModule("interpreters");
+        if (highlevel == NULL) {
+            PyErr_Clear();
+            highlevel = PyImport_ImportModule("test.support.interpreters");
+            if (highlevel == NULL) {
+                return NULL;
+            }
+        }
+        if (end == CHANNEL_SEND) {
+            cls = state->send_channel_type;
+        }
+        else {
+            cls = state->recv_channel_type;
+        }
+        assert(cls != NULL);
+    }
+    return cls;
+}
+
+static PyObject *
+_channel_end_from_xid(_PyCrossInterpreterData *data)
+{
+    channelid *cid = (channelid *)_channelid_from_xid(data);
+    if (cid == NULL) {
+        return NULL;
+    }
+    PyTypeObject *cls = _get_current_channel_end_type(cid->end);
+    if (cls == NULL) {
+        return NULL;
+    }
+    PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cid);
+    Py_DECREF(cid);
+    return obj;
+}
+
+static int
+_channel_end_shared(PyThreadState *tstate, PyObject *obj,
+                    _PyCrossInterpreterData *data)
+{
+    PyObject *cidobj = PyObject_GetAttrString(obj, "_id");
+    if (cidobj == NULL) {
+        return -1;
+    }
+    if (_channelid_shared(tstate, cidobj, data) < 0) {
+        return -1;
+    }
+    data->new_object = _channel_end_from_xid;
+    return 0;
+}
+
+static int
+set_channel_end_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
+{
+    module_state *state = get_module_state(mod);
+    if (state == NULL) {
+        return -1;
+    }
+
+    if (state->send_channel_type != NULL
+        || state->recv_channel_type != NULL)
+    {
+        PyErr_SetString(PyExc_TypeError, "already registered");
+        return -1;
+    }
+    state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
+    state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);
+
+    if (_PyCrossInterpreterData_RegisterClass(send, _channel_end_shared)) {
+        return -1;
+    }
+    if (_PyCrossInterpreterData_RegisterClass(recv, _channel_end_shared)) {
+        return -1;
+    }
+
+    return 0;
+}
+
 /* module level code ********************************************************/
 
 /* globals is the process-global state for the module.  It holds all
@@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
         return NULL;
     }
     PyTypeObject *cls = state->ChannelIDType;
-    PyObject *mod = get_module_from_owned_type(cls);
-    if (mod == NULL) {
+    assert(get_module_from_owned_type(cls) == self);
+
+    return _channelid_new(self, cls, args, kwds);
+}
+
+static PyObject *
+channel__register_end_types(PyObject *self, PyObject *args, PyObject *kwds)
+{
+    static char *kwlist[] = {"send", "recv", NULL};
+    PyObject *send;
+    PyObject *recv;
+    if (!PyArg_ParseTupleAndKeywords(args, kwds,
+                                     "OO:_register_end_types", kwlist,
+                                     &send, &recv)) {
         return NULL;
     }
-    PyObject *cid = _channelid_new(mod, cls, args, kwds);
-    Py_DECREF(mod);
-    return cid;
+    if (!PyType_Check(send)) {
+        PyErr_SetString(PyExc_TypeError, "expected a type for 'send'");
+        return NULL;
+    }
+    if (!PyType_Check(recv)) {
+        PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'");
+        return NULL;
+    }
+    PyTypeObject *cls_send = (PyTypeObject *)send;
+    PyTypeObject *cls_recv = (PyTypeObject *)recv;
+
+    if (set_channel_end_types(self, cls_send, cls_recv) < 0) {
+        return NULL;
+    }
+
+    Py_RETURN_NONE;
 }
 
 static PyMethodDef module_functions[] = {
@@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = {
      METH_VARARGS | METH_KEYWORDS, channel_release_doc},
     {"_channel_id",               _PyCFunction_CAST(channel__channel_id),
      METH_VARARGS | METH_KEYWORDS, NULL},
+    {"_register_end_types",       _PyCFunction_CAST(channel__register_end_types),
+     METH_VARARGS | METH_KEYWORDS, NULL},
 
     {NULL,                        NULL}           /* sentinel */
 };