]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-84570: Implement Waiting in SendChannel.send() (gh-110565)
authorEric Snow <ericsnowcurrently@gmail.com>
Tue, 10 Oct 2023 09:35:14 +0000 (03:35 -0600)
committerGitHub <noreply@github.com>
Tue, 10 Oct 2023 09:35:14 +0000 (09:35 +0000)
We had been faking it (poorly).

We will add timeouts separately.

Lib/test/support/interpreters.py
Lib/test/test__xxinterpchannels.py
Lib/test/test_interpreters.py
Modules/_xxinterpchannelsmodule.c

index d61724ca86b66c86ae41cdd2f1dc3d81842e2d05..9ba6862a9ee01a1777ce07a806c533e82d78f64a 100644 (file)
@@ -208,11 +208,7 @@ class SendChannel(_ChannelEnd):
 
         This blocks until the object is received.
         """
-        _channels.send(self._id, obj)
-        # XXX We are missing a low-level channel_send_wait().
-        # See bpo-32604 and gh-19829.
-        # Until that shows up we fake it:
-        time.sleep(2)
+        _channels.send(self._id, obj, blocking=True)
 
     def send_nowait(self, obj):
         """Send the object to the channel's receiving end.
@@ -223,14 +219,14 @@ class SendChannel(_ChannelEnd):
         # XXX Note that at the moment channel_send() only ever returns
         # None.  This should be fixed when channel_send_wait() is added.
         # See bpo-32604 and gh-19829.
-        return _channels.send(self._id, obj)
+        return _channels.send(self._id, obj, blocking=False)
 
     def send_buffer(self, obj):
         """Send the object's buffer to the channel's receiving end.
 
         This blocks until the object is received.
         """
-        _channels.send_buffer(self._id, obj)
+        _channels.send_buffer(self._id, obj, blocking=True)
 
     def send_buffer_nowait(self, obj):
         """Send the object's buffer to the channel's receiving end.
@@ -238,7 +234,7 @@ class SendChannel(_ChannelEnd):
         If the object is immediately received then return True
         (else False).  Otherwise this is the same as send().
         """
-        return _channels.send_buffer(self._id, obj)
+        return _channels.send_buffer(self._id, obj, blocking=False)
 
     def close(self):
         _channels.close(self._id, send=True)
index cb69f73c4348d49218ecd3f2f584f789a1dfd03e..ff01a339c0008e2313f87712f356ee3ee1eafd36 100644 (file)
@@ -21,6 +21,13 @@ channels = import_helper.import_module('_xxinterpchannels')
 ##################################
 # helpers
 
+def recv_wait(cid):
+    while True:
+        try:
+            return channels.recv(cid)
+        except channels.ChannelEmptyError:
+            time.sleep(0.1)
+
 #@contextmanager
 #def run_threaded(id, source, **shared):
 #    def run():
@@ -189,7 +196,7 @@ def run_action(cid, action, end, state, *, hideclosed=True):
 def _run_action(cid, action, end, state):
     if action == 'use':
         if end == 'send':
-            channels.send(cid, b'spam')
+            channels.send(cid, b'spam', blocking=False)
             return state.incr()
         elif end == 'recv':
             if not state.pending:
@@ -332,7 +339,7 @@ class ChannelIDTests(TestBase):
         chan = channels.create()
 
         obj = channels.create()
-        channels.send(chan, obj)
+        channels.send(chan, obj, blocking=False)
         got = channels.recv(chan)
 
         self.assertEqual(got, obj)
@@ -390,7 +397,7 @@ class ChannelTests(TestBase):
         """Test basic listing channel interpreters."""
         interp0 = interpreters.get_main()
         cid = channels.create()
-        channels.send(cid, "send")
+        channels.send(cid, "send", blocking=False)
         # Test for a channel that has one end associated to an interpreter.
         send_interps = channels.list_interpreters(cid, send=True)
         recv_interps = channels.list_interpreters(cid, send=False)
@@ -416,10 +423,10 @@ class ChannelTests(TestBase):
         interp3 = interpreters.create()
         cid = channels.create()
 
-        channels.send(cid, "send")
+        channels.send(cid, "send", blocking=False)
         _run_output(interp1, dedent(f"""
             import _xxinterpchannels as _channels
-            _channels.send({cid}, "send")
+            _channels.send({cid}, "send", blocking=False)
             """))
         _run_output(interp2, dedent(f"""
             import _xxinterpchannels as _channels
@@ -439,7 +446,7 @@ class ChannelTests(TestBase):
         interp0 = interpreters.get_main()
         interp1 = interpreters.create()
         cid = channels.create()
-        channels.send(cid, "send")
+        channels.send(cid, "send", blocking=False)
         _run_output(interp1, dedent(f"""
             import _xxinterpchannels as _channels
             obj = _channels.recv({cid})
@@ -465,12 +472,12 @@ class ChannelTests(TestBase):
         interp1 = interpreters.create()
         interp2 = interpreters.create()
         cid = channels.create()
-        channels.send(cid, "data")
+        channels.send(cid, "data", blocking=False)
         _run_output(interp1, dedent(f"""
             import _xxinterpchannels as _channels
             obj = _channels.recv({cid})
             """))
-        channels.send(cid, "data")
+        channels.send(cid, "data", blocking=False)
         _run_output(interp2, dedent(f"""
             import _xxinterpchannels as _channels
             obj = _channels.recv({cid})
@@ -506,7 +513,7 @@ class ChannelTests(TestBase):
         interp1 = interpreters.create()
         cid = channels.create()
         # Put something in the channel so that it's not empty.
-        channels.send(cid, "send")
+        channels.send(cid, "send", blocking=False)
 
         # Check initial state.
         send_interps = channels.list_interpreters(cid, send=True)
@@ -528,7 +535,7 @@ class ChannelTests(TestBase):
         interp1 = interpreters.create()
         cid = channels.create()
         # Put something in the channel so that it's not empty.
-        channels.send(cid, "send")
+        channels.send(cid, "send", blocking=False)
 
         # Check initial state.
         send_interps = channels.list_interpreters(cid, send=True)
@@ -562,7 +569,7 @@ class ChannelTests(TestBase):
     def test_send_recv_main(self):
         cid = channels.create()
         orig = b'spam'
-        channels.send(cid, orig)
+        channels.send(cid, orig, blocking=False)
         obj = channels.recv(cid)
 
         self.assertEqual(obj, orig)
@@ -574,7 +581,7 @@ class ChannelTests(TestBase):
             import _xxinterpchannels as _channels
             cid = _channels.create()
             orig = b'spam'
-            _channels.send(cid, orig)
+            _channels.send(cid, orig, blocking=False)
             obj = _channels.recv(cid)
             assert obj is not orig
             assert obj == orig
@@ -585,7 +592,7 @@ class ChannelTests(TestBase):
         id1 = interpreters.create()
         out = _run_output(id1, dedent(f"""
             import _xxinterpchannels as _channels
-            _channels.send({cid}, b'spam')
+            _channels.send({cid}, b'spam', blocking=False)
             """))
         obj = channels.recv(cid)
 
@@ -595,19 +602,14 @@ class ChannelTests(TestBase):
         cid = channels.create()
 
         def f():
-            while True:
-                try:
-                    obj = channels.recv(cid)
-                    break
-                except channels.ChannelEmptyError:
-                    time.sleep(0.1)
+            obj = recv_wait(cid)
             channels.send(cid, obj)
         t = threading.Thread(target=f)
         t.start()
 
         channels.send(cid, b'spam')
+        obj = recv_wait(cid)
         t.join()
-        obj = channels.recv(cid)
 
         self.assertEqual(obj, b'spam')
 
@@ -634,8 +636,8 @@ class ChannelTests(TestBase):
         t.start()
 
         channels.send(cid, b'spam')
+        obj = recv_wait(cid)
         t.join()
-        obj = channels.recv(cid)
 
         self.assertEqual(obj, b'eggs')
 
@@ -656,10 +658,10 @@ class ChannelTests(TestBase):
         default = object()
         cid = channels.create()
         obj1 = channels.recv(cid, default)
-        channels.send(cid, None)
-        channels.send(cid, 1)
-        channels.send(cid, b'spam')
-        channels.send(cid, b'eggs')
+        channels.send(cid, None, blocking=False)
+        channels.send(cid, 1, blocking=False)
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'eggs', blocking=False)
         obj2 = channels.recv(cid, default)
         obj3 = channels.recv(cid, default)
         obj4 = channels.recv(cid)
@@ -679,7 +681,7 @@ class ChannelTests(TestBase):
             interp = interpreters.create()
             interpreters.run_string(interp, dedent(f"""
                 import _xxinterpchannels as _channels
-                _channels.send({cid1}, b'spam')
+                _channels.send({cid1}, b'spam', blocking=False)
                 """))
             interpreters.destroy(interp)
 
@@ -692,9 +694,9 @@ class ChannelTests(TestBase):
             interp = interpreters.create()
             interpreters.run_string(interp, dedent(f"""
                 import _xxinterpchannels as _channels
-                _channels.send({cid2}, b'spam')
+                _channels.send({cid2}, b'spam', blocking=False)
                 """))
-            channels.send(cid2, b'eggs')
+            channels.send(cid2, b'eggs', blocking=False)
             interpreters.destroy(interp)
 
             channels.recv(cid2)
@@ -706,7 +708,7 @@ class ChannelTests(TestBase):
     def test_send_buffer(self):
         buf = bytearray(b'spamspamspam')
         cid = channels.create()
-        channels.send_buffer(cid, buf)
+        channels.send_buffer(cid, buf, blocking=False)
         obj = channels.recv(cid)
 
         self.assertIsNot(obj, buf)
@@ -728,7 +730,7 @@ class ChannelTests(TestBase):
         ]
         for obj in objects:
             with self.subTest(obj):
-                channels.send(cid, obj)
+                channels.send(cid, obj, blocking=False)
                 got = channels.recv(cid)
 
                 self.assertEqual(got, obj)
@@ -744,7 +746,7 @@ class ChannelTests(TestBase):
         out = _run_output(interp, dedent("""
             import _xxinterpchannels as _channels
             print(cid.end)
-            _channels.send(cid, b'spam')
+            _channels.send(cid, b'spam', blocking=False)
             """),
             dict(cid=cid.send))
         obj = channels.recv(cid)
@@ -764,7 +766,7 @@ class ChannelTests(TestBase):
         out = _run_output(interp, dedent("""
             import _xxinterpchannels as _channels
             print(chan.id.end)
-            _channels.send(chan.id, b'spam')
+            _channels.send(chan.id, b'spam', blocking=False)
             """),
             dict(chan=cid.send))
         obj = channels.recv(cid)
@@ -776,7 +778,7 @@ class ChannelTests(TestBase):
 
     def test_close_single_user(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
         channels.recv(cid)
         channels.close(cid)
 
@@ -791,7 +793,7 @@ class ChannelTests(TestBase):
         id2 = interpreters.create()
         interpreters.run_string(id1, dedent(f"""
             import _xxinterpchannels as _channels
-            _channels.send({cid}, b'spam')
+            _channels.send({cid}, b'spam', blocking=False)
             """))
         interpreters.run_string(id2, dedent(f"""
             import _xxinterpchannels as _channels
@@ -811,7 +813,7 @@ class ChannelTests(TestBase):
 
     def test_close_multiple_times(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
         channels.recv(cid)
         channels.close(cid)
 
@@ -828,7 +830,7 @@ class ChannelTests(TestBase):
         for send, recv in tests:
             with self.subTest((send, recv)):
                 cid = channels.create()
-                channels.send(cid, b'spam')
+                channels.send(cid, b'spam', blocking=False)
                 channels.recv(cid)
                 channels.close(cid, send=send, recv=recv)
 
@@ -839,31 +841,31 @@ class ChannelTests(TestBase):
 
     def test_close_defaults_with_unused_items(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'ham')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'ham', blocking=False)
 
         with self.assertRaises(channels.ChannelNotEmptyError):
             channels.close(cid)
         channels.recv(cid)
-        channels.send(cid, b'eggs')
+        channels.send(cid, b'eggs', blocking=False)
 
     def test_close_recv_with_unused_items_unforced(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'ham')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'ham', blocking=False)
 
         with self.assertRaises(channels.ChannelNotEmptyError):
             channels.close(cid, recv=True)
         channels.recv(cid)
-        channels.send(cid, b'eggs')
+        channels.send(cid, b'eggs', blocking=False)
         channels.recv(cid)
         channels.recv(cid)
         channels.close(cid, recv=True)
 
     def test_close_send_with_unused_items_unforced(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'ham')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'ham', blocking=False)
         channels.close(cid, send=True)
 
         with self.assertRaises(channels.ChannelClosedError):
@@ -875,21 +877,21 @@ class ChannelTests(TestBase):
 
     def test_close_both_with_unused_items_unforced(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'ham')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'ham', blocking=False)
 
         with self.assertRaises(channels.ChannelNotEmptyError):
             channels.close(cid, recv=True, send=True)
         channels.recv(cid)
-        channels.send(cid, b'eggs')
+        channels.send(cid, b'eggs', blocking=False)
         channels.recv(cid)
         channels.recv(cid)
         channels.close(cid, recv=True)
 
     def test_close_recv_with_unused_items_forced(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'ham')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'ham', blocking=False)
         channels.close(cid, recv=True, force=True)
 
         with self.assertRaises(channels.ChannelClosedError):
@@ -899,8 +901,8 @@ class ChannelTests(TestBase):
 
     def test_close_send_with_unused_items_forced(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'ham')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'ham', blocking=False)
         channels.close(cid, send=True, force=True)
 
         with self.assertRaises(channels.ChannelClosedError):
@@ -910,8 +912,8 @@ class ChannelTests(TestBase):
 
     def test_close_both_with_unused_items_forced(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'ham')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'ham', blocking=False)
         channels.close(cid, send=True, recv=True, force=True)
 
         with self.assertRaises(channels.ChannelClosedError):
@@ -930,7 +932,7 @@ class ChannelTests(TestBase):
 
     def test_close_by_unassociated_interp(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
         interp = interpreters.create()
         interpreters.run_string(interp, dedent(f"""
             import _xxinterpchannels as _channels
@@ -943,9 +945,9 @@ class ChannelTests(TestBase):
 
     def test_close_used_multiple_times_by_single_user(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'spam')
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'spam', blocking=False)
         channels.recv(cid)
         channels.close(cid, force=True)
 
@@ -1017,7 +1019,7 @@ class ChannelReleaseTests(TestBase):
 
     def test_single_user(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
         channels.recv(cid)
         channels.release(cid, send=True, recv=True)
 
@@ -1032,7 +1034,7 @@ class ChannelReleaseTests(TestBase):
         id2 = interpreters.create()
         interpreters.run_string(id1, dedent(f"""
             import _xxinterpchannels as _channels
-            _channels.send({cid}, b'spam')
+            _channels.send({cid}, b'spam', blocking=False)
             """))
         out = _run_output(id2, dedent(f"""
             import _xxinterpchannels as _channels
@@ -1048,7 +1050,7 @@ class ChannelReleaseTests(TestBase):
 
     def test_no_kwargs(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
         channels.recv(cid)
         channels.release(cid)
 
@@ -1059,7 +1061,7 @@ class ChannelReleaseTests(TestBase):
 
     def test_multiple_times(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
         channels.recv(cid)
         channels.release(cid, send=True, recv=True)
 
@@ -1068,8 +1070,8 @@ class ChannelReleaseTests(TestBase):
 
     def test_with_unused_items(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'ham')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'ham', blocking=False)
         channels.release(cid, send=True, recv=True)
 
         with self.assertRaises(channels.ChannelClosedError):
@@ -1086,7 +1088,7 @@ class ChannelReleaseTests(TestBase):
 
     def test_by_unassociated_interp(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
         interp = interpreters.create()
         interpreters.run_string(interp, dedent(f"""
             import _xxinterpchannels as _channels
@@ -1105,7 +1107,7 @@ class ChannelReleaseTests(TestBase):
         interp = interpreters.create()
         interpreters.run_string(interp, dedent(f"""
             import _xxinterpchannels as _channels
-            obj = _channels.send({cid}, b'spam')
+            obj = _channels.send({cid}, b'spam', blocking=False)
             _channels.release({cid})
             """))
 
@@ -1115,9 +1117,9 @@ class ChannelReleaseTests(TestBase):
     def test_partially(self):
         # XXX Is partial close too weird/confusing?
         cid = channels.create()
-        channels.send(cid, None)
+        channels.send(cid, None, blocking=False)
         channels.recv(cid)
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
         channels.release(cid, send=True)
         obj = channels.recv(cid)
 
@@ -1125,9 +1127,9 @@ class ChannelReleaseTests(TestBase):
 
     def test_used_multiple_times_by_single_user(self):
         cid = channels.create()
-        channels.send(cid, b'spam')
-        channels.send(cid, b'spam')
-        channels.send(cid, b'spam')
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'spam', blocking=False)
+        channels.send(cid, b'spam', blocking=False)
         channels.recv(cid)
         channels.release(cid, send=True, recv=True)
 
@@ -1212,7 +1214,7 @@ class ChannelCloseFixture(namedtuple('ChannelCloseFixture',
                 cid = _xxsubchannels.create()
                 # We purposefully send back an int to avoid tying the
                 # channel to the other interpreter.
-                _xxsubchannels.send({ch}, int(cid))
+                _xxsubchannels.send({ch}, int(cid), blocking=False)
                 del _xxsubinterpreters
                 """)
             self._cid = channels.recv(ch)
@@ -1442,8 +1444,8 @@ class ExhaustiveChannelTests(TestBase):
                     {repr(fix.state)},
                     hideclosed={hideclosed},
                     )
-                channels.send({_cid}, result.pending.to_bytes(1, 'little'))
-                channels.send({_cid}, b'X' if result.closed else b'')
+                channels.send({_cid}, result.pending.to_bytes(1, 'little'), blocking=False)
+                channels.send({_cid}, b'X' if result.closed else b'', blocking=False)
                 """)
             result = ChannelState(
                 pending=int.from_bytes(channels.recv(_cid), 'little'),
@@ -1490,7 +1492,7 @@ class ExhaustiveChannelTests(TestBase):
                 """)
             run_interp(interp.id, """
                 with helpers.expect_channel_closed():
-                    channels.send(cid, b'spam')
+                    channels.send(cid, b'spam', blocking=False)
                 """)
             run_interp(interp.id, """
                 with helpers.expect_channel_closed():
index fe7b14de459a7d36078dc0cf666d6c043e33fc73..0910b51bfe5dbd5a4a7ee3aaa6cfed732399d024 100644 (file)
@@ -964,8 +964,8 @@ class TestSendRecv(TestBase):
 
         orig = b'spam'
         s.send(orig)
-        t.join()
         obj = r.recv()
+        t.join()
 
         self.assertEqual(obj, orig)
         self.assertIsNot(obj, orig)
index a1531c5c3db34dc58ef44268ff8d925504c8e44c..bc8cd0e2cff4c13085f18a945fdab0bfd0238e5b 100644 (file)
@@ -234,6 +234,17 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
     return cls;
 }
 
+static void
+wait_for_lock(PyThread_type_lock mutex)
+{
+    Py_BEGIN_ALLOW_THREADS
+    // XXX Handle eintr, etc.
+    PyThread_acquire_lock(mutex, WAIT_LOCK);
+    Py_END_ALLOW_THREADS
+
+    PyThread_release_lock(mutex);
+}
+
 
 /* Cross-interpreter Buffer Views *******************************************/
 
@@ -567,6 +578,7 @@ struct _channelitem;
 
 typedef struct _channelitem {
     _PyCrossInterpreterData *data;
+    PyThread_type_lock recv_mutex;
     struct _channelitem *next;
 } _channelitem;
 
@@ -612,10 +624,11 @@ _channelitem_free_all(_channelitem *item)
 }
 
 static _PyCrossInterpreterData *
-_channelitem_popped(_channelitem *item)
+_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex)
 {
     _PyCrossInterpreterData *data = item->data;
     item->data = NULL;
+    *recv_mutex = item->recv_mutex;
     _channelitem_free(item);
     return data;
 }
@@ -657,13 +670,15 @@ _channelqueue_free(_channelqueue *queue)
 }
 
 static int
-_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
+_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
+                  PyThread_type_lock recv_mutex)
 {
     _channelitem *item = _channelitem_new();
     if (item == NULL) {
         return -1;
     }
     item->data = data;
+    item->recv_mutex = recv_mutex;
 
     queue->count += 1;
     if (queue->first == NULL) {
@@ -677,7 +692,7 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
 }
 
 static _PyCrossInterpreterData *
-_channelqueue_get(_channelqueue *queue)
+_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
 {
     _channelitem *item = queue->first;
     if (item == NULL) {
@@ -689,7 +704,7 @@ _channelqueue_get(_channelqueue *queue)
     }
     queue->count -= 1;
 
-    return _channelitem_popped(item);
+    return _channelitem_popped(item, recv_mutex);
 }
 
 static void
@@ -1006,7 +1021,7 @@ _channel_free(_PyChannelState *chan)
 
 static int
 _channel_add(_PyChannelState *chan, int64_t interp,
-             _PyCrossInterpreterData *data)
+             _PyCrossInterpreterData *data, PyThread_type_lock recv_mutex)
 {
     int res = -1;
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -1020,7 +1035,7 @@ _channel_add(_PyChannelState *chan, int64_t interp,
         goto done;
     }
 
-    if (_channelqueue_put(chan->queue, data) != 0) {
+    if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) {
         goto done;
     }
 
@@ -1046,12 +1061,17 @@ _channel_next(_PyChannelState *chan, int64_t interp,
         goto done;
     }
 
-    _PyCrossInterpreterData *data = _channelqueue_get(chan->queue);
+    PyThread_type_lock recv_mutex = NULL;
+    _PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex);
     if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
         chan->open = 0;
     }
     *res = data;
 
+    if (recv_mutex != NULL) {
+        PyThread_release_lock(recv_mutex);
+    }
+
 done:
     PyThread_release_lock(chan->mutex);
     if (chan->queue->count == 0) {
@@ -1571,7 +1591,8 @@ _channel_destroy(_channels *channels, int64_t id)
 }
 
 static int
-_channel_send(_channels *channels, int64_t id, PyObject *obj)
+_channel_send(_channels *channels, int64_t id, PyObject *obj,
+              PyThread_type_lock recv_mutex)
 {
     PyInterpreterState *interp = _get_current_interp();
     if (interp == NULL) {
@@ -1606,7 +1627,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj)
     }
 
     // Add the data to the channel.
-    int res = _channel_add(chan, PyInterpreterState_GetID(interp), data);
+    int res = _channel_add(chan, PyInterpreterState_GetID(interp), data,
+                           recv_mutex);
     PyThread_release_lock(mutex);
     if (res != 0) {
         // We may chain an exception here:
@@ -2489,42 +2511,70 @@ receive end.");
 static PyObject *
 channel_send(PyObject *self, PyObject *args, PyObject *kwds)
 {
-    static char *kwlist[] = {"cid", "obj", NULL};
+    // XXX Add a timeout arg.
+    static char *kwlist[] = {"cid", "obj", "blocking", NULL};
     int64_t cid;
     struct channel_id_converter_data cid_data = {
         .module = self,
     };
     PyObject *obj;
-    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist,
-                                     channel_id_converter, &cid_data, &obj)) {
+    int blocking = 1;
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist,
+                                     channel_id_converter, &cid_data, &obj,
+                                     &blocking)) {
         return NULL;
     }
     cid = cid_data.cid;
 
-    int err = _channel_send(&_globals.channels, cid, obj);
-    if (handle_channel_error(err, self, cid)) {
-        return NULL;
+    if (blocking) {
+        PyThread_type_lock mutex = PyThread_allocate_lock();
+        if (mutex == NULL) {
+            PyErr_NoMemory();
+            return NULL;
+        }
+        PyThread_acquire_lock(mutex, WAIT_LOCK);
+
+        /* Queue up the object. */
+        int err = _channel_send(&_globals.channels, cid, obj, mutex);
+        if (handle_channel_error(err, self, cid)) {
+            PyThread_release_lock(mutex);
+            return NULL;
+        }
+
+        /* Wait until the object is received. */
+        wait_for_lock(mutex);
+    }
+    else {
+        /* Queue up the object. */
+        int err = _channel_send(&_globals.channels, cid, obj, NULL);
+        if (handle_channel_error(err, self, cid)) {
+            return NULL;
+        }
     }
+
     Py_RETURN_NONE;
 }
 
 PyDoc_STRVAR(channel_send_doc,
-"channel_send(cid, obj)\n\
+"channel_send(cid, obj, blocking=True)\n\
 \n\
-Add the object's data to the channel's queue.");
+Add the object's data to the channel's queue.\n\
+By default this waits for the object to be received.");
 
 static PyObject *
 channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
 {
-    static char *kwlist[] = {"cid", "obj", NULL};
+    static char *kwlist[] = {"cid", "obj", "blocking", NULL};
     int64_t cid;
     struct channel_id_converter_data cid_data = {
         .module = self,
     };
     PyObject *obj;
+    int blocking = 1;
     if (!PyArg_ParseTupleAndKeywords(args, kwds,
-                                     "O&O:channel_send_buffer", kwlist,
-                                     channel_id_converter, &cid_data, &obj)) {
+                                     "O&O|$p:channel_send_buffer", kwlist,
+                                     channel_id_converter, &cid_data, &obj,
+                                     &blocking)) {
         return NULL;
     }
     cid = cid_data.cid;
@@ -2534,18 +2584,43 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
         return NULL;
     }
 
-    int err = _channel_send(&_globals.channels, cid, tempobj);
-    Py_DECREF(tempobj);
-    if (handle_channel_error(err, self, cid)) {
-        return NULL;
+    if (blocking) {
+        PyThread_type_lock mutex = PyThread_allocate_lock();
+        if (mutex == NULL) {
+            Py_DECREF(tempobj);
+            PyErr_NoMemory();
+            return NULL;
+        }
+        PyThread_acquire_lock(mutex, WAIT_LOCK);
+
+        /* Queue up the buffer. */
+        int err = _channel_send(&_globals.channels, cid, tempobj, mutex);
+        Py_DECREF(tempobj);
+        if (handle_channel_error(err, self, cid)) {
+            PyThread_acquire_lock(mutex, WAIT_LOCK);
+            return NULL;
+        }
+
+        /* Wait until the buffer is received. */
+        wait_for_lock(mutex);
     }
+    else {
+        /* Queue up the buffer. */
+        int err = _channel_send(&_globals.channels, cid, tempobj, NULL);
+        Py_DECREF(tempobj);
+        if (handle_channel_error(err, self, cid)) {
+            return NULL;
+        }
+    }
+
     Py_RETURN_NONE;
 }
 
 PyDoc_STRVAR(channel_send_buffer_doc,
-"channel_send_buffer(cid, obj)\n\
+"channel_send_buffer(cid, obj, blocking=True)\n\
 \n\
-Add the object's buffer to the channel's queue.");
+Add the object's buffer to the channel's queue.\n\
+By default this waits for the object to be received.");
 
 static PyObject *
 channel_recv(PyObject *self, PyObject *args, PyObject *kwds)