Protect waiting for event loop from within event loop (#3658)

* Protect waiting for event loop from within event loop

* Faster fetching of loop attribute for ident check
This commit is contained in:
Paulus Schoutsen 2016-10-02 15:07:23 -07:00 committed by GitHub
parent e455daa61d
commit abb8bcb6d9
4 changed files with 93 additions and 0 deletions

View file

@ -199,6 +199,8 @@ class HomeAssistant(object):
This method is a coroutine.
"""
# pylint: disable=protected-access
self.loop._thread_ident = threading.get_ident()
async_create_timer(self)
async_monitor_worker_pool(self)
self.bus.async_fire(EVENT_HOMEASSISTANT_START)

View file

@ -1,5 +1,6 @@
"""Asyncio backports for Python 3.4.3 compatibility."""
import concurrent.futures
import threading
from asyncio import coroutines
from asyncio.futures import Future
@ -97,6 +98,10 @@ def run_coroutine_threadsafe(coro, loop):
Return a concurrent.futures.Future to access the result.
"""
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
if not coroutines.iscoroutine(coro):
raise TypeError('A coroutine object is required')
future = concurrent.futures.Future()
@ -122,6 +127,10 @@ def fire_coroutine_threadsafe(coro, loop):
is intended for fire-and-forget use. This reduces the
work involved to fire the function on the loop.
"""
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
if not coroutines.iscoroutine(coro):
raise TypeError('A coroutine object is required: %s' % coro)
@ -139,6 +148,10 @@ def run_callback_threadsafe(loop, callback, *args):
Return a concurrent.futures.Future to access the result.
"""
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
future = concurrent.futures.Future()
def run_callback():

View file

@ -58,6 +58,7 @@ def get_test_home_assistant(num_threads=None):
stop_event = threading.Event()
def run_loop():
loop._thread_ident = threading.get_ident()
loop.run_forever()
loop.close()
stop_event.set()

View file

@ -1,10 +1,87 @@
"""Tests for async util methods from Python source."""
import asyncio
from asyncio import test_utils
from unittest.mock import MagicMock, patch
import pytest
from homeassistant.util import async as hasync
@patch('asyncio.coroutines.iscoroutine', return_value=True)
@patch('concurrent.futures.Future')
@patch('threading.get_ident')
def test_run_coroutine_threadsafe_from_inside_event_loop(mock_ident, _, __):
"""Testing calling run_coroutine_threadsafe from inside an event loop."""
coro = MagicMock()
loop = MagicMock()
loop._thread_ident = None
mock_ident.return_value = 5
hasync.run_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
hasync.run_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
hasync.run_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 2
@patch('asyncio.coroutines.iscoroutine', return_value=True)
@patch('concurrent.futures.Future')
@patch('threading.get_ident')
def test_fire_coroutine_threadsafe_from_inside_event_loop(mock_ident, _, __):
"""Testing calling fire_coroutine_threadsafe from inside an event loop."""
coro = MagicMock()
loop = MagicMock()
loop._thread_ident = None
mock_ident.return_value = 5
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 2
@patch('concurrent.futures.Future')
@patch('threading.get_ident')
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _):
"""Testing calling run_callback_threadsafe from inside an event loop."""
callback = MagicMock()
loop = MagicMock()
loop._thread_ident = None
mock_ident.return_value = 5
hasync.run_callback_threadsafe(loop, callback)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
hasync.run_callback_threadsafe(loop, callback)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
hasync.run_callback_threadsafe(loop, callback)
assert len(loop.call_soon_threadsafe.mock_calls) == 2
class RunCoroutineThreadsafeTests(test_utils.TestCase):
"""Test case for asyncio.run_coroutine_threadsafe."""