]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-111085: Fix invalid state handling in TaskGroup and Timeout (#111111)
authorSerhiy Storchaka <storchaka@gmail.com>
Sat, 21 Oct 2023 19:18:34 +0000 (22:18 +0300)
committerGitHub <noreply@github.com>
Sat, 21 Oct 2023 19:18:34 +0000 (22:18 +0300)
asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError
if they are improperly used.

* When they are used without entering the context manager.
* When they are used after finishing.
* When the context manager is entered more than once (simultaneously or
  sequentially).
* If there is no current task when entering the context manager.

They now remain in a consistent state after an exception is thrown,
so subsequent operations can be performed correctly (if they are allowed).

Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
Lib/asyncio/taskgroups.py
Lib/asyncio/timeouts.py
Lib/test/test_asyncio/test_taskgroups.py
Lib/test/test_asyncio/test_timeouts.py
Lib/test/test_asyncio/utils.py
Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst [new file with mode: 0644]

index 24238c4f5f998d0f9b26d1ef9ebfb8f9547f3395..91be0decc41c4265df5dfab0fc4c907c1617663f 100644 (file)
@@ -54,16 +54,14 @@ class TaskGroup:
     async def __aenter__(self):
         if self._entered:
             raise RuntimeError(
-                f"TaskGroup {self!r} has been already entered")
-        self._entered = True
-
+                f"TaskGroup {self!r} has already been entered")
         if self._loop is None:
             self._loop = events.get_running_loop()
-
         self._parent_task = tasks.current_task(self._loop)
         if self._parent_task is None:
             raise RuntimeError(
                 f'TaskGroup {self!r} cannot determine the parent task')
+        self._entered = True
 
         return self
 
index 029c468739bf2d46bb699da1384d38cf45c004a8..30042abb3ad804d8ed5d4255cf6949b0f00c3e40 100644 (file)
@@ -49,8 +49,9 @@ class Timeout:
 
     def reschedule(self, when: Optional[float]) -> None:
         """Reschedule the timeout."""
-        assert self._state is not _State.CREATED
         if self._state is not _State.ENTERED:
+            if self._state is _State.CREATED:
+                raise RuntimeError("Timeout has not been entered")
             raise RuntimeError(
                 f"Cannot change state of {self._state.value} Timeout",
             )
@@ -82,11 +83,14 @@ class Timeout:
         return f"<Timeout [{self._state.value}]{info_str}>"
 
     async def __aenter__(self) -> "Timeout":
+        if self._state is not _State.CREATED:
+            raise RuntimeError("Timeout has already been entered")
+        task = tasks.current_task()
+        if task is None:
+            raise RuntimeError("Timeout should be used inside a task")
         self._state = _State.ENTERED
-        self._task = tasks.current_task()
+        self._task = task
         self._cancelling = self._task.cancelling()
-        if self._task is None:
-            raise RuntimeError("Timeout should be used inside a task")
         self.reschedule(self._when)
         return self
 
index 6a0231f2859a625ec5dbb7be06c5b883644b42b2..7a18362b54e4695a463c84548a682d30a15383f2 100644 (file)
@@ -8,6 +8,8 @@ import contextlib
 from asyncio import taskgroups
 import unittest
 
+from test.test_asyncio.utils import await_without_task
+
 
 # To prevent a warning "test altered the execution environment"
 def tearDownModule():
@@ -779,6 +781,49 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
 
         await asyncio.create_task(main())
 
+    async def test_taskgroup_already_entered(self):
+        tg = taskgroups.TaskGroup()
+        async with tg:
+            with self.assertRaisesRegex(RuntimeError, "has already been entered"):
+                async with tg:
+                    pass
+
+    async def test_taskgroup_double_enter(self):
+        tg = taskgroups.TaskGroup()
+        async with tg:
+            pass
+        with self.assertRaisesRegex(RuntimeError, "has already been entered"):
+            async with tg:
+                pass
+
+    async def test_taskgroup_finished(self):
+        tg = taskgroups.TaskGroup()
+        async with tg:
+            pass
+        coro = asyncio.sleep(0)
+        with self.assertRaisesRegex(RuntimeError, "is finished"):
+            tg.create_task(coro)
+        # We still have to await coro to avoid a warning
+        await coro
+
+    async def test_taskgroup_not_entered(self):
+        tg = taskgroups.TaskGroup()
+        coro = asyncio.sleep(0)
+        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
+            tg.create_task(coro)
+        # We still have to await coro to avoid a warning
+        await coro
+
+    async def test_taskgroup_without_parent_task(self):
+        tg = taskgroups.TaskGroup()
+        with self.assertRaisesRegex(RuntimeError, "parent task"):
+            await await_without_task(tg.__aenter__())
+        coro = asyncio.sleep(0)
+        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
+            tg.create_task(coro)
+        # We still have to await coro to avoid a warning
+        await coro
+
 
 if __name__ == "__main__":
     unittest.main()
index e9b59b953518b3fc1e332240c1faeeebdf24083d..f54e79e4d8e600c511d87b32b6cad9b7073f747d 100644 (file)
@@ -5,11 +5,12 @@ import time
 
 import asyncio
 
+from test.test_asyncio.utils import await_without_task
+
 
 def tearDownModule():
     asyncio.set_event_loop_policy(None)
 
-
 class TimeoutTests(unittest.IsolatedAsyncioTestCase):
 
     async def test_timeout_basic(self):
@@ -257,6 +258,51 @@ class TimeoutTests(unittest.IsolatedAsyncioTestCase):
         cause = exc.exception.__cause__
         assert isinstance(cause, asyncio.CancelledError)
 
+    async def test_timeout_already_entered(self):
+        async with asyncio.timeout(0.01) as cm:
+            with self.assertRaisesRegex(RuntimeError, "has already been entered"):
+                async with cm:
+                    pass
+
+    async def test_timeout_double_enter(self):
+        async with asyncio.timeout(0.01) as cm:
+            pass
+        with self.assertRaisesRegex(RuntimeError, "has already been entered"):
+            async with cm:
+                pass
+
+    async def test_timeout_finished(self):
+        async with asyncio.timeout(0.01) as cm:
+            pass
+        with self.assertRaisesRegex(RuntimeError, "finished"):
+            cm.reschedule(0.02)
+
+    async def test_timeout_expired(self):
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0.01) as cm:
+                await asyncio.sleep(1)
+        with self.assertRaisesRegex(RuntimeError, "expired"):
+            cm.reschedule(0.02)
+
+    async def test_timeout_expiring(self):
+        async with asyncio.timeout(0.01) as cm:
+            with self.assertRaises(asyncio.CancelledError):
+                await asyncio.sleep(1)
+            with self.assertRaisesRegex(RuntimeError, "expiring"):
+                cm.reschedule(0.02)
+
+    async def test_timeout_not_entered(self):
+        cm = asyncio.timeout(0.01)
+        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
+            cm.reschedule(0.02)
+
+    async def test_timeout_without_task(self):
+        cm = asyncio.timeout(0.01)
+        with self.assertRaisesRegex(RuntimeError, "task"):
+            await await_without_task(cm.__aenter__())
+        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
+            cm.reschedule(0.02)
+
 
 if __name__ == '__main__':
     unittest.main()
index 83fac4a26aff9ebf7d64a61833e805a9e6d53c37..18869b3290a8aef517ba6b27d935db14aa4a6f0f 100644 (file)
@@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
     sock.family = family
     sock.gettimeout.return_value = 0.0
     return sock
+
+
+async def await_without_task(coro):
+    exc = None
+    def func():
+        try:
+            for _ in coro.__await__():
+                pass
+        except BaseException as err:
+            nonlocal exc
+            exc = err
+    asyncio.get_running_loop().call_soon(func)
+    await asyncio.sleep(0)
+    if exc is not None:
+        raise exc
diff --git a/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst b/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst
new file mode 100644 (file)
index 0000000..c750447
--- /dev/null
@@ -0,0 +1,3 @@
+Fix invalid state handling in :class:`asyncio.TaskGroup` and
+:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
+improperly used and are left in consistent state after this.