From abb8bcb6d934d22144c40e22459dd2b62e7353b9 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 2 Oct 2016 15:07:23 -0700 Subject: [PATCH] Protect waiting for event loop from within event loop (#3658) * Protect waiting for event loop from within event loop * Faster fetching of loop attribute for ident check --- homeassistant/core.py | 2 + homeassistant/util/async.py | 13 +++++++ tests/common.py | 1 + tests/util/test_async.py | 77 +++++++++++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+) diff --git a/homeassistant/core.py b/homeassistant/core.py index 43c20f18b75a..1ef0adc59614 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -199,6 +199,8 @@ class HomeAssistant(object): This method is a coroutine. """ + # pylint: disable=protected-access + self.loop._thread_ident = threading.get_ident() async_create_timer(self) async_monitor_worker_pool(self) self.bus.async_fire(EVENT_HOMEASSISTANT_START) diff --git a/homeassistant/util/async.py b/homeassistant/util/async.py index 54a3204c78df..ff498912fc2f 100644 --- a/homeassistant/util/async.py +++ b/homeassistant/util/async.py @@ -1,5 +1,6 @@ """Asyncio backports for Python 3.4.3 compatibility.""" import concurrent.futures +import threading from asyncio import coroutines from asyncio.futures import Future @@ -97,6 +98,10 @@ def run_coroutine_threadsafe(coro, loop): Return a concurrent.futures.Future to access the result. """ + ident = loop.__dict__.get("_thread_ident") + if ident is not None and ident == threading.get_ident(): + raise RuntimeError('Cannot be called from within the event loop') + if not coroutines.iscoroutine(coro): raise TypeError('A coroutine object is required') future = concurrent.futures.Future() @@ -122,6 +127,10 @@ def fire_coroutine_threadsafe(coro, loop): is intended for fire-and-forget use. This reduces the work involved to fire the function on the loop. """ + ident = loop.__dict__.get("_thread_ident") + if ident is not None and ident == threading.get_ident(): + raise RuntimeError('Cannot be called from within the event loop') + if not coroutines.iscoroutine(coro): raise TypeError('A coroutine object is required: %s' % coro) @@ -139,6 +148,10 @@ def run_callback_threadsafe(loop, callback, *args): Return a concurrent.futures.Future to access the result. """ + ident = loop.__dict__.get("_thread_ident") + if ident is not None and ident == threading.get_ident(): + raise RuntimeError('Cannot be called from within the event loop') + future = concurrent.futures.Future() def run_callback(): diff --git a/tests/common.py b/tests/common.py index ceb9bf3c058c..b44cbee4b6fe 100644 --- a/tests/common.py +++ b/tests/common.py @@ -58,6 +58,7 @@ def get_test_home_assistant(num_threads=None): stop_event = threading.Event() def run_loop(): + loop._thread_ident = threading.get_ident() loop.run_forever() loop.close() stop_event.set() diff --git a/tests/util/test_async.py b/tests/util/test_async.py index 079097f33268..f88887e3c6e3 100644 --- a/tests/util/test_async.py +++ b/tests/util/test_async.py @@ -1,10 +1,87 @@ """Tests for async util methods from Python source.""" import asyncio from asyncio import test_utils +from unittest.mock import MagicMock, patch + +import pytest from homeassistant.util import async as hasync +@patch('asyncio.coroutines.iscoroutine', return_value=True) +@patch('concurrent.futures.Future') +@patch('threading.get_ident') +def test_run_coroutine_threadsafe_from_inside_event_loop(mock_ident, _, __): + """Testing calling run_coroutine_threadsafe from inside an event loop.""" + coro = MagicMock() + loop = MagicMock() + + loop._thread_ident = None + mock_ident.return_value = 5 + hasync.run_coroutine_threadsafe(coro, loop) + assert len(loop.call_soon_threadsafe.mock_calls) == 1 + + loop._thread_ident = 5 + mock_ident.return_value = 5 + with pytest.raises(RuntimeError): + hasync.run_coroutine_threadsafe(coro, loop) + assert len(loop.call_soon_threadsafe.mock_calls) == 1 + + loop._thread_ident = 1 + mock_ident.return_value = 5 + hasync.run_coroutine_threadsafe(coro, loop) + assert len(loop.call_soon_threadsafe.mock_calls) == 2 + + +@patch('asyncio.coroutines.iscoroutine', return_value=True) +@patch('concurrent.futures.Future') +@patch('threading.get_ident') +def test_fire_coroutine_threadsafe_from_inside_event_loop(mock_ident, _, __): + """Testing calling fire_coroutine_threadsafe from inside an event loop.""" + coro = MagicMock() + loop = MagicMock() + + loop._thread_ident = None + mock_ident.return_value = 5 + hasync.fire_coroutine_threadsafe(coro, loop) + assert len(loop.call_soon_threadsafe.mock_calls) == 1 + + loop._thread_ident = 5 + mock_ident.return_value = 5 + with pytest.raises(RuntimeError): + hasync.fire_coroutine_threadsafe(coro, loop) + assert len(loop.call_soon_threadsafe.mock_calls) == 1 + + loop._thread_ident = 1 + mock_ident.return_value = 5 + hasync.fire_coroutine_threadsafe(coro, loop) + assert len(loop.call_soon_threadsafe.mock_calls) == 2 + + +@patch('concurrent.futures.Future') +@patch('threading.get_ident') +def test_run_callback_threadsafe_from_inside_event_loop(mock_ident, _): + """Testing calling run_callback_threadsafe from inside an event loop.""" + callback = MagicMock() + loop = MagicMock() + + loop._thread_ident = None + mock_ident.return_value = 5 + hasync.run_callback_threadsafe(loop, callback) + assert len(loop.call_soon_threadsafe.mock_calls) == 1 + + loop._thread_ident = 5 + mock_ident.return_value = 5 + with pytest.raises(RuntimeError): + hasync.run_callback_threadsafe(loop, callback) + assert len(loop.call_soon_threadsafe.mock_calls) == 1 + + loop._thread_ident = 1 + mock_ident.return_value = 5 + hasync.run_callback_threadsafe(loop, callback) + assert len(loop.call_soon_threadsafe.mock_calls) == 2 + + class RunCoroutineThreadsafeTests(test_utils.TestCase): """Test case for asyncio.run_coroutine_threadsafe."""