]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-46752: Slight improvements to TaskGroup API (GH-31398)
authorGuido van Rossum <guido@python.org>
Fri, 18 Feb 2022 05:30:44 +0000 (21:30 -0800)
committerGitHub <noreply@github.com>
Fri, 18 Feb 2022 05:30:44 +0000 (21:30 -0800)
* Remove task group names (for now)

We're not sure that they are needed, and once in the code
we would never be able to get rid of them.

Yury wrote:

> Ideally, there should be a way for someone to build a "trace"
> of taskgroups/task leading to the current running task.
> We could do that using contextvars, but I'm not sure we should
> do that in 3.11.

* Pass optional name on to task in create_task()

* Remove a bunch of unused stuff

Lib/asyncio/taskgroups.py
Lib/test/test_asyncio/test_taskgroups.py

index 718277892c51c90e3b77c52f37dcf50a1cce8992..57b0eafefc16fea0f9b475dec48b8cc7ae4f9b32 100644 (file)
@@ -3,10 +3,6 @@
 
 __all__ = ["TaskGroup"]
 
-import itertools
-import textwrap
-import traceback
-import types
 import weakref
 
 from . import events
@@ -15,12 +11,7 @@ from . import tasks
 
 class TaskGroup:
 
-    def __init__(self, *, name=None):
-        if name is None:
-            self._name = f'tg-{_name_counter()}'
-        else:
-            self._name = str(name)
-
+    def __init__(self):
         self._entered = False
         self._exiting = False
         self._aborting = False
@@ -33,11 +24,8 @@ class TaskGroup:
         self._base_error = None
         self._on_completed_fut = None
 
-    def get_name(self):
-        return self._name
-
     def __repr__(self):
-        msg = f'<TaskGroup {self._name!r}'
+        msg = f'<TaskGroup'
         if self._tasks:
             msg += f' tasks:{len(self._tasks)}'
         if self._unfinished_tasks:
@@ -152,12 +140,13 @@ class TaskGroup:
             me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
             raise me from None
 
-    def create_task(self, coro):
+    def create_task(self, coro, *, name=None):
         if not self._entered:
             raise RuntimeError(f"TaskGroup {self!r} has not been entered")
         if self._exiting and self._unfinished_tasks == 0:
             raise RuntimeError(f"TaskGroup {self!r} is finished")
         task = self._loop.create_task(coro)
+        tasks._set_task_name(task, name)
         task.add_done_callback(self._on_task_done)
         self._unfinished_tasks += 1
         self._tasks.add(task)
@@ -230,6 +219,3 @@ class TaskGroup:
             #                                 # after TaskGroup is finished.
             self._parent_cancel_requested = True
             self._parent_task.cancel()
-
-
-_name_counter = itertools.count(1).__next__
index ea6ee2ed43d2f8dd5cc21e6f9b953c1adae7988a..aab1fd1ebb38d87283b78bbadde7dae42c2c364f 100644 (file)
@@ -368,10 +368,10 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
             raise ValueError(t)
 
         async def runner():
-            async with taskgroups.TaskGroup(name='g1') as g1:
+            async with taskgroups.TaskGroup() as g1:
                 g1.create_task(crash_after(0.1))
 
-                async with taskgroups.TaskGroup(name='g2') as g2:
+                async with taskgroups.TaskGroup() as g2:
                     g2.create_task(crash_after(0.2))
 
         r = asyncio.create_task(runner())
@@ -387,10 +387,10 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
             raise ValueError(t)
 
         async def runner():
-            async with taskgroups.TaskGroup(name='g1') as g1:
+            async with taskgroups.TaskGroup() as g1:
                 g1.create_task(crash_after(10))
 
-                async with taskgroups.TaskGroup(name='g2') as g2:
+                async with taskgroups.TaskGroup() as g2:
                     g2.create_task(crash_after(0.1))
 
         r = asyncio.create_task(runner())
@@ -407,7 +407,7 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
             1 / 0
 
         async def runner():
-            async with taskgroups.TaskGroup(name='g1') as g1:
+            async with taskgroups.TaskGroup() as g1:
                 g1.create_task(crash_soon())
                 try:
                     await asyncio.sleep(10)
@@ -430,7 +430,7 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
             1 / 0
 
         async def nested_runner():
-            async with taskgroups.TaskGroup(name='g1') as g1:
+            async with taskgroups.TaskGroup() as g1:
                 g1.create_task(crash_soon())
                 try:
                     await asyncio.sleep(10)
@@ -692,3 +692,10 @@ class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
 
         self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
         self.assertGreaterEqual(nhydras, 10)
+
+    async def test_taskgroup_task_name(self):
+        async def coro():
+            await asyncio.sleep(0)
+        async with taskgroups.TaskGroup() as g:
+            t = g.create_task(coro(), name="yolo")
+            self.assertEqual(t.get_name(), "yolo")