From: Kumar Aditya Date: Mon, 13 Jan 2025 15:36:55 +0000 (+0530) Subject: gh-128002: add more thread safety tests for asyncio (#128480) X-Git-Tag: v3.14.0a4~13 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=3efe28a40b136164f0d33c4f84dfcef7e123d1a0;p=thirdparty%2FPython%2Fcpython.git gh-128002: add more thread safety tests for asyncio (#128480) --- diff --git a/Lib/test/test_asyncio/test_free_threading.py b/Lib/test/test_asyncio/test_free_threading.py index 90bddbf3a9dd..8f4bba5f3b97 100644 --- a/Lib/test/test_asyncio/test_free_threading.py +++ b/Lib/test/test_asyncio/test_free_threading.py @@ -7,6 +7,11 @@ from test.support import threading_helper threading_helper.requires_working_threading(module=True) + +class MyException(Exception): + pass + + def tearDownModule(): asyncio._set_event_loop_policy(None) @@ -53,6 +58,55 @@ class TestFreeThreading: with threading_helper.start_threads(threads): pass + def test_run_coroutine_threadsafe(self) -> None: + results = [] + + def in_thread(loop: asyncio.AbstractEventLoop): + coro = asyncio.sleep(0.1, result=42) + fut = asyncio.run_coroutine_threadsafe(coro, loop) + result = fut.result() + self.assertEqual(result, 42) + results.append(result) + + async def main(): + loop = asyncio.get_running_loop() + async with asyncio.TaskGroup() as tg: + for _ in range(10): + tg.create_task(asyncio.to_thread(in_thread, loop)) + self.assertEqual(results, [42] * 10) + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + + def test_run_coroutine_threadsafe_exception(self) -> None: + async def coro(): + await asyncio.sleep(0) + raise MyException("test") + + def in_thread(loop: asyncio.AbstractEventLoop): + fut = asyncio.run_coroutine_threadsafe(coro(), loop) + return fut.result() + + async def main(): + loop = asyncio.get_running_loop() + tasks = [] + for _ in range(10): + task = loop.create_task(asyncio.to_thread(in_thread, loop)) + tasks.append(task) + results = await asyncio.gather(*tasks, return_exceptions=True) + + self.assertEqual(len(results), 10) + for result in results: + self.assertIsInstance(result, MyException) + self.assertEqual(str(result), "test") + + with asyncio.Runner() as r: + loop = r.get_loop() + loop.set_task_factory(self.factory) + r.run(main()) + class TestPyFreeThreading(TestFreeThreading, TestCase): all_tasks = staticmethod(asyncio.tasks._py_all_tasks)