Protect waiting for event loop from within event loop (#3658)

* Protect waiting for event loop from within event loop

* Faster fetching of loop attribute for ident check
This commit is contained in:
Paulus Schoutsen 2016-10-02 15:07:23 -07:00 committed by GitHub
parent e455daa61d
commit abb8bcb6d9
4 changed files with 93 additions and 0 deletions

View file

@ -199,6 +199,8 @@ class HomeAssistant(object):
This method is a coroutine. This method is a coroutine.
""" """
# pylint: disable=protected-access
self.loop._thread_ident = threading.get_ident()
async_create_timer(self) async_create_timer(self)
async_monitor_worker_pool(self) async_monitor_worker_pool(self)
self.bus.async_fire(EVENT_HOMEASSISTANT_START) self.bus.async_fire(EVENT_HOMEASSISTANT_START)

View file

@ -1,5 +1,6 @@
"""Asyncio backports for Python 3.4.3 compatibility.""" """Asyncio backports for Python 3.4.3 compatibility."""
import concurrent.futures import concurrent.futures
import threading
from asyncio import coroutines from asyncio import coroutines
from asyncio.futures import Future from asyncio.futures import Future
@ -97,6 +98,10 @@ def run_coroutine_threadsafe(coro, loop):
Return a concurrent.futures.Future to access the result. Return a concurrent.futures.Future to access the result.
""" """
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
if not coroutines.iscoroutine(coro): if not coroutines.iscoroutine(coro):
raise TypeError('A coroutine object is required') raise TypeError('A coroutine object is required')
future = concurrent.futures.Future() future = concurrent.futures.Future()
@ -122,6 +127,10 @@ def fire_coroutine_threadsafe(coro, loop):
is intended for fire-and-forget use. This reduces the is intended for fire-and-forget use. This reduces the
work involved to fire the function on the loop. work involved to fire the function on the loop.
""" """
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
if not coroutines.iscoroutine(coro): if not coroutines.iscoroutine(coro):
raise TypeError('A coroutine object is required: %s' % coro) raise TypeError('A coroutine object is required: %s' % coro)
@ -139,6 +148,10 @@ def run_callback_threadsafe(loop, callback, *args):
Return a concurrent.futures.Future to access the result. Return a concurrent.futures.Future to access the result.
""" """
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
future = concurrent.futures.Future() future = concurrent.futures.Future()
def run_callback(): def run_callback():

View file

@ -58,6 +58,7 @@ def get_test_home_assistant(num_threads=None):
stop_event = threading.Event() stop_event = threading.Event()
def run_loop(): def run_loop():
loop._thread_ident = threading.get_ident()
loop.run_forever() loop.run_forever()
loop.close() loop.close()
stop_event.set() stop_event.set()

View file

@ -1,10 +1,87 @@
"""Tests for async util methods from Python source.""" """Tests for async util methods from Python source."""
import asyncio import asyncio
from asyncio import test_utils from asyncio import test_utils
from unittest.mock import MagicMock, patch
import pytest
from homeassistant.util import async as hasync from homeassistant.util import async as hasync
@patch('asyncio.coroutines.iscoroutine', return_value=True)
@patch('concurrent.futures.Future')
@patch('threading.get_ident')
def test_run_coroutine_threadsafe_from_inside_event_loop(mock_ident, _, __):
"""Testing calling run_coroutine_threadsafe from inside an event loop."""
coro = MagicMock()
loop = MagicMock()
loop._thread_ident = None
mock_ident.return_value = 5
hasync.run_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
hasync.run_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
hasync.run_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 2
@patch('asyncio.coroutines.iscoroutine', return_value=True)
@patch('concurrent.futures.Future')
@patch('threading.get_ident')
def test_fire_coroutine_threadsafe_from_inside_event_loop(mock_ident, _, __):
"""Testing calling fire_coroutine_threadsafe from inside an event loop."""
coro = MagicMock()
loop = MagicMock()
loop._thread_ident = None
mock_ident.return_value = 5
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
hasync.fire_coroutine_threadsafe(coro, loop)
assert len(loop.call_soon_threadsafe.mock_calls) == 2
@patch('concurrent.futures.Future')
@patch('threading.get_ident')
def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _):
"""Testing calling run_callback_threadsafe from inside an event loop."""
callback = MagicMock()
loop = MagicMock()
loop._thread_ident = None
mock_ident.return_value = 5
hasync.run_callback_threadsafe(loop, callback)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 5
mock_ident.return_value = 5
with pytest.raises(RuntimeError):
hasync.run_callback_threadsafe(loop, callback)
assert len(loop.call_soon_threadsafe.mock_calls) == 1
loop._thread_ident = 1
mock_ident.return_value = 5
hasync.run_callback_threadsafe(loop, callback)
assert len(loop.call_soon_threadsafe.mock_calls) == 2
class RunCoroutineThreadsafeTests(test_utils.TestCase): class RunCoroutineThreadsafeTests(test_utils.TestCase):
"""Test case for asyncio.run_coroutine_threadsafe.""" """Test case for asyncio.run_coroutine_threadsafe."""