bpo-46752: Add TaskGroup; add Task..cancelled(),.uncancel() (GH-31270)

asyncio/taskgroups.py is an adaptation of taskgroup.py from EdgeDb, with the following key changes:

- Allow creating new tasks as long as the last task hasn't finished
- Raise [Base]ExceptionGroup (directly) rather than TaskGroupError deriving from MultiError
- Instead of monkey-patching the parent task's cancel() method,
  add a new public API to Task

The Task class has a new internal flag, `_cancel_requested`, which is set when `.cancel()` is called successfully. The `.cancelling()` method returns the value of this flag. Further `.cancel()` calls while this flag is set return False. To reset this flag, call `.uncancel()`.

Thus, a Task that catches and ignores `CancelledError` should call `.uncancel()` if it wants to be cancellable again; until it does so, it is deemed to be busy with uninterruptible cleanup.

This new Task API helps solve the problem where TaskGroup needs to distinguish between whether the parent task being cancelled "from the outside" vs. "from inside".

Co-authored-by: Yury Selivanov <yury@edgedb.com>
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
Guido van Rossum 2022-02-15 15:42:04 -08:00 committed by GitHub
parent 08ec80113b
commit 602630ac18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 1100 additions and 3 deletions

View file

@ -17,6 +17,7 @@
from .streams import *
from .subprocess import *
from .tasks import *
from .taskgroups import *
from .threads import *
from .transports import *

View file

@ -8,7 +8,7 @@
def _task_repr_info(task):
info = base_futures._future_repr_info(task)
if task._must_cancel:
if task.cancelling() and not task.done():
# replace status
info[0] = 'cancelling'

235
Lib/asyncio/taskgroups.py Normal file
View file

@ -0,0 +1,235 @@
# Adapted with permission from the EdgeDB project.
__all__ = ["TaskGroup"]
import itertools
import textwrap
import traceback
import types
import weakref
from . import events
from . import exceptions
from . import tasks
class TaskGroup:
def __init__(self, *, name=None):
if name is None:
self._name = f'tg-{_name_counter()}'
else:
self._name = str(name)
self._entered = False
self._exiting = False
self._aborting = False
self._loop = None
self._parent_task = None
self._parent_cancel_requested = False
self._tasks = weakref.WeakSet()
self._unfinished_tasks = 0
self._errors = []
self._base_error = None
self._on_completed_fut = None
def get_name(self):
return self._name
def __repr__(self):
msg = f'<TaskGroup {self._name!r}'
if self._tasks:
msg += f' tasks:{len(self._tasks)}'
if self._unfinished_tasks:
msg += f' unfinished:{self._unfinished_tasks}'
if self._errors:
msg += f' errors:{len(self._errors)}'
if self._aborting:
msg += ' cancelling'
elif self._entered:
msg += ' entered'
msg += '>'
return msg
async def __aenter__(self):
if self._entered:
raise RuntimeError(
f"TaskGroup {self!r} has been already entered")
self._entered = True
if self._loop is None:
self._loop = events.get_running_loop()
self._parent_task = tasks.current_task(self._loop)
if self._parent_task is None:
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
return self
async def __aexit__(self, et, exc, tb):
self._exiting = True
propagate_cancellation_error = None
if (exc is not None and
self._is_base_error(exc) and
self._base_error is None):
self._base_error = exc
if et is exceptions.CancelledError:
if self._parent_cancel_requested:
# Only if we did request task to cancel ourselves
# we mark it as no longer cancelled.
self._parent_task.uncancel()
else:
propagate_cancellation_error = et
if et is not None and not self._aborting:
# Our parent task is being cancelled:
#
# async with TaskGroup() as g:
# g.create_task(...)
# await ... # <- CancelledError
#
if et is exceptions.CancelledError:
propagate_cancellation_error = et
# or there's an exception in "async with":
#
# async with TaskGroup() as g:
# g.create_task(...)
# 1 / 0
#
self._abort()
# We use while-loop here because "self._on_completed_fut"
# can be cancelled multiple times if our parent task
# is being cancelled repeatedly (or even once, when
# our own cancellation is already in progress)
while self._unfinished_tasks:
if self._on_completed_fut is None:
self._on_completed_fut = self._loop.create_future()
try:
await self._on_completed_fut
except exceptions.CancelledError as ex:
if not self._aborting:
# Our parent task is being cancelled:
#
# async def wrapper():
# async with TaskGroup() as g:
# g.create_task(foo)
#
# "wrapper" is being cancelled while "foo" is
# still running.
propagate_cancellation_error = ex
self._abort()
self._on_completed_fut = None
assert self._unfinished_tasks == 0
self._on_completed_fut = None # no longer needed
if self._base_error is not None:
raise self._base_error
if propagate_cancellation_error is not None:
# The wrapping task was cancelled; since we're done with
# closing all child tasks, just propagate the cancellation
# request now.
raise propagate_cancellation_error
if et is not None and et is not exceptions.CancelledError:
self._errors.append(exc)
if self._errors:
# Exceptions are heavy objects that can have object
# cycles (bad for GC); let's not keep a reference to
# a bunch of them.
errors = self._errors
self._errors = None
me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
raise me from None
def create_task(self, coro):
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and self._unfinished_tasks == 0:
raise RuntimeError(f"TaskGroup {self!r} is finished")
task = self._loop.create_task(coro)
task.add_done_callback(self._on_task_done)
self._unfinished_tasks += 1
self._tasks.add(task)
return task
# Since Python 3.8 Tasks propagate all exceptions correctly,
# except for KeyboardInterrupt and SystemExit which are
# still considered special.
def _is_base_error(self, exc: BaseException) -> bool:
assert isinstance(exc, BaseException)
return isinstance(exc, (SystemExit, KeyboardInterrupt))
def _abort(self):
self._aborting = True
for t in self._tasks:
if not t.done():
t.cancel()
def _on_task_done(self, task):
self._unfinished_tasks -= 1
assert self._unfinished_tasks >= 0
if self._on_completed_fut is not None and not self._unfinished_tasks:
if not self._on_completed_fut.done():
self._on_completed_fut.set_result(True)
if task.cancelled():
return
exc = task.exception()
if exc is None:
return
self._errors.append(exc)
if self._is_base_error(exc) and self._base_error is None:
self._base_error = exc
if self._parent_task.done():
# Not sure if this case is possible, but we want to handle
# it anyways.
self._loop.call_exception_handler({
'message': f'Task {task!r} has errored out but its parent '
f'task {self._parent_task} is already completed',
'exception': exc,
'task': task,
})
return
self._abort()
if not self._parent_task.cancelling():
# If parent task *is not* being cancelled, it means that we want
# to manually cancel it to abort whatever is being run right now
# in the TaskGroup. But we want to mark parent task as
# "not cancelled" later in __aexit__. Example situation that
# we need to handle:
#
# async def foo():
# try:
# async with TaskGroup() as g:
# g.create_task(crash_soon())
# await something # <- this needs to be canceled
# # by the TaskGroup, e.g.
# # foo() needs to be cancelled
# except Exception:
# # Ignore any exceptions raised in the TaskGroup
# pass
# await something_else # this line has to be called
# # after TaskGroup is finished.
self._parent_cancel_requested = True
self._parent_task.cancel()
_name_counter = itertools.count(1).__next__

View file

@ -105,6 +105,7 @@ def __init__(self, coro, *, loop=None, name=None):
else:
self._name = str(name)
self._cancel_requested = False
self._must_cancel = False
self._fut_waiter = None
self._coro = coro
@ -201,6 +202,9 @@ def cancel(self, msg=None):
self._log_traceback = False
if self.done():
return False
if self._cancel_requested:
return False
self._cancel_requested = True
if self._fut_waiter is not None:
if self._fut_waiter.cancel(msg=msg):
# Leave self._fut_waiter; it may be a Task that
@ -212,6 +216,16 @@ def cancel(self, msg=None):
self._cancel_message = msg
return True
def cancelling(self):
return self._cancel_requested
def uncancel(self):
if self._cancel_requested:
self._cancel_requested = False
return True
else:
return False
def __step(self, exc=None):
if self.done():
raise exceptions.InvalidStateError(
@ -634,7 +648,7 @@ def _ensure_future(coro_or_future, *, loop=None):
loop = events._get_event_loop(stacklevel=4)
try:
return loop.create_task(coro_or_future)
except RuntimeError:
except RuntimeError:
if not called_wrap_awaitable:
coro_or_future.close()
raise

View file

@ -0,0 +1,694 @@
# Adapted with permission from the EdgeDB project.
import asyncio
from asyncio import taskgroups
import unittest
# To prevent a warning "test altered the execution environment"
def tearDownModule():
asyncio.set_event_loop_policy(None)
class MyExc(Exception):
pass
class MyBaseExc(BaseException):
pass
def get_error_types(eg):
return {type(exc) for exc in eg.exceptions}
class TestTaskGroup(unittest.IsolatedAsyncioTestCase):
async def test_taskgroup_01(self):
async def foo1():
await asyncio.sleep(0.1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
self.assertEqual(t1.result(), 42)
self.assertEqual(t2.result(), 11)
async def test_taskgroup_02(self):
async def foo1():
await asyncio.sleep(0.1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
await asyncio.sleep(0.15)
t2 = g.create_task(foo2())
self.assertEqual(t1.result(), 42)
self.assertEqual(t2.result(), 11)
async def test_taskgroup_03(self):
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(0.2)
return 11
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
await asyncio.sleep(0.15)
# cancel t1 explicitly, i.e. everything should continue
# working as expected.
t1.cancel()
t2 = g.create_task(foo2())
self.assertTrue(t1.cancelled())
self.assertEqual(t2.result(), 11)
async def test_taskgroup_04(self):
NUM = 0
t2_cancel = False
t2 = None
async def foo1():
await asyncio.sleep(0.1)
1 / 0
async def foo2():
nonlocal NUM, t2_cancel
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
t2_cancel = True
raise
NUM += 1
async def runner():
nonlocal NUM, t2
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
t2 = g.create_task(foo2())
NUM += 10
with self.assertRaises(ExceptionGroup) as cm:
await asyncio.create_task(runner())
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertEqual(NUM, 0)
self.assertTrue(t2_cancel)
self.assertTrue(t2.cancelled())
async def test_taskgroup_05(self):
NUM = 0
t2_cancel = False
runner_cancel = False
async def foo1():
await asyncio.sleep(0.1)
1 / 0
async def foo2():
nonlocal NUM, t2_cancel
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
t2_cancel = True
raise
NUM += 1
async def runner():
nonlocal NUM, runner_cancel
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
g.create_task(foo1())
g.create_task(foo1())
g.create_task(foo2())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
runner_cancel = True
raise
NUM += 10
# The 3 foo1 sub tasks can be racy when the host is busy - if the
# cancellation happens in the middle, we'll see partial sub errors here
with self.assertRaises(ExceptionGroup) as cm:
await asyncio.create_task(runner())
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertEqual(NUM, 0)
self.assertTrue(t2_cancel)
self.assertTrue(runner_cancel)
async def test_taskgroup_06(self):
NUM = 0
async def foo():
nonlocal NUM
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
NUM += 1
raise
async def runner():
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
self.assertEqual(NUM, 5)
async def test_taskgroup_07(self):
NUM = 0
async def foo():
nonlocal NUM
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
NUM += 1
raise
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
self.assertEqual(NUM, 15)
async def test_taskgroup_08(self):
async def foo():
await asyncio.sleep(0.1)
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g:
for _ in range(5):
g.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_09(self):
t1 = t2 = None
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
nonlocal t1, t2
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
await asyncio.sleep(0.1)
1 / 0
try:
await runner()
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {ZeroDivisionError})
else:
self.fail('ExceptionGroup was not raised')
self.assertTrue(t1.cancelled())
self.assertTrue(t2.cancelled())
async def test_taskgroup_10(self):
t1 = t2 = None
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
nonlocal t1, t2
async with taskgroups.TaskGroup() as g:
t1 = g.create_task(foo1())
t2 = g.create_task(foo2())
1 / 0
try:
await runner()
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {ZeroDivisionError})
else:
self.fail('ExceptionGroup was not raised')
self.assertTrue(t1.cancelled())
self.assertTrue(t2.cancelled())
async def test_taskgroup_11(self):
async def foo():
await asyncio.sleep(0.1)
1 / 0
async def runner():
async with taskgroups.TaskGroup():
async with taskgroups.TaskGroup() as g2:
for _ in range(5):
g2.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_12(self):
async def foo():
await asyncio.sleep(0.1)
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g1:
g1.create_task(asyncio.sleep(10))
async with taskgroups.TaskGroup() as g2:
for _ in range(5):
g2.create_task(foo())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_13(self):
async def crash_after(t):
await asyncio.sleep(t)
raise ValueError(t)
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
g1.create_task(crash_after(0.1))
async with taskgroups.TaskGroup(name='g2') as g2:
g2.create_task(crash_after(0.2))
r = asyncio.create_task(runner())
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ValueError})
async def test_taskgroup_14(self):
async def crash_after(t):
await asyncio.sleep(t)
raise ValueError(t)
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
g1.create_task(crash_after(10))
async with taskgroups.TaskGroup(name='g2') as g2:
g2.create_task(crash_after(0.1))
r = asyncio.create_task(runner())
with self.assertRaises(ExceptionGroup) as cm:
await r
self.assertEqual(get_error_types(cm.exception), {ExceptionGroup})
self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ValueError})
async def test_taskgroup_15(self):
async def crash_soon():
await asyncio.sleep(0.3)
1 / 0
async def runner():
async with taskgroups.TaskGroup(name='g1') as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
await asyncio.sleep(0.5)
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_16(self):
async def crash_soon():
await asyncio.sleep(0.3)
1 / 0
async def nested_runner():
async with taskgroups.TaskGroup(name='g1') as g1:
g1.create_task(crash_soon())
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
await asyncio.sleep(0.5)
raise
async def runner():
t = asyncio.create_task(nested_runner())
await t
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_17(self):
NUM = 0
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
raise
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
self.assertEqual(NUM, 10)
async def test_taskgroup_18(self):
NUM = 0
async def runner():
nonlocal NUM
async with taskgroups.TaskGroup():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
NUM += 10
# This isn't a good idea, but we have to support
# this weird case.
raise MyExc
r = asyncio.create_task(runner())
await asyncio.sleep(0.1)
self.assertFalse(r.done())
r.cancel()
try:
await r
except ExceptionGroup as t:
self.assertEqual(get_error_types(t),{MyExc})
else:
self.fail('ExceptionGroup was not raised')
self.assertEqual(NUM, 10)
async def test_taskgroup_19(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise MyExc
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
r = asyncio.create_task(runner())
try:
await r
except ExceptionGroup as t:
self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError})
else:
self.fail('TasgGroupError was not raised')
async def test_taskgroup_20(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise KeyboardInterrupt
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(KeyboardInterrupt):
await runner()
async def test_taskgroup_20a(self):
async def crash_soon():
await asyncio.sleep(0.1)
1 / 0
async def nested():
try:
await asyncio.sleep(10)
finally:
raise MyBaseExc
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(BaseExceptionGroup) as cm:
await runner()
self.assertEqual(
get_error_types(cm.exception), {MyBaseExc, ZeroDivisionError}
)
async def _test_taskgroup_21(self):
# This test doesn't work as asyncio, currently, doesn't
# correctly propagate KeyboardInterrupt (or SystemExit) --
# those cause the event loop itself to crash.
# (Compare to the previous (passing) test -- that one raises
# a plain exception but raises KeyboardInterrupt in nested();
# this test does it the other way around.)
async def crash_soon():
await asyncio.sleep(0.1)
raise KeyboardInterrupt
async def nested():
try:
await asyncio.sleep(10)
finally:
raise TypeError
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(KeyboardInterrupt):
await runner()
async def test_taskgroup_21a(self):
async def crash_soon():
await asyncio.sleep(0.1)
raise MyBaseExc
async def nested():
try:
await asyncio.sleep(10)
finally:
raise TypeError
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(crash_soon())
await nested()
with self.assertRaises(BaseExceptionGroup) as cm:
await runner()
self.assertEqual(get_error_types(cm.exception), {MyBaseExc, TypeError})
async def test_taskgroup_22(self):
async def foo1():
await asyncio.sleep(1)
return 42
async def foo2():
await asyncio.sleep(2)
return 11
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(foo1())
g.create_task(foo2())
r = asyncio.create_task(runner())
await asyncio.sleep(0.05)
r.cancel()
with self.assertRaises(asyncio.CancelledError):
await r
async def test_taskgroup_23(self):
async def do_job(delay):
await asyncio.sleep(delay)
async with taskgroups.TaskGroup() as g:
for count in range(10):
await asyncio.sleep(0.1)
g.create_task(do_job(0.3))
if count == 5:
self.assertLess(len(g._tasks), 5)
await asyncio.sleep(1.35)
self.assertEqual(len(g._tasks), 0)
async def test_taskgroup_24(self):
async def root(g):
await asyncio.sleep(0.1)
g.create_task(coro1(0.1))
g.create_task(coro1(0.2))
async def coro1(delay):
await asyncio.sleep(delay)
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(root(g))
await runner()
async def test_taskgroup_25(self):
nhydras = 0
async def hydra(g):
nonlocal nhydras
nhydras += 1
await asyncio.sleep(0.01)
g.create_task(hydra(g))
g.create_task(hydra(g))
async def hercules():
while nhydras < 10:
await asyncio.sleep(0.015)
1 / 0
async def runner():
async with taskgroups.TaskGroup() as g:
g.create_task(hydra(g))
g.create_task(hercules())
with self.assertRaises(ExceptionGroup) as cm:
await runner()
self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError})
self.assertGreaterEqual(nhydras, 10)

View file

@ -496,6 +496,51 @@ async def run():
# This also distinguishes from the initial has_cycle=None.
self.assertEqual(has_cycle, False)
def test_cancelling(self):
loop = asyncio.new_event_loop()
async def task():
await asyncio.sleep(10)
try:
t = self.new_task(loop, task())
self.assertFalse(t.cancelling())
self.assertNotIn(" cancelling ", repr(t))
self.assertTrue(t.cancel())
self.assertTrue(t.cancelling())
self.assertIn(" cancelling ", repr(t))
self.assertFalse(t.cancel())
with self.assertRaises(asyncio.CancelledError):
loop.run_until_complete(t)
finally:
loop.close()
def test_uncancel(self):
loop = asyncio.new_event_loop()
async def task():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
asyncio.current_task().uncancel()
await asyncio.sleep(10)
try:
t = self.new_task(loop, task())
loop.run_until_complete(asyncio.sleep(0.01))
self.assertTrue(t.cancel()) # Cancel first sleep
self.assertIn(" cancelling ", repr(t))
loop.run_until_complete(asyncio.sleep(0.01))
self.assertNotIn(" cancelling ", repr(t)) # after .uncancel()
self.assertTrue(t.cancel()) # Cancel second sleep
with self.assertRaises(asyncio.CancelledError):
loop.run_until_complete(t)
finally:
loop.close()
def test_cancel(self):
def gen():

View file

@ -0,0 +1,2 @@
Add task groups to asyncio (structured concurrency, inspired by Trio's nurseries).
This also introduces a change to task cancellation, where a cancelled task can't be cancelled again until it calls .uncancel().

View file

@ -91,6 +91,7 @@ typedef struct {
PyObject *task_context;
int task_must_cancel;
int task_log_destroy_pending;
int task_cancel_requested;
} TaskObj;
typedef struct {
@ -2039,6 +2040,7 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
Py_CLEAR(self->task_fut_waiter);
self->task_must_cancel = 0;
self->task_log_destroy_pending = 1;
self->task_cancel_requested = 0;
Py_INCREF(coro);
Py_XSETREF(self->task_coro, coro);
@ -2205,6 +2207,11 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg)
Py_RETURN_FALSE;
}
if (self->task_cancel_requested) {
Py_RETURN_FALSE;
}
self->task_cancel_requested = 1;
if (self->task_fut_waiter) {
PyObject *res;
int is_true;
@ -2232,6 +2239,56 @@ _asyncio_Task_cancel_impl(TaskObj *self, PyObject *msg)
Py_RETURN_TRUE;
}
/*[clinic input]
_asyncio.Task.cancelling
Return True if the task is in the process of being cancelled.
This is set once .cancel() is called
and remains set until .uncancel() is called.
As long as this flag is set, further .cancel() calls will be ignored,
until .uncancel() is called to reset it.
[clinic start generated code]*/
static PyObject *
_asyncio_Task_cancelling_impl(TaskObj *self)
/*[clinic end generated code: output=803b3af96f917d7e input=c50e50f9c3ca4676]*/
/*[clinic end generated code]*/
{
if (self->task_cancel_requested) {
Py_RETURN_TRUE;
}
else {
Py_RETURN_FALSE;
}
}
/*[clinic input]
_asyncio.Task.uncancel
Reset the flag returned by cancelling().
This should be used by tasks that catch CancelledError
and wish to continue indefinitely until they are cancelled again.
Returns the previous value of the flag.
[clinic start generated code]*/
static PyObject *
_asyncio_Task_uncancel_impl(TaskObj *self)
/*[clinic end generated code: output=58184d236a817d3c input=5db95e28fcb6f7cd]*/
/*[clinic end generated code]*/
{
if (self->task_cancel_requested) {
self->task_cancel_requested = 0;
Py_RETURN_TRUE;
}
else {
Py_RETURN_FALSE;
}
}
/*[clinic input]
_asyncio.Task.get_stack
@ -2455,6 +2512,8 @@ static PyMethodDef TaskType_methods[] = {
_ASYNCIO_TASK_SET_RESULT_METHODDEF
_ASYNCIO_TASK_SET_EXCEPTION_METHODDEF
_ASYNCIO_TASK_CANCEL_METHODDEF
_ASYNCIO_TASK_CANCELLING_METHODDEF
_ASYNCIO_TASK_UNCANCEL_METHODDEF
_ASYNCIO_TASK_GET_STACK_METHODDEF
_ASYNCIO_TASK_PRINT_STACK_METHODDEF
_ASYNCIO_TASK__MAKE_CANCELLED_ERROR_METHODDEF

View file

@ -447,6 +447,53 @@ exit:
return return_value;
}
PyDoc_STRVAR(_asyncio_Task_cancelling__doc__,
"cancelling($self, /)\n"
"--\n"
"\n"
"Return True if the task is in the process of being cancelled.\n"
"\n"
"This is set once .cancel() is called\n"
"and remains set until .uncancel() is called.\n"
"\n"
"As long as this flag is set, further .cancel() calls will be ignored,\n"
"until .uncancel() is called to reset it.");
#define _ASYNCIO_TASK_CANCELLING_METHODDEF \
{"cancelling", (PyCFunction)_asyncio_Task_cancelling, METH_NOARGS, _asyncio_Task_cancelling__doc__},
static PyObject *
_asyncio_Task_cancelling_impl(TaskObj *self);
static PyObject *
_asyncio_Task_cancelling(TaskObj *self, PyObject *Py_UNUSED(ignored))
{
return _asyncio_Task_cancelling_impl(self);
}
PyDoc_STRVAR(_asyncio_Task_uncancel__doc__,
"uncancel($self, /)\n"
"--\n"
"\n"
"Reset the flag returned by cancelling().\n"
"\n"
"This should be used by tasks that catch CancelledError\n"
"and wish to continue indefinitely until they are cancelled again.\n"
"\n"
"Returns the previous value of the flag.");
#define _ASYNCIO_TASK_UNCANCEL_METHODDEF \
{"uncancel", (PyCFunction)_asyncio_Task_uncancel, METH_NOARGS, _asyncio_Task_uncancel__doc__},
static PyObject *
_asyncio_Task_uncancel_impl(TaskObj *self);
static PyObject *
_asyncio_Task_uncancel(TaskObj *self, PyObject *Py_UNUSED(ignored))
{
return _asyncio_Task_uncancel_impl(self);
}
PyDoc_STRVAR(_asyncio_Task_get_stack__doc__,
"get_stack($self, /, *, limit=None)\n"
"--\n"
@ -871,4 +918,4 @@ _asyncio__leave_task(PyObject *module, PyObject *const *args, Py_ssize_t nargs,
exit:
return return_value;
}
/*[clinic end generated code: output=0d127162ac92e0c0 input=a9049054013a1b77]*/
/*[clinic end generated code: output=c02708a9d6a774cc input=a9049054013a1b77]*/