]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-128002: add more thread safety tests for asyncio (#128480)
authorKumar Aditya <kumaraditya@python.org>
Mon, 13 Jan 2025 15:36:55 +0000 (21:06 +0530)
committerGitHub <noreply@github.com>
Mon, 13 Jan 2025 15:36:55 +0000 (15:36 +0000)
Lib/test/test_asyncio/test_free_threading.py

index 90bddbf3a9dda16fb9f8b3072928df8f2942ff7b..8f4bba5f3b97d9bdadf4c848a06cbca039bda999 100644 (file)
@@ -7,6 +7,11 @@ from test.support import threading_helper
 
 threading_helper.requires_working_threading(module=True)
 
+
+class MyException(Exception):
+    pass
+
+
 def tearDownModule():
     asyncio._set_event_loop_policy(None)
 
@@ -53,6 +58,55 @@ class TestFreeThreading:
         with threading_helper.start_threads(threads):
             pass
 
+    def test_run_coroutine_threadsafe(self) -> None:
+        results = []
+
+        def in_thread(loop: asyncio.AbstractEventLoop):
+            coro = asyncio.sleep(0.1, result=42)
+            fut = asyncio.run_coroutine_threadsafe(coro, loop)
+            result = fut.result()
+            self.assertEqual(result, 42)
+            results.append(result)
+
+        async def main():
+            loop = asyncio.get_running_loop()
+            async with asyncio.TaskGroup() as tg:
+                for _ in range(10):
+                    tg.create_task(asyncio.to_thread(in_thread, loop))
+            self.assertEqual(results, [42] * 10)
+
+        with asyncio.Runner() as r:
+            loop = r.get_loop()
+            loop.set_task_factory(self.factory)
+            r.run(main())
+
+    def test_run_coroutine_threadsafe_exception(self) -> None:
+        async def coro():
+            await asyncio.sleep(0)
+            raise MyException("test")
+
+        def in_thread(loop: asyncio.AbstractEventLoop):
+            fut = asyncio.run_coroutine_threadsafe(coro(), loop)
+            return fut.result()
+
+        async def main():
+            loop = asyncio.get_running_loop()
+            tasks = []
+            for _ in range(10):
+                task = loop.create_task(asyncio.to_thread(in_thread, loop))
+                tasks.append(task)
+            results = await asyncio.gather(*tasks, return_exceptions=True)
+
+            self.assertEqual(len(results), 10)
+            for result in results:
+                self.assertIsInstance(result, MyException)
+                self.assertEqual(str(result), "test")
+
+        with asyncio.Runner() as r:
+            loop = r.get_loop()
+            loop.set_task_factory(self.factory)
+            r.run(main())
+
 
 class TestPyFreeThreading(TestFreeThreading, TestCase):
     all_tasks = staticmethod(asyncio.tasks._py_all_tasks)