GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-94705)

This commit is contained in:
Kumar Aditya 2022-08-30 00:01:11 +05:30 committed by GitHub
parent 3d180e3ab2
commit e5b2453e61
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 19 deletions

View file

@ -2,6 +2,7 @@
'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server')
import collections
import socket
import sys
import weakref
@ -128,7 +129,7 @@ def __init__(self, loop=None):
else:
self._loop = loop
self._paused = False
self._drain_waiter = None
self._drain_waiters = collections.deque()
self._connection_lost = False
def pause_writing(self):
@ -143,38 +144,34 @@ def resume_writing(self):
if self._loop.get_debug():
logger.debug("%r resumes writing", self)
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
for waiter in self._drain_waiters:
if not waiter.done():
waiter.set_result(None)
def connection_lost(self, exc):
self._connection_lost = True
# Wake up the writer if currently paused.
# Wake up the writer(s) if currently paused.
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
for waiter in self._drain_waiters:
if not waiter.done():
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
async def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = self._loop.create_future()
self._drain_waiter = waiter
await waiter
self._drain_waiters.append(waiter)
try:
await waiter
finally:
self._drain_waiters.remove(waiter)
def _get_close_waiter(self, stream):
raise NotImplementedError

View file

@ -864,6 +864,25 @@ def test_streamreaderprotocol_constructor_use_global_loop(self):
self.assertEqual(cm.filename, __file__)
self.assertIs(protocol._loop, self.loop)
def test_multiple_drain(self):
# See https://github.com/python/cpython/issues/74116
drained = 0
async def drainer(stream):
nonlocal drained
await stream._drain_helper()
drained += 1
async def main():
loop = asyncio.get_running_loop()
stream = asyncio.streams.FlowControlMixin(loop)
stream.pause_writing()
loop.call_later(0.1, stream.resume_writing)
await asyncio.gather(*[drainer(stream) for _ in range(10)])
self.assertEqual(drained, 10)
self.loop.run_until_complete(main())
def test_drain_raises(self):
# See http://bugs.python.org/issue25441

View file

@ -0,0 +1 @@
Allow :meth:`asyncio.StreamWriter.drain` to be awaited concurrently by multiple tasks. Patch by Kumar Aditya.