Automatically clean up executor as part of closing loop (#43284)

This commit is contained in:
Paulus Schoutsen 2020-11-16 15:43:48 +01:00 committed by GitHub
parent 5d83f0a911
commit 819dd27925
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 166 deletions

View file

@ -15,11 +15,7 @@ import yarl
from homeassistant import config as conf_util, config_entries, core, loader
from homeassistant.components import http
from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
REQUIRED_NEXT_PYTHON_DATE,
REQUIRED_NEXT_PYTHON_VER,
)
from homeassistant.const import REQUIRED_NEXT_PYTHON_DATE, REQUIRED_NEXT_PYTHON_VER
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import (
@ -142,11 +138,9 @@ async def async_setup_hass(
_LOGGER.warning("Detected that frontend did not load. Activating safe mode")
# Ask integrations to shut down. It's messy but we can't
# do a clean stop without knowing what is broken
hass.async_track_tasks()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP, {})
with contextlib.suppress(asyncio.TimeoutError):
async with hass.timeout.async_timeout(10):
await hass.async_block_till_done()
await hass.async_stop()
safe_mode = True
old_config = hass.config

View file

@ -257,12 +257,9 @@ class HomeAssistant:
fire_coroutine_threadsafe(self.async_start(), self.loop)
# Run forever
try:
# Block until stopped
_LOGGER.info("Starting Home Assistant core loop")
self.loop.run_forever()
finally:
self.loop.close()
# Block until stopped
_LOGGER.info("Starting Home Assistant core loop")
self.loop.run_forever()
return self.exit_code
async def async_run(self, *, attach_signals: bool = True) -> int:
@ -559,16 +556,11 @@ class HomeAssistant:
"Timed out waiting for shutdown stage 3 to complete, the shutdown will continue"
)
# Python 3.9+ and backported in runner.py
await self.loop.shutdown_default_executor() # type: ignore
self.exit_code = exit_code
self.state = CoreState.stopped
if self._stopped is not None:
self._stopped.set()
else:
self.loop.stop()
@attr.s(slots=True, frozen=True)

View file

@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor
import dataclasses
import logging
import sys
import threading
from typing import Any, Dict, Optional
from homeassistant import bootstrap
@ -77,29 +76,14 @@ class HassEventLoopPolicy(PolicyBase): # type: ignore
loop.set_default_executor, "sets default executor on the event loop"
)
# Python 3.9+
if hasattr(loop, "shutdown_default_executor"):
return loop
# Shut down executor when we shut down loop
orig_close = loop.close
# Copied from Python 3.9 source
def _do_shutdown(future: asyncio.Future) -> None:
try:
executor.shutdown(wait=True)
loop.call_soon_threadsafe(future.set_result, None)
except Exception as ex: # pylint: disable=broad-except
loop.call_soon_threadsafe(future.set_exception, ex)
def close() -> None:
executor.shutdown(wait=True)
orig_close()
async def shutdown_default_executor() -> None:
"""Schedule the shutdown of the default executor."""
future = loop.create_future()
thread = threading.Thread(target=_do_shutdown, args=(future,))
thread.start()
try:
await future
finally:
thread.join()
setattr(loop, "shutdown_default_executor", shutdown_default_executor)
loop.close = close # type: ignore
return loop

View file

@ -9,7 +9,6 @@ from io import StringIO
import json
import logging
import os
import sys
import threading
import time
import uuid
@ -109,24 +108,21 @@ def get_test_config_dir(*add_path):
def get_test_home_assistant():
"""Return a Home Assistant object pointing at test config directory."""
if sys.platform == "win32":
loop = asyncio.ProactorEventLoop()
else:
loop = asyncio.new_event_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
hass = loop.run_until_complete(async_test_home_assistant(loop))
stop_event = threading.Event()
loop_stop_event = threading.Event()
def run_loop():
"""Run event loop."""
# pylint: disable=protected-access
loop._thread_ident = threading.get_ident()
loop.run_forever()
stop_event.set()
loop_stop_event.set()
orig_stop = hass.stop
hass._stopped = Mock(set=loop.stop)
def start_hass(*mocks):
"""Start hass."""
@ -135,7 +131,7 @@ def get_test_home_assistant():
def stop_hass():
"""Stop hass."""
orig_stop()
stop_event.wait()
loop_stop_event.wait()
loop.close()
hass.start = start_hass

View file

@ -38,7 +38,11 @@ import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM
from tests.async_mock import MagicMock, Mock, PropertyMock, patch
from tests.common import async_mock_service, get_test_home_assistant
from tests.common import (
async_capture_events,
async_mock_service,
get_test_home_assistant,
)
PST = pytz.timezone("America/Los_Angeles")
@ -151,22 +155,14 @@ def test_async_run_hass_job_delegates_non_async():
assert len(hass.async_add_hass_job.mock_calls) == 1
def test_stage_shutdown():
async def test_stage_shutdown(hass):
"""Simulate a shutdown, test calling stuff."""
hass = get_test_home_assistant()
test_stop = []
test_final_write = []
test_close = []
test_all = []
test_stop = async_capture_events(hass, EVENT_HOMEASSISTANT_STOP)
test_final_write = async_capture_events(hass, EVENT_HOMEASSISTANT_FINAL_WRITE)
test_close = async_capture_events(hass, EVENT_HOMEASSISTANT_CLOSE)
test_all = async_capture_events(hass, MATCH_ALL)
hass.bus.listen(EVENT_HOMEASSISTANT_STOP, lambda event: test_stop.append(event))
hass.bus.listen(
EVENT_HOMEASSISTANT_FINAL_WRITE, lambda event: test_final_write.append(event)
)
hass.bus.listen(EVENT_HOMEASSISTANT_CLOSE, lambda event: test_close.append(event))
hass.bus.listen("*", lambda event: test_all.append(event))
hass.stop()
await hass.async_stop()
assert len(test_stop) == 1
assert len(test_close) == 1
@ -341,147 +337,139 @@ def test_state_as_dict():
assert state.as_dict() is state.as_dict()
class TestEventBus(unittest.TestCase):
"""Test EventBus methods."""
async def test_add_remove_listener(hass):
"""Test remove_listener method."""
old_count = len(hass.bus.async_listeners())
# pylint: disable=invalid-name
def setUp(self):
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
self.bus = self.hass.bus
def listener(_):
pass
# pylint: disable=invalid-name
def tearDown(self):
"""Stop down stuff we started."""
self.hass.stop()
unsub = hass.bus.async_listen("test", listener)
def test_add_remove_listener(self):
"""Test remove_listener method."""
self.hass.allow_pool = False
old_count = len(self.bus.listeners)
assert old_count + 1 == len(hass.bus.async_listeners())
def listener(_):
pass
# Remove listener
unsub()
assert old_count == len(hass.bus.async_listeners())
unsub = self.bus.listen("test", listener)
# Should do nothing now
unsub()
assert old_count + 1 == len(self.bus.listeners)
# Remove listener
unsub()
assert old_count == len(self.bus.listeners)
async def test_unsubscribe_listener(hass):
"""Test unsubscribe listener from returned function."""
calls = []
# Should do nothing now
unsub()
@ha.callback
def listener(event):
"""Mock listener."""
calls.append(event)
def test_unsubscribe_listener(self):
"""Test unsubscribe listener from returned function."""
calls = []
unsub = hass.bus.async_listen("test", listener)
@ha.callback
def listener(event):
"""Mock listener."""
calls.append(event)
hass.bus.async_fire("test")
await hass.async_block_till_done()
unsub = self.bus.listen("test", listener)
assert len(calls) == 1
self.bus.fire("test")
self.hass.block_till_done()
unsub()
assert len(calls) == 1
hass.bus.async_fire("event")
await hass.async_block_till_done()
unsub()
assert len(calls) == 1
self.bus.fire("event")
self.hass.block_till_done()
assert len(calls) == 1
async def test_listen_once_event_with_callback(hass):
"""Test listen_once_event method."""
runs = []
def test_listen_once_event_with_callback(self):
"""Test listen_once_event method."""
runs = []
@ha.callback
def event_handler(event):
runs.append(event)
@ha.callback
def event_handler(event):
runs.append(event)
hass.bus.async_listen_once("test_event", event_handler)
self.bus.listen_once("test_event", event_handler)
hass.bus.async_fire("test_event")
# Second time it should not increase runs
hass.bus.async_fire("test_event")
self.bus.fire("test_event")
# Second time it should not increase runs
self.bus.fire("test_event")
await hass.async_block_till_done()
assert len(runs) == 1
self.hass.block_till_done()
assert len(runs) == 1
def test_listen_once_event_with_coroutine(self):
"""Test listen_once_event method."""
runs = []
async def test_listen_once_event_with_coroutine(hass):
"""Test listen_once_event method."""
runs = []
async def event_handler(event):
runs.append(event)
async def event_handler(event):
runs.append(event)
self.bus.listen_once("test_event", event_handler)
hass.bus.async_listen_once("test_event", event_handler)
self.bus.fire("test_event")
# Second time it should not increase runs
self.bus.fire("test_event")
hass.bus.async_fire("test_event")
# Second time it should not increase runs
hass.bus.async_fire("test_event")
self.hass.block_till_done()
assert len(runs) == 1
await hass.async_block_till_done()
assert len(runs) == 1
def test_listen_once_event_with_thread(self):
"""Test listen_once_event method."""
runs = []
def event_handler(event):
runs.append(event)
async def test_listen_once_event_with_thread(hass):
"""Test listen_once_event method."""
runs = []
self.bus.listen_once("test_event", event_handler)
def event_handler(event):
runs.append(event)
self.bus.fire("test_event")
# Second time it should not increase runs
self.bus.fire("test_event")
hass.bus.async_listen_once("test_event", event_handler)
self.hass.block_till_done()
assert len(runs) == 1
hass.bus.async_fire("test_event")
# Second time it should not increase runs
hass.bus.async_fire("test_event")
def test_thread_event_listener(self):
"""Test thread event listener."""
thread_calls = []
await hass.async_block_till_done()
assert len(runs) == 1
def thread_listener(event):
thread_calls.append(event)
self.bus.listen("test_thread", thread_listener)
self.bus.fire("test_thread")
self.hass.block_till_done()
assert len(thread_calls) == 1
async def test_thread_event_listener(hass):
"""Test thread event listener."""
thread_calls = []
def test_callback_event_listener(self):
"""Test callback event listener."""
callback_calls = []
def thread_listener(event):
thread_calls.append(event)
@ha.callback
def callback_listener(event):
callback_calls.append(event)
hass.bus.async_listen("test_thread", thread_listener)
hass.bus.async_fire("test_thread")
await hass.async_block_till_done()
assert len(thread_calls) == 1
self.bus.listen("test_callback", callback_listener)
self.bus.fire("test_callback")
self.hass.block_till_done()
assert len(callback_calls) == 1
def test_coroutine_event_listener(self):
"""Test coroutine event listener."""
coroutine_calls = []
async def test_callback_event_listener(hass):
"""Test callback event listener."""
callback_calls = []
async def coroutine_listener(event):
coroutine_calls.append(event)
@ha.callback
def callback_listener(event):
callback_calls.append(event)
self.bus.listen("test_coroutine", coroutine_listener)
self.bus.fire("test_coroutine")
self.hass.block_till_done()
assert len(coroutine_calls) == 1
hass.bus.async_listen("test_callback", callback_listener)
hass.bus.async_fire("test_callback")
await hass.async_block_till_done()
assert len(callback_calls) == 1
async def test_coroutine_event_listener(hass):
"""Test coroutine event listener."""
coroutine_calls = []
async def coroutine_listener(event):
coroutine_calls.append(event)
hass.bus.async_listen("test_coroutine", coroutine_listener)
hass.bus.async_fire("test_coroutine")
await hass.async_block_till_done()
assert len(coroutine_calls) == 1
def test_state_init():