]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-129874: improve asyncio tests to use correct internal functions (#129887)
authorKumar Aditya <kumaraditya@python.org>
Sun, 9 Feb 2025 12:05:39 +0000 (17:35 +0530)
committerGitHub <noreply@github.com>
Sun, 9 Feb 2025 12:05:39 +0000 (12:05 +0000)
Lib/test/test_asyncio/test_eager_task_factory.py
Lib/test/test_asyncio/test_free_threading.py
Lib/test/test_asyncio/test_graph.py

index 10450c11b68279d2fd6db9e999c1fcdcee4cff52..bb0760a6967dbacc82b0d9d0f85ed51d44abcea0 100644 (file)
@@ -267,12 +267,33 @@ class EagerTaskFactoryLoopTests:
 class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
     Task = tasks._PyTask
 
+    def setUp(self):
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+        return super().tearDown()
+
+
 
 @unittest.skipUnless(hasattr(tasks, '_CTask'),
                      'requires the C _asyncio module')
 class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
     Task = getattr(tasks, '_CTask', None)
 
+    def setUp(self):
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+        return super().tearDown()
+
+
+    @unittest.skip("skip")
     def test_issue105987(self):
         code = """if 1:
         from _asyncio import _swap_current_task
@@ -400,31 +421,83 @@ class BaseEagerTaskFactoryTests(BaseTaskCountingTests):
 
 
 class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
-    Task = asyncio.Task
+    Task = asyncio.tasks._CTask
+
+    def setUp(self):
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
+        return super().setUp()
 
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+        return super().tearDown()
 
 class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
-    Task = asyncio.Task
+    Task = asyncio.tasks._CTask
+
+    def setUp(self):
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+        return super().tearDown()
 
 
 class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
     Task = tasks._PyTask
 
+    def setUp(self):
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+        return super().tearDown()
+
 
 class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
     Task = tasks._PyTask
 
+    def setUp(self):
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+        return super().tearDown()
 
 @unittest.skipUnless(hasattr(tasks, '_CTask'),
                      'requires the C _asyncio module')
 class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
     Task = getattr(tasks, '_CTask', None)
 
+    def setUp(self):
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+        return super().tearDown()
+
 
 @unittest.skipUnless(hasattr(tasks, '_CTask'),
                      'requires the C _asyncio module')
 class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
     Task = getattr(tasks, '_CTask', None)
 
+    def setUp(self):
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+        return super().tearDown()
+
 if __name__ == '__main__':
     unittest.main()
index d0221d87062c5b4687afcab11a60a511dcb95281..199dbbdda5e8a6f14b6297c35b1690ccc6794250 100644 (file)
@@ -40,7 +40,7 @@ class TestFreeThreading:
                     self.assertEqual(task.get_loop(), loop)
                     self.assertFalse(task.done())
 
-                current = self.current_task()
+                current = asyncio.current_task()
                 self.assertEqual(current.get_loop(), loop)
                 self.assertSetEqual(all_tasks, tasks | {current})
                 future.set_result(None)
@@ -101,8 +101,12 @@ class TestFreeThreading:
         async def func():
             nonlocal task
             task = asyncio.current_task()
-
-        thread = Thread(target=lambda: asyncio.run(func()))
+        def runner():
+            with asyncio.Runner() as runner:
+                loop = runner.get_loop()
+                loop.set_task_factory(self.factory)
+                runner.run(func())
+        thread = Thread(target=runner)
         thread.start()
         thread.join()
         wr = weakref.ref(task)
@@ -164,7 +168,15 @@ class TestFreeThreading:
 
 class TestPyFreeThreading(TestFreeThreading, TestCase):
     all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
-    current_task = staticmethod(asyncio.tasks._py_current_task)
+
+    def setUp(self):
+        self._old_current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._py_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._old_current_task
+        return super().tearDown()
 
     def factory(self, loop, coro, **kwargs):
         return asyncio.tasks._PyTask(coro, loop=loop, **kwargs)
@@ -173,7 +185,16 @@ class TestPyFreeThreading(TestFreeThreading, TestCase):
 @unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
 class TestCFreeThreading(TestFreeThreading, TestCase):
     all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
-    current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))
+
+    def setUp(self):
+        self._old_current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = asyncio.tasks._c_current_task
+        return super().setUp()
+
+    def tearDown(self):
+        asyncio.current_task = asyncio.tasks.current_task = self._old_current_task
+        return super().tearDown()
+
 
     def factory(self, loop, coro, **kwargs):
         return asyncio.tasks._CTask(coro, loop=loop, **kwargs)
index fd2160d4ca31374ad27bee979b78fc93e4617098..62f6593c31d2d1e988539854424bb74408a84b35 100644 (file)
@@ -369,6 +369,8 @@ class TestCallStackC(CallStackTestBase, unittest.IsolatedAsyncioTestCase):
         futures.future_discard_from_awaited_by = futures._c_future_discard_from_awaited_by
         asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by
 
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = tasks._c_current_task
 
     def tearDown(self):
         futures = asyncio.futures
@@ -390,6 +392,8 @@ class TestCallStackC(CallStackTestBase, unittest.IsolatedAsyncioTestCase):
         futures.Future = self._Future
         del self._Future
 
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task
+
 
 @unittest.skipIf(
     not hasattr(asyncio.futures, "_py_future_add_to_awaited_by"),
@@ -414,6 +418,9 @@ class TestCallStackPy(CallStackTestBase, unittest.IsolatedAsyncioTestCase):
         futures.future_discard_from_awaited_by = futures._py_future_discard_from_awaited_by
         asyncio.future_discard_from_awaited_by = futures.future_discard_from_awaited_by
 
+        self._current_task = asyncio.current_task
+        asyncio.current_task = asyncio.tasks.current_task = tasks._py_current_task
+
 
     def tearDown(self):
         futures = asyncio.futures
@@ -434,3 +441,5 @@ class TestCallStackPy(CallStackTestBase, unittest.IsolatedAsyncioTestCase):
         asyncio.Future = self._Future
         futures.Future = self._Future
         del self._Future
+
+        asyncio.current_task = asyncio.tasks.current_task = self._current_task