]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-96764: rewrite `asyncio.wait_for` to use `asyncio.timeout` (#98518)
authorKumar Aditya <59607654+kumaraditya303@users.noreply.github.com>
Thu, 16 Feb 2023 18:48:21 +0000 (00:18 +0530)
committerGitHub <noreply@github.com>
Thu, 16 Feb 2023 18:48:21 +0000 (00:18 +0530)
Changes `asyncio.wait_for` to use `asyncio.timeout` as its underlying implementation.

Lib/asyncio/tasks.py
Lib/test/test_asyncio/test_futures2.py
Lib/test/test_asyncio/test_waitfor.py
Misc/NEWS.d/next/Library/2022-10-22-09-26-43.gh-issue-96764.Dh9Y5L.rst [new file with mode: 0644]

index e78719de216fd00479bac8d01d73ae6b82343cfb..a2e06d5ef72f42f7227cc4f337d819c7f1efcb51 100644 (file)
@@ -24,6 +24,7 @@ from . import coroutines
 from . import events
 from . import exceptions
 from . import futures
+from . import timeouts
 from .coroutines import _is_coroutine
 
 # Helper to generate new task names
@@ -437,65 +438,44 @@ async def wait_for(fut, timeout):
 
     If the wait is cancelled, the task is also cancelled.
 
+    If the task supresses the cancellation and returns a value instead,
+    that value is returned.
+
     This function is a coroutine.
     """
-    loop = events.get_running_loop()
+    # The special case for timeout <= 0 is for the following case:
+    #
+    # async def test_waitfor():
+    #     func_started = False
+    #
+    #     async def func():
+    #         nonlocal func_started
+    #         func_started = True
+    #
+    #     try:
+    #         await asyncio.wait_for(func(), 0)
+    #     except asyncio.TimeoutError:
+    #         assert not func_started
+    #     else:
+    #         assert False
+    #
+    # asyncio.run(test_waitfor())
 
-    if timeout is None:
-        return await fut
 
-    if timeout <= 0:
-        fut = ensure_future(fut, loop=loop)
+    if timeout is not None and timeout <= 0:
+        fut = ensure_future(fut)
 
         if fut.done():
             return fut.result()
 
-        await _cancel_and_wait(fut, loop=loop)
+        await _cancel_and_wait(fut)
         try:
             return fut.result()
         except exceptions.CancelledError as exc:
-            raise exceptions.TimeoutError() from exc
-
-    waiter = loop.create_future()
-    timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
-    cb = functools.partial(_release_waiter, waiter)
-
-    fut = ensure_future(fut, loop=loop)
-    fut.add_done_callback(cb)
-
-    try:
-        # wait until the future completes or the timeout
-        try:
-            await waiter
-        except exceptions.CancelledError:
-            if fut.done():
-                return fut.result()
-            else:
-                fut.remove_done_callback(cb)
-                # We must ensure that the task is not running
-                # after wait_for() returns.
-                # See https://bugs.python.org/issue32751
-                await _cancel_and_wait(fut, loop=loop)
-                raise
-
-        if fut.done():
-            return fut.result()
-        else:
-            fut.remove_done_callback(cb)
-            # We must ensure that the task is not running
-            # after wait_for() returns.
-            # See https://bugs.python.org/issue32751
-            await _cancel_and_wait(fut, loop=loop)
-            # In case task cancellation failed with some
-            # exception, we should re-raise it
-            # See https://bugs.python.org/issue40607
-            try:
-                return fut.result()
-            except exceptions.CancelledError as exc:
-                raise exceptions.TimeoutError() from exc
-    finally:
-        timeout_handle.cancel()
+            raise TimeoutError from exc
 
+    async with timeouts.timeout(timeout):
+        return await fut
 
 async def _wait(fs, timeout, return_when, loop):
     """Internal helper for wait().
@@ -541,9 +521,10 @@ async def _wait(fs, timeout, return_when, loop):
     return done, pending
 
 
-async def _cancel_and_wait(fut, loop):
+async def _cancel_and_wait(fut):
     """Cancel the *fut* future or task and wait until it completes."""
 
+    loop = events.get_running_loop()
     waiter = loop.create_future()
     cb = functools.partial(_release_waiter, waiter)
     fut.add_done_callback(cb)
index 9e7a5775a70383c1f18cbc074e8bd75da1f8af8c..b7cfffb76bd8f1729a9be469c717e4bf305f757d 100644 (file)
@@ -86,10 +86,9 @@ class FutureReprTests(unittest.IsolatedAsyncioTestCase):
         async def func():
             return asyncio.all_tasks()
 
-        # The repr() call should not raise RecursiveError at first.
-        # The check for returned string is not very reliable but
-        # exact comparison for the whole string is even weaker.
-        self.assertIn('...', repr(await asyncio.wait_for(func(), timeout=10)))
+        # The repr() call should not raise RecursionError at first.
+        waiter = await asyncio.wait_for(asyncio.Task(func()),timeout=10)
+        self.assertIn('...', repr(waiter))
 
 
 if __name__ == '__main__':
index 45498fa097f6bccab7c7e3781fdd4f857c334a75..ed80540b2b38520e98891c509859009152f205e9 100644 (file)
@@ -237,33 +237,6 @@ class AsyncioWaitForTest(unittest.IsolatedAsyncioTestCase):
         with self.assertRaises(FooException):
             await foo()
 
-    async def test_wait_for_self_cancellation(self):
-        async def inner():
-            try:
-                await asyncio.sleep(0.3)
-            except asyncio.CancelledError:
-                try:
-                    await asyncio.sleep(0.3)
-                except asyncio.CancelledError:
-                    await asyncio.sleep(0.3)
-
-            return 42
-
-        inner_task = asyncio.create_task(inner())
-
-        wait = asyncio.wait_for(inner_task, timeout=0.1)
-
-        # Test that wait_for itself is properly cancellable
-        # even when the initial task holds up the initial cancellation.
-        task = asyncio.create_task(wait)
-        await asyncio.sleep(0.2)
-        task.cancel()
-
-        with self.assertRaises(asyncio.CancelledError):
-            await task
-
-        self.assertEqual(await inner_task, 42)
-
     async def _test_cancel_wait_for(self, timeout):
         loop = asyncio.get_running_loop()
 
@@ -289,6 +262,106 @@ class AsyncioWaitForTest(unittest.IsolatedAsyncioTestCase):
     async def test_cancel_wait_for(self):
         await self._test_cancel_wait_for(60.0)
 
+    async def test_wait_for_cancel_suppressed(self):
+        # GH-86296: Supressing CancelledError is discouraged
+        # but if a task subpresses CancelledError and returns a value,
+        # `wait_for` should return the value instead of raising CancelledError.
+        # This is the same behavior as `asyncio.timeout`.
+
+        async def return_42():
+            try:
+                await asyncio.sleep(10)
+            except asyncio.CancelledError:
+                return 42
+
+        res = await asyncio.wait_for(return_42(), timeout=0.1)
+        self.assertEqual(res, 42)
+
+
+    async def test_wait_for_issue86296(self):
+        # GH-86296: The task should get cancelled and not run to completion.
+        # inner completes in one cycle of the event loop so it
+        # completes before the task is cancelled.
+
+        async def inner():
+            return 'done'
+
+        inner_task = asyncio.create_task(inner())
+        reached_end = False
+
+        async def wait_for_coro():
+            await asyncio.wait_for(inner_task, timeout=100)
+            await asyncio.sleep(1)
+            nonlocal reached_end
+            reached_end = True
+
+        task = asyncio.create_task(wait_for_coro())
+        self.assertFalse(task.done())
+        # Run the task
+        await asyncio.sleep(0)
+        task.cancel()
+        with self.assertRaises(asyncio.CancelledError):
+            await task
+        self.assertTrue(inner_task.done())
+        self.assertEqual(await inner_task, 'done')
+        self.assertFalse(reached_end)
+
+
+class WaitForShieldTests(unittest.IsolatedAsyncioTestCase):
+
+    async def test_zero_timeout(self):
+        # `asyncio.shield` creates a new task which wraps the passed in
+        # awaitable and shields it from cancellation so with timeout=0
+        # the task returned by `asyncio.shield` aka shielded_task gets
+        # cancelled immediately and the task wrapped by it is scheduled
+        # to run.
+
+        async def coro():
+            await asyncio.sleep(0.01)
+            return 'done'
+
+        task = asyncio.create_task(coro())
+        with self.assertRaises(asyncio.TimeoutError):
+            shielded_task = asyncio.shield(task)
+            await asyncio.wait_for(shielded_task, timeout=0)
+
+        # Task is running in background
+        self.assertFalse(task.done())
+        self.assertFalse(task.cancelled())
+        self.assertTrue(shielded_task.cancelled())
+
+        # Wait for the task to complete
+        await asyncio.sleep(0.1)
+        self.assertTrue(task.done())
+
+
+    async def test_none_timeout(self):
+        # With timeout=None the timeout is disabled so it
+        # runs till completion.
+        async def coro():
+            await asyncio.sleep(0.1)
+            return 'done'
+
+        task = asyncio.create_task(coro())
+        await asyncio.wait_for(asyncio.shield(task), timeout=None)
+
+        self.assertTrue(task.done())
+        self.assertEqual(await task, "done")
+
+    async def test_shielded_timeout(self):
+        # shield prevents the task from being cancelled.
+        async def coro():
+            await asyncio.sleep(0.1)
+            return 'done'
+
+        task = asyncio.create_task(coro())
+        with self.assertRaises(asyncio.TimeoutError):
+            await asyncio.wait_for(asyncio.shield(task), timeout=0.01)
+
+        self.assertFalse(task.done())
+        self.assertFalse(task.cancelled())
+        self.assertEqual(await task, "done")
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2022-10-22-09-26-43.gh-issue-96764.Dh9Y5L.rst b/Misc/NEWS.d/next/Library/2022-10-22-09-26-43.gh-issue-96764.Dh9Y5L.rst
new file mode 100644 (file)
index 0000000..a017429
--- /dev/null
@@ -0,0 +1 @@
+:func:`asyncio.wait_for` now uses :func:`asyncio.timeout` as its underlying implementation. Patch by Kumar Aditya.