bpo-46752: Slight improvements to TaskGroup API (GH-31398)

* Remove task group names (for now)

We're not sure that they are needed, and once in the code
we would never be able to get rid of them.

Yury wrote:

> Ideally, there should be a way for someone to build a "trace"
> of taskgroups/task leading to the current running task.
> We could do that using contextvars, but I'm not sure we should
> do that in 3.11.

* Pass optional name on to task in create_task()

* Remove a bunch of unused stuff
This commit is contained in:
Guido van Rossum 2022-02-17 21:30:44 -08:00 committed by GitHub
parent 2a38e1ab65
commit d85121660e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 24 deletions

View file

@ -3,10 +3,6 @@
__all__ = ["TaskGroup"]
import itertools
import textwrap
import traceback
import types
import weakref
from . import events
@ -15,12 +11,7 @@
class TaskGroup:
def __init__(self, *, name=None):
if name is None:
self._name = f'tg-{_name_counter()}'
else:
self._name = str(name)
def __init__(self):
self._entered = False
self._exiting = False
self._aborting = False
@ -33,11 +24,8 @@ def __init__(self, *, name=None):
self._base_error = None
self._on_completed_fut = None
def get_name(self):
return self._name
def __repr__(self):
msg = f'<TaskGroup {self._name!r}'
msg = f'<TaskGroup'
if self._tasks:
msg += f' tasks:{len(self._tasks)}'
if self._unfinished_tasks:
@ -152,12 +140,13 @@ async def __aexit__(self, et, exc, tb):
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
raise me from None
def create_task(self, coro):
def create_task(self, coro, *, name=None):
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and self._unfinished_tasks == 0:
raise RuntimeError(f"TaskGroup {self!r} is finished")
task = self._loop.create_task(coro)
tasks._set_task_name(task, name)
task.add_done_callback(self._on_task_done)
self._unfinished_tasks += 1
self._tasks.add(task)
@ -230,6 +219,3 @@ def _on_task_done(self, task):
# # after TaskGroup is finished.
self._parent_cancel_requested = True
self._parent_task.cancel()
_name_counter = itertools.count(1).__next__

View file

@ -368,10 +368,10 @@ async def crash_after(t):
raise ValueError(t)
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_after(0.1))
async with taskgroups.TaskGroup(name='g2') as g2:
async with taskgroups.TaskGroup() as g2:
g2.create_task(crash_after(0.2))
r = asyncio.create_task(runner())
@ -387,10 +387,10 @@ async def crash_after(t):
raise ValueError(t)
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_after(10))
async with taskgroups.TaskGroup(name='g2') as g2:
async with taskgroups.TaskGroup() as g2:
g2.create_task(crash_after(0.1))
r = asyncio.create_task(runner())
@ -407,7 +407,7 @@ async def crash_soon():
1 / 0
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
@ -430,7 +430,7 @@ async def crash_soon():
1 / 0
async def nested_runner():
async with taskgroups.TaskGroup(name='g1') as g1:
async with taskgroups.TaskGroup() as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
@ -692,3 +692,10 @@ async def runner():
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertGreaterEqual(nhydras, 10)
async def test_taskgroup_task_name(self):
async def coro():
await asyncio.sleep(0)
async with taskgroups.TaskGroup() as g:
t = g.create_task(coro(), name="yolo")
self.assertEqual(t.get_name(), "yolo")