]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-128340: add thread safe handle for `loop.call_soon_threadsafe` (#128369)
authorKumar Aditya <kumaraditya@python.org>
Mon, 6 Jan 2025 12:35:11 +0000 (18:05 +0530)
committerGitHub <noreply@github.com>
Mon, 6 Jan 2025 12:35:11 +0000 (18:05 +0530)
Adds `_ThreadSafeHandle` to be used for callbacks scheduled with `loop.call_soon_threadsafe`.

Lib/asyncio/base_events.py
Lib/asyncio/events.py
Lib/test/test_asyncio/test_events.py
Misc/NEWS.d/next/Library/2025-01-05-11-46-14.gh-issue-128340.gKI0uU.rst [new file with mode: 0644]

index 5dbe4b28d236d3c2a4cc2b6697b32ecfa2c40497..9e6f6e3ee7e3ecf7a05485066c141e558bb8571b 100644 (file)
@@ -873,7 +873,10 @@ class BaseEventLoop(events.AbstractEventLoop):
         self._check_closed()
         if self._debug:
             self._check_callback(callback, 'call_soon_threadsafe')
-        handle = self._call_soon(callback, args, context)
+        handle = events._ThreadSafeHandle(callback, args, self, context)
+        self._ready.append(handle)
+        if handle._source_traceback:
+            del handle._source_traceback[-1]
         if handle._source_traceback:
             del handle._source_traceback[-1]
         self._write_to_self()
index 6e291d28ec81aea421430b6fb3c2534d8bf10745..2ee9870e80f20b3aa099869fbd358f48af7d2466 100644 (file)
@@ -113,6 +113,34 @@ class Handle:
             self._loop.call_exception_handler(context)
         self = None  # Needed to break cycles when an exception occurs.
 
+# _ThreadSafeHandle is used for callbacks scheduled with call_soon_threadsafe
+# and is thread safe unlike Handle which is not thread safe.
+class _ThreadSafeHandle(Handle):
+
+    __slots__ = ('_lock',)
+
+    def __init__(self, callback, args, loop, context=None):
+        super().__init__(callback, args, loop, context)
+        self._lock = threading.RLock()
+
+    def cancel(self):
+        with self._lock:
+            return super().cancel()
+
+    def cancelled(self):
+        with self._lock:
+            return super().cancelled()
+
+    def _run(self):
+        # The event loop checks for cancellation without holding the lock
+        # It is possible that the handle is cancelled after the check
+        # but before the callback is called so check it again after acquiring
+        # the lock and return without calling the callback if it is cancelled.
+        with self._lock:
+            if self._cancelled:
+                return
+            return super()._run()
+
 
 class TimerHandle(Handle):
     """Object returned by timed callback registration methods."""
index c8439c9af5e6bafd8e61e100f7d21f7f9985052c..ed75b909317357aafce171533fcd55983b67f45d 100644 (file)
@@ -353,6 +353,124 @@ class EventLoopTestsMixin:
         t.join()
         self.assertEqual(results, ['hello', 'world'])
 
+    def test_call_soon_threadsafe_handle_block_check_cancelled(self):
+        results = []
+
+        callback_started = threading.Event()
+        callback_finished = threading.Event()
+        def callback(arg):
+            callback_started.set()
+            results.append(arg)
+            time.sleep(1)
+            callback_finished.set()
+
+        def run_in_thread():
+            handle = self.loop.call_soon_threadsafe(callback, 'hello')
+            self.assertIsInstance(handle, events._ThreadSafeHandle)
+            callback_started.wait()
+            # callback started so it should block checking for cancellation
+            # until it finishes
+            self.assertFalse(handle.cancelled())
+            self.assertTrue(callback_finished.is_set())
+            self.loop.call_soon_threadsafe(self.loop.stop)
+
+        t = threading.Thread(target=run_in_thread)
+        t.start()
+
+        self.loop.run_forever()
+        t.join()
+        self.assertEqual(results, ['hello'])
+
+    def test_call_soon_threadsafe_handle_block_cancellation(self):
+        results = []
+
+        callback_started = threading.Event()
+        callback_finished = threading.Event()
+        def callback(arg):
+            callback_started.set()
+            results.append(arg)
+            time.sleep(1)
+            callback_finished.set()
+
+        def run_in_thread():
+            handle = self.loop.call_soon_threadsafe(callback, 'hello')
+            self.assertIsInstance(handle, events._ThreadSafeHandle)
+            callback_started.wait()
+            # callback started so it cannot be cancelled from other thread until
+            # it finishes
+            handle.cancel()
+            self.assertTrue(callback_finished.is_set())
+            self.loop.call_soon_threadsafe(self.loop.stop)
+
+        t = threading.Thread(target=run_in_thread)
+        t.start()
+
+        self.loop.run_forever()
+        t.join()
+        self.assertEqual(results, ['hello'])
+
+    def test_call_soon_threadsafe_handle_cancel_same_thread(self):
+        results = []
+        callback_started = threading.Event()
+        callback_finished = threading.Event()
+
+        fut = concurrent.futures.Future()
+        def callback(arg):
+            callback_started.set()
+            handle = fut.result()
+            handle.cancel()
+            results.append(arg)
+            callback_finished.set()
+            self.loop.stop()
+
+        def run_in_thread():
+            handle = self.loop.call_soon_threadsafe(callback, 'hello')
+            fut.set_result(handle)
+            self.assertIsInstance(handle, events._ThreadSafeHandle)
+            callback_started.wait()
+            # callback cancels itself from same thread so it has no effect
+            # it runs to completion
+            self.assertTrue(handle.cancelled())
+            self.assertTrue(callback_finished.is_set())
+            self.loop.call_soon_threadsafe(self.loop.stop)
+
+        t = threading.Thread(target=run_in_thread)
+        t.start()
+
+        self.loop.run_forever()
+        t.join()
+        self.assertEqual(results, ['hello'])
+
+    def test_call_soon_threadsafe_handle_cancel_other_thread(self):
+        results = []
+        ev = threading.Event()
+
+        callback_finished = threading.Event()
+        def callback(arg):
+            results.append(arg)
+            callback_finished.set()
+            self.loop.stop()
+
+        def run_in_thread():
+            handle = self.loop.call_soon_threadsafe(callback, 'hello')
+            # handle can be cancelled from other thread if not started yet
+            self.assertIsInstance(handle, events._ThreadSafeHandle)
+            handle.cancel()
+            self.assertTrue(handle.cancelled())
+            self.assertFalse(callback_finished.is_set())
+            ev.set()
+            self.loop.call_soon_threadsafe(self.loop.stop)
+
+        # block the main loop until the callback is added and cancelled in the
+        # other thread
+        self.loop.call_soon(ev.wait)
+        t = threading.Thread(target=run_in_thread)
+        t.start()
+        self.loop.run_forever()
+        t.join()
+        self.assertEqual(results, [])
+        self.assertFalse(callback_finished.is_set())
+
     def test_call_soon_threadsafe_same_thread(self):
         results = []
 
diff --git a/Misc/NEWS.d/next/Library/2025-01-05-11-46-14.gh-issue-128340.gKI0uU.rst b/Misc/NEWS.d/next/Library/2025-01-05-11-46-14.gh-issue-128340.gKI0uU.rst
new file mode 100644 (file)
index 0000000..790400a
--- /dev/null
@@ -0,0 +1 @@
+Add internal thread safe handle to be used in :meth:`asyncio.loop.call_soon_threadsafe` for thread safe cancellation.