mirror of
https://github.com/python/cpython
synced 2024-09-20 02:28:13 +00:00
bpo-46752: Slight improvements to TaskGroup API (GH-31398)
* Remove task group names (for now) We're not sure that they are needed, and once in the code we would never be able to get rid of them. Yury wrote: > Ideally, there should be a way for someone to build a "trace" > of taskgroups/task leading to the current running task. > We could do that using contextvars, but I'm not sure we should > do that in 3.11. * Pass optional name on to task in create_task() * Remove a bunch of unused stuff
This commit is contained in:
parent
2a38e1ab65
commit
d85121660e
|
@ -3,10 +3,6 @@
|
|||
|
||||
__all__ = ["TaskGroup"]
|
||||
|
||||
import itertools
|
||||
import textwrap
|
||||
import traceback
|
||||
import types
|
||||
import weakref
|
||||
|
||||
from . import events
|
||||
|
@ -15,12 +11,7 @@
|
|||
|
||||
class TaskGroup:
|
||||
|
||||
def __init__(self, *, name=None):
|
||||
if name is None:
|
||||
self._name = f'tg-{_name_counter()}'
|
||||
else:
|
||||
self._name = str(name)
|
||||
|
||||
def __init__(self):
|
||||
self._entered = False
|
||||
self._exiting = False
|
||||
self._aborting = False
|
||||
|
@ -33,11 +24,8 @@ def __init__(self, *, name=None):
|
|||
self._base_error = None
|
||||
self._on_completed_fut = None
|
||||
|
||||
def get_name(self):
|
||||
return self._name
|
||||
|
||||
def __repr__(self):
|
||||
msg = f'<TaskGroup {self._name!r}'
|
||||
msg = f'<TaskGroup'
|
||||
if self._tasks:
|
||||
msg += f' tasks:{len(self._tasks)}'
|
||||
if self._unfinished_tasks:
|
||||
|
@ -152,12 +140,13 @@ async def __aexit__(self, et, exc, tb):
|
|||
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
|
||||
raise me from None
|
||||
|
||||
def create_task(self, coro):
|
||||
def create_task(self, coro, *, name=None):
|
||||
if not self._entered:
|
||||
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
|
||||
if self._exiting and self._unfinished_tasks == 0:
|
||||
raise RuntimeError(f"TaskGroup {self!r} is finished")
|
||||
task = self._loop.create_task(coro)
|
||||
tasks._set_task_name(task, name)
|
||||
task.add_done_callback(self._on_task_done)
|
||||
self._unfinished_tasks += 1
|
||||
self._tasks.add(task)
|
||||
|
@ -230,6 +219,3 @@ def _on_task_done(self, task):
|
|||
# # after TaskGroup is finished.
|
||||
self._parent_cancel_requested = True
|
||||
self._parent_task.cancel()
|
||||
|
||||
|
||||
_name_counter = itertools.count(1).__next__
|
||||
|
|
|
@ -368,10 +368,10 @@ async def crash_after(t):
|
|||
raise ValueError(t)
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup(name='g1') as g1:
|
||||
async with taskgroups.TaskGroup() as g1:
|
||||
g1.create_task(crash_after(0.1))
|
||||
|
||||
async with taskgroups.TaskGroup(name='g2') as g2:
|
||||
async with taskgroups.TaskGroup() as g2:
|
||||
g2.create_task(crash_after(0.2))
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
|
@ -387,10 +387,10 @@ async def crash_after(t):
|
|||
raise ValueError(t)
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup(name='g1') as g1:
|
||||
async with taskgroups.TaskGroup() as g1:
|
||||
g1.create_task(crash_after(10))
|
||||
|
||||
async with taskgroups.TaskGroup(name='g2') as g2:
|
||||
async with taskgroups.TaskGroup() as g2:
|
||||
g2.create_task(crash_after(0.1))
|
||||
|
||||
r = asyncio.create_task(runner())
|
||||
|
@ -407,7 +407,7 @@ async def crash_soon():
|
|||
1 / 0
|
||||
|
||||
async def runner():
|
||||
async with taskgroups.TaskGroup(name='g1') as g1:
|
||||
async with taskgroups.TaskGroup() as g1:
|
||||
g1.create_task(crash_soon())
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
|
@ -430,7 +430,7 @@ async def crash_soon():
|
|||
1 / 0
|
||||
|
||||
async def nested_runner():
|
||||
async with taskgroups.TaskGroup(name='g1') as g1:
|
||||
async with taskgroups.TaskGroup() as g1:
|
||||
g1.create_task(crash_soon())
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
|
@ -692,3 +692,10 @@ async def runner():
|
|||
|
||||
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
|
||||
self.assertGreaterEqual(nhydras, 10)
|
||||
|
||||
async def test_taskgroup_task_name(self):
|
||||
async def coro():
|
||||
await asyncio.sleep(0)
|
||||
async with taskgroups.TaskGroup() as g:
|
||||
t = g.create_task(coro(), name="yolo")
|
||||
self.assertEqual(t.get_name(), "yolo")
|
||||
|
|
Loading…
Reference in a new issue