]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-84570: Send-Wait Fixes for _xxinterpchannels (gh-111006)
authorEric Snow <ericsnowcurrently@gmail.com>
Tue, 17 Oct 2023 22:32:00 +0000 (16:32 -0600)
committerGitHub <noreply@github.com>
Tue, 17 Oct 2023 22:32:00 +0000 (16:32 -0600)
There were a few things I did in gh-110565 that need to be fixed. I also forgot to add tests in that PR.

(Note that this PR exposes a refleak introduced by gh-110246. I'll take care of that separately.)

Include/internal/pycore_pythread.h
Lib/test/test__xxinterpchannels.py
Modules/_threadmodule.c
Modules/_xxinterpchannelsmodule.c
Python/thread.c

index 8ce5a79d066abfa34600cef3ed0d4989866400b9..ffd7398eaeee5a4c323f335dc6f71e95200d87bc 100644 (file)
@@ -86,6 +86,21 @@ extern int _PyThread_at_fork_reinit(PyThread_type_lock *lock);
 #endif  /* HAVE_FORK */
 
 
+// unset: -1 seconds, in nanoseconds
+#define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000))
+
+/* Helper to acquire an interruptible lock with a timeout.  If the lock acquire
+ * is interrupted, signal handlers are run, and if they raise an exception,
+ * PY_LOCK_INTR is returned.  Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE
+ * are returned, depending on whether the lock can be acquired within the
+ * timeout.
+ */
+// Exported for the _xxinterpchannels module.
+PyAPI_FUNC(PyLockStatus) PyThread_acquire_lock_timed_with_retries(
+    PyThread_type_lock,
+    PY_TIMEOUT_T microseconds);
+
+
 #ifdef __cplusplus
 }
 #endif
index ff01a339c0008e2313f87712f356ee3ee1eafd36..90a1224498fe6d1cb6dd536c81cea66dd404e4b0 100644 (file)
@@ -564,7 +564,62 @@ class ChannelTests(TestBase):
         with self.assertRaises(channels.ChannelClosedError):
             channels.list_interpreters(cid, send=False)
 
-    ####################
+    def test_allowed_types(self):
+        cid = channels.create()
+        objects = [
+            None,
+            'spam',
+            b'spam',
+            42,
+        ]
+        for obj in objects:
+            with self.subTest(obj):
+                channels.send(cid, obj, blocking=False)
+                got = channels.recv(cid)
+
+                self.assertEqual(got, obj)
+                self.assertIs(type(got), type(obj))
+                # XXX Check the following?
+                #self.assertIsNot(got, obj)
+                # XXX What about between interpreters?
+
+    def test_run_string_arg_unresolved(self):
+        cid = channels.create()
+        interp = interpreters.create()
+
+        out = _run_output(interp, dedent("""
+            import _xxinterpchannels as _channels
+            print(cid.end)
+            _channels.send(cid, b'spam', blocking=False)
+            """),
+            dict(cid=cid.send))
+        obj = channels.recv(cid)
+
+        self.assertEqual(obj, b'spam')
+        self.assertEqual(out.strip(), 'send')
+
+    # XXX For now there is no high-level channel into which the
+    # sent channel ID can be converted...
+    # Note: this test caused crashes on some buildbots (bpo-33615).
+    @unittest.skip('disabled until high-level channels exist')
+    def test_run_string_arg_resolved(self):
+        cid = channels.create()
+        cid = channels._channel_id(cid, _resolve=True)
+        interp = interpreters.create()
+
+        out = _run_output(interp, dedent("""
+            import _xxinterpchannels as _channels
+            print(chan.id.end)
+            _channels.send(chan.id, b'spam', blocking=False)
+            """),
+            dict(chan=cid.send))
+        obj = channels.recv(cid)
+
+        self.assertEqual(obj, b'spam')
+        self.assertEqual(out.strip(), 'send')
+
+    #-------------------
+    # send/recv
 
     def test_send_recv_main(self):
         cid = channels.create()
@@ -705,6 +760,9 @@ class ChannelTests(TestBase):
                 channels.recv(cid2)
             del cid2
 
+    #-------------------
+    # send_buffer
+
     def test_send_buffer(self):
         buf = bytearray(b'spamspamspam')
         cid = channels.create()
@@ -720,60 +778,131 @@ class ChannelTests(TestBase):
         obj[4:8] = b'ham.'
         self.assertEqual(obj, buf)
 
-    def test_allowed_types(self):
+    #-------------------
+    # send with waiting
+
+    def build_send_waiter(self, obj, *, buffer=False):
+        # We want a long enough sleep that send() actually has to wait.
+
+        if buffer:
+            send = channels.send_buffer
+        else:
+            send = channels.send
+
         cid = channels.create()
-        objects = [
-            None,
-            'spam',
-            b'spam',
-            42,
-        ]
-        for obj in objects:
-            with self.subTest(obj):
-                channels.send(cid, obj, blocking=False)
-                got = channels.recv(cid)
+        try:
+            started = time.monotonic()
+            send(cid, obj, blocking=False)
+            stopped = time.monotonic()
+            channels.recv(cid)
+        finally:
+            channels.destroy(cid)
+        delay = stopped - started  # seconds
+        delay *= 3
 
-                self.assertEqual(got, obj)
-                self.assertIs(type(got), type(obj))
-                # XXX Check the following?
-                #self.assertIsNot(got, obj)
-                # XXX What about between interpreters?
+        def wait():
+            time.sleep(delay)
+        return wait
 
-    def test_run_string_arg_unresolved(self):
+    def test_send_blocking_waiting(self):
+        received = None
+        obj = b'spam'
+        wait = self.build_send_waiter(obj)
         cid = channels.create()
-        interp = interpreters.create()
+        def f():
+            nonlocal received
+            wait()
+            received = recv_wait(cid)
+        t = threading.Thread(target=f)
+        t.start()
+        channels.send(cid, obj, blocking=True)
+        t.join()
 
-        out = _run_output(interp, dedent("""
-            import _xxinterpchannels as _channels
-            print(cid.end)
-            _channels.send(cid, b'spam', blocking=False)
-            """),
-            dict(cid=cid.send))
-        obj = channels.recv(cid)
+        self.assertEqual(received, obj)
 
-        self.assertEqual(obj, b'spam')
-        self.assertEqual(out.strip(), 'send')
+    def test_send_buffer_blocking_waiting(self):
+        received = None
+        obj = bytearray(b'spam')
+        wait = self.build_send_waiter(obj, buffer=True)
+        cid = channels.create()
+        def f():
+            nonlocal received
+            wait()
+            received = recv_wait(cid)
+        t = threading.Thread(target=f)
+        t.start()
+        channels.send_buffer(cid, obj, blocking=True)
+        t.join()
 
-    # XXX For now there is no high-level channel into which the
-    # sent channel ID can be converted...
-    # Note: this test caused crashes on some buildbots (bpo-33615).
-    @unittest.skip('disabled until high-level channels exist')
-    def test_run_string_arg_resolved(self):
+        self.assertEqual(received, obj)
+
+    def test_send_blocking_no_wait(self):
+        received = None
+        obj = b'spam'
         cid = channels.create()
-        cid = channels._channel_id(cid, _resolve=True)
-        interp = interpreters.create()
+        def f():
+            nonlocal received
+            received = recv_wait(cid)
+        t = threading.Thread(target=f)
+        t.start()
+        channels.send(cid, obj, blocking=True)
+        t.join()
 
-        out = _run_output(interp, dedent("""
-            import _xxinterpchannels as _channels
-            print(chan.id.end)
-            _channels.send(chan.id, b'spam', blocking=False)
-            """),
-            dict(chan=cid.send))
-        obj = channels.recv(cid)
+        self.assertEqual(received, obj)
 
-        self.assertEqual(obj, b'spam')
-        self.assertEqual(out.strip(), 'send')
+    def test_send_buffer_blocking_no_wait(self):
+        received = None
+        obj = bytearray(b'spam')
+        cid = channels.create()
+        def f():
+            nonlocal received
+            received = recv_wait(cid)
+        t = threading.Thread(target=f)
+        t.start()
+        channels.send_buffer(cid, obj, blocking=True)
+        t.join()
+
+        self.assertEqual(received, obj)
+
+    def test_send_closed_while_waiting(self):
+        obj = b'spam'
+        wait = self.build_send_waiter(obj)
+        cid = channels.create()
+        def f():
+            wait()
+            channels.close(cid, force=True)
+        t = threading.Thread(target=f)
+        t.start()
+        with self.assertRaises(channels.ChannelClosedError):
+            channels.send(cid, obj, blocking=True)
+        t.join()
+
+    def test_send_buffer_closed_while_waiting(self):
+        try:
+            self._has_run_once
+        except AttributeError:
+            # At the moment, this test leaks a few references.
+            # It looks like the leak originates with the addition
+            # of _channels.send_buffer() (gh-110246), whereas the
+            # tests were added afterward.  We want this test even
+            # if the refleak isn't fixed yet, so we skip here.
+            raise unittest.SkipTest('temporarily skipped due to refleaks')
+        else:
+            self._has_run_once = True
+
+        obj = bytearray(b'spam')
+        wait = self.build_send_waiter(obj, buffer=True)
+        cid = channels.create()
+        def f():
+            wait()
+            channels.close(cid, force=True)
+        t = threading.Thread(target=f)
+        t.start()
+        with self.assertRaises(channels.ChannelClosedError):
+            channels.send_buffer(cid, obj, blocking=True)
+        t.join()
 
+    #-------------------
     # close
 
     def test_close_single_user(self):
index 86bd560b92ba6be28a8cbb4503fd20c729888d37..7620511dd1d6eb5c5ca8aa4911a62e9e75d937f1 100644 (file)
@@ -3,7 +3,6 @@
 /* Interface to Sjoerd's portable C thread library */
 
 #include "Python.h"
-#include "pycore_ceval.h"         // _PyEval_MakePendingCalls()
 #include "pycore_dict.h"          // _PyDict_Pop()
 #include "pycore_interp.h"        // _PyInterpreterState.threads.count
 #include "pycore_moduleobject.h"  // _PyModule_GetState()
@@ -76,57 +75,10 @@ lock_dealloc(lockobject *self)
     Py_DECREF(tp);
 }
 
-/* Helper to acquire an interruptible lock with a timeout.  If the lock acquire
- * is interrupted, signal handlers are run, and if they raise an exception,
- * PY_LOCK_INTR is returned.  Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE
- * are returned, depending on whether the lock can be acquired within the
- * timeout.
- */
-static PyLockStatus
+static inline PyLockStatus
 acquire_timed(PyThread_type_lock lock, _PyTime_t timeout)
 {
-    PyThreadState *tstate = _PyThreadState_GET();
-    _PyTime_t endtime = 0;
-    if (timeout > 0) {
-        endtime = _PyDeadline_Init(timeout);
-    }
-
-    PyLockStatus r;
-    do {
-        _PyTime_t microseconds;
-        microseconds = _PyTime_AsMicroseconds(timeout, _PyTime_ROUND_CEILING);
-
-        /* first a simple non-blocking try without releasing the GIL */
-        r = PyThread_acquire_lock_timed(lock, 0, 0);
-        if (r == PY_LOCK_FAILURE && microseconds != 0) {
-            Py_BEGIN_ALLOW_THREADS
-            r = PyThread_acquire_lock_timed(lock, microseconds, 1);
-            Py_END_ALLOW_THREADS
-        }
-
-        if (r == PY_LOCK_INTR) {
-            /* Run signal handlers if we were interrupted.  Propagate
-             * exceptions from signal handlers, such as KeyboardInterrupt, by
-             * passing up PY_LOCK_INTR.  */
-            if (_PyEval_MakePendingCalls(tstate) < 0) {
-                return PY_LOCK_INTR;
-            }
-
-            /* If we're using a timeout, recompute the timeout after processing
-             * signals, since those can take time.  */
-            if (timeout > 0) {
-                timeout = _PyDeadline_Get(endtime);
-
-                /* Check for negative values, since those mean block forever.
-                 */
-                if (timeout < 0) {
-                    r = PY_LOCK_FAILURE;
-                }
-            }
-        }
-    } while (r == PY_LOCK_INTR);  /* Retry if we were interrupted. */
-
-    return r;
+    return PyThread_acquire_lock_timed_with_retries(lock, timeout);
 }
 
 static int
index 34efe9d6d1bfa685f0982000199c0141eca31aba..be53cbfc39b4ddb0ce5405b2420274d34a25f901 100644 (file)
 #include "pycore_pybuffer.h"      // _PyBuffer_ReleaseInInterpreterAndRawFree()
 #include "pycore_interp.h"        // _PyInterpreterState_LookUpID()
 
+#ifdef MS_WINDOWS
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>        // SwitchToThread()
+#elif defined(HAVE_SCHED_H)
+#include <sched.h>          // sched_yield()
+#endif
+
 
 /*
 This module has the following process-global state:
@@ -234,15 +241,25 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
     return cls;
 }
 
-static void
+static int
 wait_for_lock(PyThread_type_lock mutex)
 {
-    Py_BEGIN_ALLOW_THREADS
-    // XXX Handle eintr, etc.
-    PyThread_acquire_lock(mutex, WAIT_LOCK);
-    Py_END_ALLOW_THREADS
-
+    PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT;
+    PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout);
+    if (res == PY_LOCK_INTR) {
+        /* KeyboardInterrupt, etc. */
+        assert(PyErr_Occurred());
+        return -1;
+    }
+    else if (res == PY_LOCK_FAILURE) {
+        assert(!PyErr_Occurred());
+        assert(timeout > 0);
+        PyErr_SetString(PyExc_TimeoutError, "timed out");
+        return -1;
+    }
+    assert(res == PY_LOCK_ACQUIRED);
     PyThread_release_lock(mutex);
+    return 0;
 }
 
 
@@ -489,6 +506,7 @@ _get_current_xibufferview_type(void)
 #define ERR_CHANNEL_MUTEX_INIT -7
 #define ERR_CHANNELS_MUTEX_INIT -8
 #define ERR_NO_NEXT_CHANNEL_ID -9
+#define ERR_CHANNEL_CLOSED_WAITING -10
 
 static int
 exceptions_init(PyObject *mod)
@@ -540,6 +558,10 @@ handle_channel_error(int err, PyObject *mod, int64_t cid)
         PyErr_Format(state->ChannelClosedError,
                      "channel %" PRId64 " is closed", cid);
     }
+    else if (err == ERR_CHANNEL_CLOSED_WAITING) {
+        PyErr_Format(state->ChannelClosedError,
+                     "channel %" PRId64 " has closed", cid);
+    }
     else if (err == ERR_CHANNEL_INTERP_CLOSED) {
         PyErr_Format(state->ChannelClosedError,
                      "channel %" PRId64 " is already closed", cid);
@@ -574,36 +596,145 @@ handle_channel_error(int err, PyObject *mod, int64_t cid)
 
 /* the channel queue */
 
+typedef uintptr_t _channelitem_id_t;
+
+typedef struct wait_info {
+    PyThread_type_lock mutex;
+    enum {
+        WAITING_NO_STATUS = 0,
+        WAITING_ACQUIRED = 1,
+        WAITING_RELEASING = 2,
+        WAITING_RELEASED = 3,
+    } status;
+    int received;
+    _channelitem_id_t itemid;
+} _waiting_t;
+
+static int
+_waiting_init(_waiting_t *waiting)
+{
+    PyThread_type_lock mutex = PyThread_allocate_lock();
+    if (mutex == NULL) {
+        PyErr_NoMemory();
+        return -1;
+    }
+
+    *waiting = (_waiting_t){
+        .mutex = mutex,
+        .status = WAITING_NO_STATUS,
+    };
+    return 0;
+}
+
+static void
+_waiting_clear(_waiting_t *waiting)
+{
+    assert(waiting->status != WAITING_ACQUIRED
+           && waiting->status != WAITING_RELEASING);
+    if (waiting->mutex != NULL) {
+        PyThread_free_lock(waiting->mutex);
+        waiting->mutex = NULL;
+    }
+}
+
+static _channelitem_id_t
+_waiting_get_itemid(_waiting_t *waiting)
+{
+    return waiting->itemid;
+}
+
+static void
+_waiting_acquire(_waiting_t *waiting)
+{
+    assert(waiting->status == WAITING_NO_STATUS);
+    PyThread_acquire_lock(waiting->mutex, NOWAIT_LOCK);
+    waiting->status = WAITING_ACQUIRED;
+}
+
+static void
+_waiting_release(_waiting_t *waiting, int received)
+{
+    assert(waiting->mutex != NULL);
+    assert(waiting->status == WAITING_ACQUIRED);
+    assert(!waiting->received);
+
+    waiting->status = WAITING_RELEASING;
+    PyThread_release_lock(waiting->mutex);
+    if (waiting->received != received) {
+        assert(received == 1);
+        waiting->received = received;
+    }
+    waiting->status = WAITING_RELEASED;
+}
+
+static void
+_waiting_finish_releasing(_waiting_t *waiting)
+{
+    while (waiting->status == WAITING_RELEASING) {
+#ifdef MS_WINDOWS
+        SwitchToThread();
+#elif defined(HAVE_SCHED_H)
+        sched_yield();
+#endif
+    }
+}
+
 struct _channelitem;
 
 typedef struct _channelitem {
     _PyCrossInterpreterData *data;
-    PyThread_type_lock recv_mutex;
+    _waiting_t *waiting;
     struct _channelitem *next;
 } _channelitem;
 
-static _channelitem *
-_channelitem_new(void)
+static inline _channelitem_id_t
+_channelitem_ID(_channelitem *item)
 {
-    _channelitem *item = GLOBAL_MALLOC(_channelitem);
-    if (item == NULL) {
-        PyErr_NoMemory();
-        return NULL;
+    return (_channelitem_id_t)item;
+}
+
+static void
+_channelitem_init(_channelitem *item,
+                  _PyCrossInterpreterData *data, _waiting_t *waiting)
+{
+    *item = (_channelitem){
+        .data = data,
+        .waiting = waiting,
+    };
+    if (waiting != NULL) {
+        waiting->itemid = _channelitem_ID(item);
     }
-    item->data = NULL;
-    item->next = NULL;
-    return item;
 }
 
 static void
 _channelitem_clear(_channelitem *item)
 {
+    item->next = NULL;
+
     if (item->data != NULL) {
         // It was allocated in _channel_send().
         (void)_release_xid_data(item->data, XID_IGNORE_EXC & XID_FREE);
         item->data = NULL;
     }
-    item->next = NULL;
+
+    if (item->waiting != NULL) {
+        if (item->waiting->status == WAITING_ACQUIRED) {
+            _waiting_release(item->waiting, 0);
+        }
+        item->waiting = NULL;
+    }
+}
+
+static _channelitem *
+_channelitem_new(_PyCrossInterpreterData *data, _waiting_t *waiting)
+{
+    _channelitem *item = GLOBAL_MALLOC(_channelitem);
+    if (item == NULL) {
+        PyErr_NoMemory();
+        return NULL;
+    }
+    _channelitem_init(item, data, waiting);
+    return item;
 }
 
 static void
@@ -623,14 +754,17 @@ _channelitem_free_all(_channelitem *item)
     }
 }
 
-static _PyCrossInterpreterData *
-_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex)
+static void
+_channelitem_popped(_channelitem *item,
+                    _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
 {
-    _PyCrossInterpreterData *data = item->data;
+    assert(item->waiting == NULL || item->waiting->status == WAITING_ACQUIRED);
+    *p_data = item->data;
+    *p_waiting = item->waiting;
+    // We clear them here, so they won't be released in _channelitem_clear().
     item->data = NULL;
-    *recv_mutex = item->recv_mutex;
+    item->waiting = NULL;
     _channelitem_free(item);
-    return data;
 }
 
 typedef struct _channelqueue {
@@ -670,15 +804,13 @@ _channelqueue_free(_channelqueue *queue)
 }
 
 static int
-_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
-                  PyThread_type_lock recv_mutex)
+_channelqueue_put(_channelqueue *queue,
+                  _PyCrossInterpreterData *data, _waiting_t *waiting)
 {
-    _channelitem *item = _channelitem_new();
+    _channelitem *item = _channelitem_new(data, waiting);
     if (item == NULL) {
         return -1;
     }
-    item->data = data;
-    item->recv_mutex = recv_mutex;
 
     queue->count += 1;
     if (queue->first == NULL) {
@@ -688,15 +820,21 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
         queue->last->next = item;
     }
     queue->last = item;
+
+    if (waiting != NULL) {
+        _waiting_acquire(waiting);
+    }
+
     return 0;
 }
 
-static _PyCrossInterpreterData *
-_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
+static int
+_channelqueue_get(_channelqueue *queue,
+                  _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
 {
     _channelitem *item = queue->first;
     if (item == NULL) {
-        return NULL;
+        return ERR_CHANNEL_EMPTY;
     }
     queue->first = item->next;
     if (queue->last == item) {
@@ -704,7 +842,73 @@ _channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
     }
     queue->count -= 1;
 
-    return _channelitem_popped(item, recv_mutex);
+    _channelitem_popped(item, p_data, p_waiting);
+    return 0;
+}
+
+static int
+_channelqueue_find(_channelqueue *queue, _channelitem_id_t itemid,
+                   _channelitem **p_item, _channelitem **p_prev)
+{
+    _channelitem *prev = NULL;
+    _channelitem *item = NULL;
+    if (queue->first != NULL) {
+        if (_channelitem_ID(queue->first) == itemid) {
+            item = queue->first;
+        }
+        else {
+            prev = queue->first;
+            while (prev->next != NULL) {
+                if (_channelitem_ID(prev->next) == itemid) {
+                    item = prev->next;
+                    break;
+                }
+                prev = prev->next;
+            }
+            if (item == NULL) {
+                prev = NULL;
+            }
+        }
+    }
+    if (p_item != NULL) {
+        *p_item = item;
+    }
+    if (p_prev != NULL) {
+        *p_prev = prev;
+    }
+    return (item != NULL);
+}
+
+static void
+_channelqueue_remove(_channelqueue *queue, _channelitem_id_t itemid,
+                     _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
+{
+    _channelitem *prev = NULL;
+    _channelitem *item = NULL;
+    int found = _channelqueue_find(queue, itemid, &item, &prev);
+    if (!found) {
+        return;
+    }
+
+    assert(item->waiting != NULL);
+    assert(!item->waiting->received);
+    if (prev == NULL) {
+        assert(queue->first == item);
+        queue->first = item->next;
+    }
+    else {
+        assert(queue->first != item);
+        assert(prev->next == item);
+        prev->next = item->next;
+    }
+    item->next = NULL;
+
+    if (queue->last == item) {
+        queue->last = prev;
+    }
+    queue->count -= 1;
+
+    _channelitem_popped(item, p_data, p_waiting);
 }
 
 static void
@@ -1021,7 +1225,7 @@ _channel_free(_PyChannelState *chan)
 
 static int
 _channel_add(_PyChannelState *chan, int64_t interp,
-             _PyCrossInterpreterData *data, PyThread_type_lock recv_mutex)
+             _PyCrossInterpreterData *data, _waiting_t *waiting)
 {
     int res = -1;
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -1035,9 +1239,10 @@ _channel_add(_PyChannelState *chan, int64_t interp,
         goto done;
     }
 
-    if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) {
+    if (_channelqueue_put(chan->queue, data, waiting) != 0) {
         goto done;
     }
+    // Any errors past this point must cause a _waiting_release() call.
 
     res = 0;
 done:
@@ -1047,7 +1252,7 @@ done:
 
 static int
 _channel_next(_PyChannelState *chan, int64_t interp,
-              _PyCrossInterpreterData **res)
+              _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
 {
     int err = 0;
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -1061,16 +1266,12 @@ _channel_next(_PyChannelState *chan, int64_t interp,
         goto done;
     }
 
-    PyThread_type_lock recv_mutex = NULL;
-    _PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex);
-    if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
+    int empty = _channelqueue_get(chan->queue, p_data, p_waiting);
+    assert(empty == 0 || empty == ERR_CHANNEL_EMPTY);
+    assert(!PyErr_Occurred());
+    if (empty && chan->closing != NULL) {
         chan->open = 0;
     }
-    *res = data;
-
-    if (recv_mutex != NULL) {
-        PyThread_release_lock(recv_mutex);
-    }
 
 done:
     PyThread_release_lock(chan->mutex);
@@ -1080,6 +1281,26 @@ done:
     return err;
 }
 
+static void
+_channel_remove(_PyChannelState *chan, _channelitem_id_t itemid)
+{
+    _PyCrossInterpreterData *data = NULL;
+    _waiting_t *waiting = NULL;
+
+    PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+    _channelqueue_remove(chan->queue, itemid, &data, &waiting);
+    PyThread_release_lock(chan->mutex);
+
+    (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE);
+    if (waiting != NULL) {
+        _waiting_release(waiting, 0);
+    }
+
+    if (chan->queue->count == 0) {
+        _channel_finish_closing(chan);
+    }
+}
+
 static int
 _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
 {
@@ -1592,7 +1813,7 @@ _channel_destroy(_channels *channels, int64_t id)
 
 static int
 _channel_send(_channels *channels, int64_t id, PyObject *obj,
-              PyThread_type_lock recv_mutex)
+              _waiting_t *waiting)
 {
     PyInterpreterState *interp = _get_current_interp();
     if (interp == NULL) {
@@ -1627,8 +1848,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj,
     }
 
     // Add the data to the channel.
-    int res = _channel_add(chan, PyInterpreterState_GetID(interp), data,
-                           recv_mutex);
+    int res = _channel_add(chan, PyInterpreterState_GetID(interp),
+                           data, waiting);
     PyThread_release_lock(mutex);
     if (res != 0) {
         // We may chain an exception here:
@@ -1640,31 +1861,74 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj,
     return 0;
 }
 
+static void
+_channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting)
+{
+    // Look up the channel.
+    PyThread_type_lock mutex = NULL;
+    _PyChannelState *chan = NULL;
+    int err = _channels_lookup(channels, cid, &mutex, &chan);
+    if (err != 0) {
+        // The channel was already closed, etc.
+        assert(waiting->status == WAITING_RELEASED);
+        return;  // Ignore the error.
+    }
+    assert(chan != NULL);
+    // Past this point we are responsible for releasing the mutex.
+
+    _channelitem_id_t itemid = _waiting_get_itemid(waiting);
+    _channel_remove(chan, itemid);
+
+    PyThread_release_lock(mutex);
+}
+
 static int
 _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
 {
-    PyThread_type_lock mutex = PyThread_allocate_lock();
-    if (mutex == NULL) {
-        PyErr_NoMemory();
+    // We use a stack variable here, so we must ensure that &waiting
+    // is not held by any channel item at the point this function exits.
+    _waiting_t waiting;
+    if (_waiting_init(&waiting) < 0) {
+        assert(PyErr_Occurred());
         return -1;
     }
-    PyThread_acquire_lock(mutex, NOWAIT_LOCK);
 
     /* Queue up the object. */
-    int res = _channel_send(channels, cid, obj, mutex);
+    int res = _channel_send(channels, cid, obj, &waiting);
     if (res < 0) {
-        PyThread_release_lock(mutex);
+        assert(waiting.status == WAITING_NO_STATUS);
         goto finally;
     }
 
     /* Wait until the object is received. */
-    wait_for_lock(mutex);
+    if (wait_for_lock(waiting.mutex) < 0) {
+        assert(PyErr_Occurred());
+        _waiting_finish_releasing(&waiting);
+        /* The send() call is failing now, so make sure the item
+           won't be received. */
+        _channel_clear_sent(channels, cid, &waiting);
+        assert(waiting.status == WAITING_RELEASED);
+        if (!waiting.received) {
+            res = -1;
+            goto finally;
+        }
+        // XXX Emit a warning if not a TimeoutError?
+        PyErr_Clear();
+    }
+    else {
+        _waiting_finish_releasing(&waiting);
+        assert(waiting.status == WAITING_RELEASED);
+        if (!waiting.received) {
+            res = ERR_CHANNEL_CLOSED_WAITING;
+            goto finally;
+        }
+    }
 
     /* success! */
     res = 0;
 
 finally:
-    // XXX Delete the lock.
+    _waiting_clear(&waiting);
     return res;
 }
 
@@ -1695,7 +1959,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
 
     // Pop off the next item from the channel.
     _PyCrossInterpreterData *data = NULL;
-    err = _channel_next(chan, PyInterpreterState_GetID(interp), &data);
+    _waiting_t *waiting = NULL;
+    err = _channel_next(chan, PyInterpreterState_GetID(interp), &data,
+                        &waiting);
     PyThread_release_lock(mutex);
     if (err != 0) {
         return err;
@@ -1711,6 +1977,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
         assert(PyErr_Occurred());
         // It was allocated in _channel_send(), so we free it.
         (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE);
+        if (waiting != NULL) {
+            _waiting_release(waiting, 0);
+        }
         return -1;
     }
     // It was allocated in _channel_send(), so we free it.
@@ -1719,9 +1988,17 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
         // The source interpreter has been destroyed already.
         assert(PyErr_Occurred());
         Py_DECREF(obj);
+        if (waiting != NULL) {
+            _waiting_release(waiting, 0);
+        }
         return -1;
     }
 
+    // Notify the sender.
+    if (waiting != NULL) {
+        _waiting_release(waiting, 1);
+    }
+
     *res = obj;
     return 0;
 }
index bf207cecb90505dd0b85c4dadde2f4c15d26d5f2..7185dd43d965b98fd4ce16d21056e24b48cb3ca9 100644 (file)
@@ -6,6 +6,7 @@
    Stuff shared by all thread_*.h files is collected here. */
 
 #include "Python.h"
+#include "pycore_ceval.h"         // _PyEval_MakePendingCalls()
 #include "pycore_pystate.h"       // _PyInterpreterState_GET()
 #include "pycore_structseq.h"     // _PyStructSequence_FiniBuiltin()
 #include "pycore_pythread.h"      // _POSIX_THREADS
@@ -92,6 +93,55 @@ PyThread_set_stacksize(size_t size)
 }
 
 
+PyLockStatus
+PyThread_acquire_lock_timed_with_retries(PyThread_type_lock lock,
+                                         PY_TIMEOUT_T timeout)
+{
+    PyThreadState *tstate = _PyThreadState_GET();
+    _PyTime_t endtime = 0;
+    if (timeout > 0) {
+        endtime = _PyDeadline_Init(timeout);
+    }
+
+    PyLockStatus r;
+    do {
+        _PyTime_t microseconds;
+        microseconds = _PyTime_AsMicroseconds(timeout, _PyTime_ROUND_CEILING);
+
+        /* first a simple non-blocking try without releasing the GIL */
+        r = PyThread_acquire_lock_timed(lock, 0, 0);
+        if (r == PY_LOCK_FAILURE && microseconds != 0) {
+            Py_BEGIN_ALLOW_THREADS
+            r = PyThread_acquire_lock_timed(lock, microseconds, 1);
+            Py_END_ALLOW_THREADS
+        }
+
+        if (r == PY_LOCK_INTR) {
+            /* Run signal handlers if we were interrupted.  Propagate
+             * exceptions from signal handlers, such as KeyboardInterrupt, by
+             * passing up PY_LOCK_INTR.  */
+            if (_PyEval_MakePendingCalls(tstate) < 0) {
+                return PY_LOCK_INTR;
+            }
+
+            /* If we're using a timeout, recompute the timeout after processing
+             * signals, since those can take time.  */
+            if (timeout > 0) {
+                timeout = _PyDeadline_Get(endtime);
+
+                /* Check for negative values, since those mean block forever.
+                 */
+                if (timeout < 0) {
+                    r = PY_LOCK_FAILURE;
+                }
+            }
+        }
+    } while (r == PY_LOCK_INTR);  /* Retry if we were interrupted. */
+
+    return r;
+}
+
+
 /* Thread Specific Storage (TSS) API
 
    Cross-platform components of TSS API implementation.