]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-84570: Add Timeouts to SendChannel.send() and RecvChannel.recv() (gh-110567)
authorEric Snow <ericsnowcurrently@gmail.com>
Tue, 17 Oct 2023 23:05:49 +0000 (17:05 -0600)
committerGitHub <noreply@github.com>
Tue, 17 Oct 2023 23:05:49 +0000 (23:05 +0000)
Include/internal/pycore_pythread.h
Lib/test/support/interpreters.py
Lib/test/test__xxinterpchannels.py
Lib/test/test_interpreters.py
Modules/_queuemodule.c
Modules/_threadmodule.c
Modules/_xxinterpchannelsmodule.c
Python/thread.c

index ffd7398eaeee5a4c323f335dc6f71e95200d87bc..d31ffc781305349ff019f686fb8108a8fca9422a 100644 (file)
@@ -89,6 +89,12 @@ extern int _PyThread_at_fork_reinit(PyThread_type_lock *lock);
 // unset: -1 seconds, in nanoseconds
 #define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000))
 
+// Exported for the _xxinterpchannels module.
+PyAPI_FUNC(int) PyThread_ParseTimeoutArg(
+    PyObject *arg,
+    int blocking,
+    PY_TIMEOUT_T *timeout);
+
 /* Helper to acquire an interruptible lock with a timeout.  If the lock acquire
  * is interrupted, signal handlers are run, and if they raise an exception,
  * PY_LOCK_INTR is returned.  Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE
index 9ba6862a9ee01a1777ce07a806c533e82d78f64a..f8f42c0e02479cb792f7b1cd1295327b08d0db15 100644 (file)
@@ -170,15 +170,25 @@ class RecvChannel(_ChannelEnd):
 
     _end = 'recv'
 
-    def recv(self, *, _sentinel=object(), _delay=10 / 1000):  # 10 milliseconds
+    def recv(self, timeout=None, *,
+             _sentinel=object(),
+             _delay=10 / 1000,  # 10 milliseconds
+             ):
         """Return the next object from the channel.
 
         This blocks until an object has been sent, if none have been
         sent already.
         """
+        if timeout is not None:
+            timeout = int(timeout)
+            if timeout < 0:
+                raise ValueError(f'timeout value must be non-negative')
+            end = time.time() + timeout
         obj = _channels.recv(self._id, _sentinel)
         while obj is _sentinel:
             time.sleep(_delay)
+            if timeout is not None and time.time() >= end:
+                raise TimeoutError
             obj = _channels.recv(self._id, _sentinel)
         return obj
 
@@ -203,12 +213,12 @@ class SendChannel(_ChannelEnd):
 
     _end = 'send'
 
-    def send(self, obj):
+    def send(self, obj, timeout=None):
         """Send the object (i.e. its data) to the channel's receiving end.
 
         This blocks until the object is received.
         """
-        _channels.send(self._id, obj, blocking=True)
+        _channels.send(self._id, obj, timeout=timeout, blocking=True)
 
     def send_nowait(self, obj):
         """Send the object to the channel's receiving end.
@@ -221,12 +231,12 @@ class SendChannel(_ChannelEnd):
         # See bpo-32604 and gh-19829.
         return _channels.send(self._id, obj, blocking=False)
 
-    def send_buffer(self, obj):
+    def send_buffer(self, obj, timeout=None):
         """Send the object's buffer to the channel's receiving end.
 
         This blocks until the object is received.
         """
-        _channels.send_buffer(self._id, obj, blocking=True)
+        _channels.send_buffer(self._id, obj, timeout=timeout, blocking=True)
 
     def send_buffer_nowait(self, obj):
         """Send the object's buffer to the channel's receiving end.
index 90a1224498fe6d1cb6dd536c81cea66dd404e4b0..1c1ef3fac9d65f5cad7c60d7327efe658406a55d 100644 (file)
@@ -864,22 +864,97 @@ class ChannelTests(TestBase):
 
         self.assertEqual(received, obj)
 
+    def test_send_timeout(self):
+        obj = b'spam'
+
+        with self.subTest('non-blocking with timeout'):
+            cid = channels.create()
+            with self.assertRaises(ValueError):
+                channels.send(cid, obj, blocking=False, timeout=0.1)
+
+        with self.subTest('timeout hit'):
+            cid = channels.create()
+            with self.assertRaises(TimeoutError):
+                channels.send(cid, obj, blocking=True, timeout=0.1)
+            with self.assertRaises(channels.ChannelEmptyError):
+                received = channels.recv(cid)
+                print(repr(received))
+
+        with self.subTest('timeout not hit'):
+            cid = channels.create()
+            def f():
+                recv_wait(cid)
+            t = threading.Thread(target=f)
+            t.start()
+            channels.send(cid, obj, blocking=True, timeout=10)
+            t.join()
+
+    def test_send_buffer_timeout(self):
+        try:
+            self._has_run_once_timeout
+        except AttributeError:
+            # At the moment, this test leaks a few references.
+            # It looks like the leak originates with the addition
+            # of _channels.send_buffer() (gh-110246), whereas the
+            # tests were added afterward.  We want this test even
+            # if the refleak isn't fixed yet, so we skip here.
+            raise unittest.SkipTest('temporarily skipped due to refleaks')
+        else:
+            self._has_run_once_timeout = True
+
+        obj = bytearray(b'spam')
+
+        with self.subTest('non-blocking with timeout'):
+            cid = channels.create()
+            with self.assertRaises(ValueError):
+                channels.send_buffer(cid, obj, blocking=False, timeout=0.1)
+
+        with self.subTest('timeout hit'):
+            cid = channels.create()
+            with self.assertRaises(TimeoutError):
+                channels.send_buffer(cid, obj, blocking=True, timeout=0.1)
+            with self.assertRaises(channels.ChannelEmptyError):
+                received = channels.recv(cid)
+                print(repr(received))
+
+        with self.subTest('timeout not hit'):
+            cid = channels.create()
+            def f():
+                recv_wait(cid)
+            t = threading.Thread(target=f)
+            t.start()
+            channels.send_buffer(cid, obj, blocking=True, timeout=10)
+            t.join()
+
     def test_send_closed_while_waiting(self):
         obj = b'spam'
         wait = self.build_send_waiter(obj)
-        cid = channels.create()
-        def f():
-            wait()
-            channels.close(cid, force=True)
-        t = threading.Thread(target=f)
-        t.start()
-        with self.assertRaises(channels.ChannelClosedError):
-            channels.send(cid, obj, blocking=True)
-        t.join()
+
+        with self.subTest('without timeout'):
+            cid = channels.create()
+            def f():
+                wait()
+                channels.close(cid, force=True)
+            t = threading.Thread(target=f)
+            t.start()
+            with self.assertRaises(channels.ChannelClosedError):
+                channels.send(cid, obj, blocking=True)
+            t.join()
+
+        with self.subTest('with timeout'):
+            cid = channels.create()
+            def f():
+                wait()
+                channels.close(cid, force=True)
+            t = threading.Thread(target=f)
+            t.start()
+            with self.assertRaises(channels.ChannelClosedError):
+                channels.send(cid, obj, blocking=True, timeout=30)
+            t.join()
 
     def test_send_buffer_closed_while_waiting(self):
         try:
-            self._has_run_once
+            self._has_run_once_closed
         except AttributeError:
             # At the moment, this test leaks a few references.
             # It looks like the leak originates with the addition
@@ -888,19 +963,32 @@ class ChannelTests(TestBase):
             # if the refleak isn't fixed yet, so we skip here.
             raise unittest.SkipTest('temporarily skipped due to refleaks')
         else:
-            self._has_run_once = True
+            self._has_run_once_closed = True
 
         obj = bytearray(b'spam')
         wait = self.build_send_waiter(obj, buffer=True)
-        cid = channels.create()
-        def f():
-            wait()
-            channels.close(cid, force=True)
-        t = threading.Thread(target=f)
-        t.start()
-        with self.assertRaises(channels.ChannelClosedError):
-            channels.send_buffer(cid, obj, blocking=True)
-        t.join()
+
+        with self.subTest('without timeout'):
+            cid = channels.create()
+            def f():
+                wait()
+                channels.close(cid, force=True)
+            t = threading.Thread(target=f)
+            t.start()
+            with self.assertRaises(channels.ChannelClosedError):
+                channels.send_buffer(cid, obj, blocking=True)
+            t.join()
+
+        with self.subTest('with timeout'):
+            cid = channels.create()
+            def f():
+                wait()
+                channels.close(cid, force=True)
+            t = threading.Thread(target=f)
+            t.start()
+            with self.assertRaises(channels.ChannelClosedError):
+                channels.send_buffer(cid, obj, blocking=True, timeout=30)
+            t.join()
 
     #-------------------
     # close
index 0910b51bfe5dbd5a4a7ee3aaa6cfed732399d024..d2d52ec9a7808fda6f16bf7ffba740d0e6aa3fbb 100644 (file)
@@ -1022,6 +1022,11 @@ class TestSendRecv(TestBase):
         self.assertEqual(obj2, b'eggs')
         self.assertNotEqual(id(obj2), int(out))
 
+    def test_recv_timeout(self):
+        r, _ = interpreters.create_channel()
+        with self.assertRaises(TimeoutError):
+            r.recv(timeout=1)
+
     def test_recv_channel_does_not_exist(self):
         ch = interpreters.RecvChannel(1_000_000)
         with self.assertRaises(interpreters.ChannelNotFoundError):
index b4bafb375c999dfdf44bf3f86192c412c2284e60..81a06cdb79a4f25f5d4525be3a01db77c68d356b 100644 (file)
@@ -214,6 +214,8 @@ _queue_SimpleQueue_get_impl(simplequeueobject *self, PyTypeObject *cls,
     PY_TIMEOUT_T microseconds;
     PyThreadState *tstate = PyThreadState_Get();
 
+    // XXX Use PyThread_ParseTimeoutArg().
+
     if (block == 0) {
         /* Non-blocking */
         microseconds = 0;
index 7620511dd1d6eb5c5ca8aa4911a62e9e75d937f1..4d4530405036437035fcab99fd9a2b913d528f1a 100644 (file)
@@ -88,14 +88,15 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds,
     char *kwlist[] = {"blocking", "timeout", NULL};
     int blocking = 1;
     PyObject *timeout_obj = NULL;
-    const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1);
-
-    *timeout = unset_timeout ;
-
     if (!PyArg_ParseTupleAndKeywords(args, kwds, "|pO:acquire", kwlist,
                                      &blocking, &timeout_obj))
         return -1;
 
+    // XXX Use PyThread_ParseTimeoutArg().
+
+    const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1);
+    *timeout = unset_timeout;
+
     if (timeout_obj
         && _PyTime_FromSecondsObject(timeout,
                                      timeout_obj, _PyTime_ROUND_TIMEOUT) < 0)
@@ -108,7 +109,7 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds,
     }
     if (*timeout < 0 && *timeout != unset_timeout) {
         PyErr_SetString(PyExc_ValueError,
-                        "timeout value must be positive");
+                        "timeout value must be a non-negative number");
         return -1;
     }
     if (!blocking)
index be53cbfc39b4ddb0ce5405b2420274d34a25f901..2e2878d5c205cf68628a3c61f2b812a76ee6f453 100644 (file)
@@ -242,9 +242,8 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
 }
 
 static int
-wait_for_lock(PyThread_type_lock mutex)
+wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
 {
-    PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT;
     PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout);
     if (res == PY_LOCK_INTR) {
         /* KeyboardInterrupt, etc. */
@@ -1883,7 +1882,8 @@ _channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting)
 }
 
 static int
-_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
+_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj,
+                   PY_TIMEOUT_T timeout)
 {
     // We use a stack variable here, so we must ensure that &waiting
     // is not held by any channel item at the point this function exits.
@@ -1901,7 +1901,7 @@ _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
     }
 
     /* Wait until the object is received. */
-    if (wait_for_lock(waiting.mutex) < 0) {
+    if (wait_for_lock(waiting.mutex, timeout) < 0) {
         assert(PyErr_Occurred());
         _waiting_finish_releasing(&waiting);
         /* The send() call is failing now, so make sure the item
@@ -2816,25 +2816,29 @@ receive end.");
 static PyObject *
 channel_send(PyObject *self, PyObject *args, PyObject *kwds)
 {
-    // XXX Add a timeout arg.
-    static char *kwlist[] = {"cid", "obj", "blocking", NULL};
-    int64_t cid;
+    static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL};
     struct channel_id_converter_data cid_data = {
         .module = self,
     };
     PyObject *obj;
     int blocking = 1;
-    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist,
+    PyObject *timeout_obj = NULL;
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$pO:channel_send", kwlist,
                                      channel_id_converter, &cid_data, &obj,
-                                     &blocking)) {
+                                     &blocking, &timeout_obj)) {
+        return NULL;
+    }
+
+    int64_t cid = cid_data.cid;
+    PY_TIMEOUT_T timeout;
+    if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
         return NULL;
     }
-    cid = cid_data.cid;
 
     /* Queue up the object. */
     int err = 0;
     if (blocking) {
-        err = _channel_send_wait(&_globals.channels, cid, obj);
+        err = _channel_send_wait(&_globals.channels, cid, obj, timeout);
     }
     else {
         err = _channel_send(&_globals.channels, cid, obj, NULL);
@@ -2855,20 +2859,25 @@ By default this waits for the object to be received.");
 static PyObject *
 channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
 {
-    static char *kwlist[] = {"cid", "obj", "blocking", NULL};
-    int64_t cid;
+    static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL};
     struct channel_id_converter_data cid_data = {
         .module = self,
     };
     PyObject *obj;
     int blocking = 1;
+    PyObject *timeout_obj = NULL;
     if (!PyArg_ParseTupleAndKeywords(args, kwds,
-                                     "O&O|$p:channel_send_buffer", kwlist,
+                                     "O&O|$pO:channel_send_buffer", kwlist,
                                      channel_id_converter, &cid_data, &obj,
-                                     &blocking)) {
+                                     &blocking, &timeout_obj)) {
+        return NULL;
+    }
+
+    int64_t cid = cid_data.cid;
+    PY_TIMEOUT_T timeout;
+    if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
         return NULL;
     }
-    cid = cid_data.cid;
 
     PyObject *tempobj = PyMemoryView_FromObject(obj);
     if (tempobj == NULL) {
@@ -2878,7 +2887,7 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
     /* Queue up the object. */
     int err = 0;
     if (blocking) {
-        err = _channel_send_wait(&_globals.channels, cid, tempobj);
+        err = _channel_send_wait(&_globals.channels, cid, tempobj, timeout);
     }
     else {
         err = _channel_send(&_globals.channels, cid, tempobj, NULL);
index 7185dd43d965b98fd4ce16d21056e24b48cb3ca9..fefae8391617f7d789c057d15e981935ecf42865 100644 (file)
@@ -93,6 +93,40 @@ PyThread_set_stacksize(size_t size)
 }
 
 
+int
+PyThread_ParseTimeoutArg(PyObject *arg, int blocking, PY_TIMEOUT_T *timeout_p)
+{
+    assert(_PyTime_FromSeconds(-1) == PyThread_UNSET_TIMEOUT);
+    if (arg == NULL || arg == Py_None) {
+        *timeout_p = blocking ? PyThread_UNSET_TIMEOUT : 0;
+        return 0;
+    }
+    if (!blocking) {
+        PyErr_SetString(PyExc_ValueError,
+                        "can't specify a timeout for a non-blocking call");
+        return -1;
+    }
+
+    _PyTime_t timeout;
+    if (_PyTime_FromSecondsObject(&timeout, arg, _PyTime_ROUND_TIMEOUT) < 0) {
+        return -1;
+    }
+    if (timeout < 0) {
+        PyErr_SetString(PyExc_ValueError,
+                        "timeout value must be a non-negative number");
+        return -1;
+    }
+
+    if (_PyTime_AsMicroseconds(timeout,
+                               _PyTime_ROUND_TIMEOUT) > PY_TIMEOUT_MAX) {
+        PyErr_SetString(PyExc_OverflowError,
+                        "timeout value is too large");
+        return -1;
+    }
+    *timeout_p = timeout;
+    return 0;
+}
+
 PyLockStatus
 PyThread_acquire_lock_timed_with_retries(PyThread_type_lock lock,
                                          PY_TIMEOUT_T timeout)