]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.11] gh-111085: Fix invalid state handling in TaskGroup and Timeout (GH-111111...
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Sat, 21 Oct 2023 19:40:07 +0000 (21:40 +0200)
committerGitHub <noreply@github.com>
Sat, 21 Oct 2023 19:40:07 +0000 (19:40 +0000)
asyncio.TaskGroup and asyncio.Timeout classes now raise proper RuntimeError
if they are improperly used.

* When they are used without entering the context manager.
* When they are used after finishing.
* When the context manager is entered more than once (simultaneously or
  sequentially).
* If there is no current task when entering the context manager.

They now remain in a consistent state after an exception is thrown,
so subsequent operations can be performed correctly (if they are allowed).

(cherry picked from commit 6c23635f2b7067ef091a550954e09f8b7c329e3f)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
Lib/asyncio/taskgroups.py
Lib/asyncio/timeouts.py
Lib/test/test_asyncio/test_taskgroups.py
Lib/test/test_asyncio/test_timeouts.py
Lib/test/test_asyncio/utils.py
Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst [new file with mode: 0644]

index 0fdea3697ece3d600ee3c467000ba091e664d971..bfdbe63049f9bdb7e068a8a60aec1b2794b4d3d2 100644 (file)
@@ -54,16 +54,14 @@ class TaskGroup:
     async def __aenter__(self):
         if self._entered:
             raise RuntimeError(
-                f"TaskGroup {self!r} has been already entered")
-        self._entered = True
-
+                f"TaskGroup {self!r} has already been entered")
         if self._loop is None:
             self._loop = events.get_running_loop()
-
         self._parent_task = tasks.current_task(self._loop)
         if self._parent_task is None:
             raise RuntimeError(
                 f'TaskGroup {self!r} cannot determine the parent task')
+        self._entered = True
 
         return self
 
index 029c468739bf2d46bb699da1384d38cf45c004a8..30042abb3ad804d8ed5d4255cf6949b0f00c3e40 100644 (file)
@@ -49,8 +49,9 @@ class Timeout:
 
     def reschedule(self, when: Optional[float]) -> None:
         """Reschedule the timeout."""
-        assert self._state is not _State.CREATED
         if self._state is not _State.ENTERED:
+            if self._state is _State.CREATED:
+                raise RuntimeError("Timeout has not been entered")
             raise RuntimeError(
                 f"Cannot change state of {self._state.value} Timeout",
             )
@@ -82,11 +83,14 @@ class Timeout:
         return f"<Timeout [{self._state.value}]{info_str}>"
 
     async def __aenter__(self) -> "Timeout":
+        if self._state is not _State.CREATED:
+            raise RuntimeError("Timeout has already been entered")
+        task = tasks.current_task()
+        if task is None:
+            raise RuntimeError("Timeout should be used inside a task")
         self._state = _State.ENTERED
-        self._task = tasks.current_task()
+        self._task = task
         self._cancelling = self._task.cancelling()
-        if self._task is None:
-            raise RuntimeError("Timeout should be used inside a task")
         self.reschedule(self._when)
         return self
 
index 6a0231f2859a625ec5dbb7be06c5b883644b42b2..7a18362b54e4695a463c84548a682d30a15383f2 100644 (file)
@@ -8,6 +8,8 @@ import contextlib
 from asyncio import taskgroups
 import unittest
 
+from test.test_asyncio.utils import await_without_task
+
 
 # To prevent a warning "test altered the execution environment"
 def tearDownModule():
@@ -779,6 +781,49 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
 
         await asyncio.create_task(main())
 
+    async def test_taskgroup_already_entered(self):
+        tg = taskgroups.TaskGroup()
+        async with tg:
+            with self.assertRaisesRegex(RuntimeError, "has already been entered"):
+                async with tg:
+                    pass
+
+    async def test_taskgroup_double_enter(self):
+        tg = taskgroups.TaskGroup()
+        async with tg:
+            pass
+        with self.assertRaisesRegex(RuntimeError, "has already been entered"):
+            async with tg:
+                pass
+
+    async def test_taskgroup_finished(self):
+        tg = taskgroups.TaskGroup()
+        async with tg:
+            pass
+        coro = asyncio.sleep(0)
+        with self.assertRaisesRegex(RuntimeError, "is finished"):
+            tg.create_task(coro)
+        # We still have to await coro to avoid a warning
+        await coro
+
+    async def test_taskgroup_not_entered(self):
+        tg = taskgroups.TaskGroup()
+        coro = asyncio.sleep(0)
+        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
+            tg.create_task(coro)
+        # We still have to await coro to avoid a warning
+        await coro
+
+    async def test_taskgroup_without_parent_task(self):
+        tg = taskgroups.TaskGroup()
+        with self.assertRaisesRegex(RuntimeError, "parent task"):
+            await await_without_task(tg.__aenter__())
+        coro = asyncio.sleep(0)
+        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
+            tg.create_task(coro)
+        # We still have to await coro to avoid a warning
+        await coro
+
 
 if __name__ == "__main__":
     unittest.main()
index 5a4093e94707bcce437bf7829d22b141ef05d430..bfa3f1bff694add051bcae952add36053e049c6c 100644 (file)
@@ -6,11 +6,12 @@ import time
 import asyncio
 from asyncio import tasks
 
+from test.test_asyncio.utils import await_without_task
+
 
 def tearDownModule():
     asyncio.set_event_loop_policy(None)
 
-
 class TimeoutTests(unittest.IsolatedAsyncioTestCase):
 
     async def test_timeout_basic(self):
@@ -258,6 +259,51 @@ class TimeoutTests(unittest.IsolatedAsyncioTestCase):
         cause = exc.exception.__cause__
         assert isinstance(cause, asyncio.CancelledError)
 
+    async def test_timeout_already_entered(self):
+        async with asyncio.timeout(0.01) as cm:
+            with self.assertRaisesRegex(RuntimeError, "has already been entered"):
+                async with cm:
+                    pass
+
+    async def test_timeout_double_enter(self):
+        async with asyncio.timeout(0.01) as cm:
+            pass
+        with self.assertRaisesRegex(RuntimeError, "has already been entered"):
+            async with cm:
+                pass
+
+    async def test_timeout_finished(self):
+        async with asyncio.timeout(0.01) as cm:
+            pass
+        with self.assertRaisesRegex(RuntimeError, "finished"):
+            cm.reschedule(0.02)
+
+    async def test_timeout_expired(self):
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(0.01) as cm:
+                await asyncio.sleep(1)
+        with self.assertRaisesRegex(RuntimeError, "expired"):
+            cm.reschedule(0.02)
+
+    async def test_timeout_expiring(self):
+        async with asyncio.timeout(0.01) as cm:
+            with self.assertRaises(asyncio.CancelledError):
+                await asyncio.sleep(1)
+            with self.assertRaisesRegex(RuntimeError, "expiring"):
+                cm.reschedule(0.02)
+
+    async def test_timeout_not_entered(self):
+        cm = asyncio.timeout(0.01)
+        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
+            cm.reschedule(0.02)
+
+    async def test_timeout_without_task(self):
+        cm = asyncio.timeout(0.01)
+        with self.assertRaisesRegex(RuntimeError, "task"):
+            await await_without_task(cm.__aenter__())
+        with self.assertRaisesRegex(RuntimeError, "has not been entered"):
+            cm.reschedule(0.02)
+
 
 if __name__ == '__main__':
     unittest.main()
index d6f60db10f9b3f2346a3ea313b899ca64b50240e..7940855b19efed25e836005ad47ca102620a872a 100644 (file)
@@ -612,3 +612,18 @@ def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
     sock.family = family
     sock.gettimeout.return_value = 0.0
     return sock
+
+
+async def await_without_task(coro):
+    exc = None
+    def func():
+        try:
+            for _ in coro.__await__():
+                pass
+        except BaseException as err:
+            nonlocal exc
+            exc = err
+    asyncio.get_running_loop().call_soon(func)
+    await asyncio.sleep(0)
+    if exc is not None:
+        raise exc
diff --git a/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst b/Misc/NEWS.d/next/Library/2023-10-20-15-29-10.gh-issue-110910.u2oPwX.rst
new file mode 100644 (file)
index 0000000..c750447
--- /dev/null
@@ -0,0 +1,3 @@
+Fix invalid state handling in :class:`asyncio.TaskGroup` and
+:class:`asyncio.Timeout`. They now raise proper RuntimeError if they are
+improperly used and are left in consistent state after this.