Add thread safety checks to async_create_task (#116339)

* Add thread safety checks to async_create_task

Calling async_create_task from a thread almost always results in an
fast crash. Since most internals are using async_create_background_task
or other task APIs, and this is the one integrations seem to get wrong
the most, add a thread safety check here

* Add thread safety checks to async_create_task

Calling async_create_task from a thread almost always results in an
fast crash. Since most internals are using async_create_background_task
or other task APIs, and this is the one integrations seem to get wrong
the most, add a thread safety check here

* missed one

* Update homeassistant/core.py

* fix mocks

* one more internal

* more places where internal can be used

* more places where internal can be used

* more places where internal can be used

* internal one more place since this is high volume and was already eager_start
This commit is contained in:
J. Nick Koston 2024-04-28 17:29:00 -05:00 committed by GitHub
parent b8ddf51e28
commit 164403de20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 70 additions and 23 deletions

View file

@ -735,7 +735,7 @@ async def async_setup_multi_components(
# to wait to be imported, and the sooner we can get the base platforms
# loaded the sooner we can start loading the rest of the integrations.
futures = {
domain: hass.async_create_task(
domain: hass.async_create_task_internal(
async_setup_component(hass, domain, config),
f"setup component {domain}",
eager_start=True,

View file

@ -1087,7 +1087,7 @@ class ConfigEntry:
target: target to call.
"""
task = hass.async_create_task(
task = hass.async_create_task_internal(
target, f"{name} {self.title} {self.domain} {self.entry_id}", eager_start
)
if eager_start and task.done():
@ -1643,7 +1643,7 @@ class ConfigEntries:
# starting a new flow with the 'unignore' step. If the integration doesn't
# implement async_step_unignore then this will be a no-op.
if entry.source == SOURCE_IGNORE:
self.hass.async_create_task(
self.hass.async_create_task_internal(
self.hass.config_entries.flow.async_init(
entry.domain,
context={"source": SOURCE_UNIGNORE},

View file

@ -773,7 +773,9 @@ class HomeAssistant:
target: target to call.
"""
self.loop.call_soon_threadsafe(
functools.partial(self.async_create_task, target, name, eager_start=True)
functools.partial(
self.async_create_task_internal, target, name, eager_start=True
)
)
@callback
@ -788,6 +790,37 @@ class HomeAssistant:
This method must be run in the event loop. If you are using this in your
integration, use the create task methods on the config entry instead.
target: target to call.
"""
# We turned on asyncio debug in April 2024 in the dev containers
# in the hope of catching some of the issues that have been
# reported. It will take a while to get all the issues fixed in
# custom components.
#
# In 2025.5 we should guard the `verify_event_loop_thread`
# check with a check for the `hass.config.debug` flag being set as
# long term we don't want to be checking this in production
# environments since it is a performance hit.
self.verify_event_loop_thread("async_create_task")
return self.async_create_task_internal(target, name, eager_start)
@callback
def async_create_task_internal(
self,
target: Coroutine[Any, Any, _R],
name: str | None = None,
eager_start: bool = True,
) -> asyncio.Task[_R]:
"""Create a task from within the event loop, internal use only.
This method is intended to only be used by core internally
and should not be considered a stable API. We will make
breaking change to this function in the future and it
should not be used in integrations.
This method must be run in the event loop. If you are using this in your
integration, use the create task methods on the config entry instead.
target: target to call.
"""
if eager_start:
@ -2683,7 +2716,7 @@ class ServiceRegistry:
coro = self._execute_service(handler, service_call)
if not blocking:
self._hass.async_create_task(
self._hass.async_create_task_internal(
self._run_service_call_catch_exceptions(coro, service_call),
f"service call background {service_call.domain}.{service_call.service}",
eager_start=True,

View file

@ -1497,7 +1497,7 @@ class Entity(
is_remove = action == "remove"
self._removed_from_registry = is_remove
if action == "update" or is_remove:
self.hass.async_create_task(
self.hass.async_create_task_internal(
self._async_process_registry_update_or_remove(event), eager_start=True
)

View file

@ -146,7 +146,7 @@ class EntityComponent(Generic[_EntityT]):
# Look in config for Domain, Domain 2, Domain 3 etc and load them
for p_type, p_config in conf_util.config_per_platform(config, self.domain):
if p_type is not None:
self.hass.async_create_task(
self.hass.async_create_task_internal(
self.async_setup_platform(p_type, p_config),
f"EntityComponent setup platform {p_type} {self.domain}",
eager_start=True,

View file

@ -477,7 +477,7 @@ class EntityPlatform:
self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None:
"""Schedule adding entities for a single platform async."""
task = self.hass.async_create_task(
task = self.hass.async_create_task_internal(
self.async_add_entities(new_entities, update_before_add=update_before_add),
f"EntityPlatform async_add_entities {self.domain}.{self.platform_name}",
eager_start=True,

View file

@ -85,7 +85,7 @@ def _async_integration_platform_component_loaded(
# At least one of the platforms is not loaded, we need to load them
# so we have to fall back to creating a task.
hass.async_create_task(
hass.async_create_task_internal(
_async_process_integration_platforms_for_component(
hass, integration, platforms_that_exist, integration_platforms_by_name
),
@ -206,7 +206,7 @@ async def async_process_integration_platforms(
# We use hass.async_create_task instead of asyncio.create_task because
# we want to make sure that startup waits for the task to complete.
#
future = hass.async_create_task(
future = hass.async_create_task_internal(
_async_process_integration_platforms(
hass, platform_name, top_level_components.copy(), process_job
),

View file

@ -659,7 +659,7 @@ class DynamicServiceIntentHandler(IntentHandler):
)
await self._run_then_background(
hass.async_create_task(
hass.async_create_task_internal(
hass.services.async_call(
domain,
service,

View file

@ -236,7 +236,9 @@ class RestoreStateData:
# Dump the initial states now. This helps minimize the risk of having
# old states loaded by overwriting the last states once Home Assistant
# has started and the old states have been read.
self.hass.async_create_task(_async_dump_states(), "RestoreStateData dump")
self.hass.async_create_task_internal(
_async_dump_states(), "RestoreStateData dump"
)
# Dump states periodically
cancel_interval = async_track_time_interval(

View file

@ -734,7 +734,7 @@ class _ScriptRun:
)
trace_set_result(params=params, running_script=running_script)
response_data = await self._async_run_long_action(
self._hass.async_create_task(
self._hass.async_create_task_internal(
self._hass.services.async_call(
**params,
blocking=True,
@ -1208,7 +1208,7 @@ class _ScriptRun:
async def _async_run_script(self, script: Script) -> None:
"""Execute a script."""
result = await self._async_run_long_action(
self._hass.async_create_task(
self._hass.async_create_task_internal(
script.async_run(self._variables, self._context), eager_start=True
)
)

View file

@ -468,7 +468,7 @@ class Store(Generic[_T]):
# wrote. Reschedule the timer to the next write time.
self._async_reschedule_delayed_write(self._next_write_time)
return
self.hass.async_create_task(
self.hass.async_create_task_internal(
self._async_callback_delayed_write(), eager_start=True
)

View file

@ -600,7 +600,7 @@ def _async_when_setup(
_LOGGER.exception("Error handling when_setup callback for %s", component)
if component in hass.config.components:
hass.async_create_task(
hass.async_create_task_internal(
when_setup(), f"when setup {component}", eager_start=True
)
return

View file

@ -234,7 +234,7 @@ async def async_test_home_assistant(
orig_async_add_job = hass.async_add_job
orig_async_add_executor_job = hass.async_add_executor_job
orig_async_create_task = hass.async_create_task
orig_async_create_task_internal = hass.async_create_task_internal
orig_tz = dt_util.DEFAULT_TIME_ZONE
def async_add_job(target, *args, eager_start: bool = False):
@ -263,18 +263,18 @@ async def async_test_home_assistant(
return orig_async_add_executor_job(target, *args)
def async_create_task(coroutine, name=None, eager_start=True):
def async_create_task_internal(coroutine, name=None, eager_start=True):
"""Create task."""
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
fut = asyncio.Future()
fut.set_result(None)
return fut
return orig_async_create_task(coroutine, name, eager_start)
return orig_async_create_task_internal(coroutine, name, eager_start)
hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task
hass.async_create_task_internal = async_create_task_internal
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}

View file

@ -319,7 +319,7 @@ async def test_async_create_task_schedule_coroutine() -> None:
async def job():
pass
ha.HomeAssistant.async_create_task(hass, job(), eager_start=False)
ha.HomeAssistant.async_create_task_internal(hass, job(), eager_start=False)
assert len(hass.loop.call_soon.mock_calls) == 0
assert len(hass.loop.create_task.mock_calls) == 1
assert len(hass.add_job.mock_calls) == 0
@ -332,7 +332,7 @@ async def test_async_create_task_eager_start_schedule_coroutine() -> None:
async def job():
pass
ha.HomeAssistant.async_create_task(hass, job(), eager_start=True)
ha.HomeAssistant.async_create_task_internal(hass, job(), eager_start=True)
# Should create the task directly since 3.12 supports eager_start
assert len(hass.loop.create_task.mock_calls) == 0
assert len(hass.add_job.mock_calls) == 0
@ -345,7 +345,7 @@ async def test_async_create_task_schedule_coroutine_with_name() -> None:
async def job():
pass
task = ha.HomeAssistant.async_create_task(
task = ha.HomeAssistant.async_create_task_internal(
hass, job(), "named task", eager_start=False
)
assert len(hass.loop.call_soon.mock_calls) == 0
@ -3470,3 +3470,15 @@ async def test_async_remove_thread_safety(hass: HomeAssistant) -> None:
await hass.async_add_executor_job(
hass.services.async_remove, "test_domain", "test_service"
)
async def test_async_create_task_thread_safety(hass: HomeAssistant) -> None:
"""Test async_create_task thread safety."""
async def _any_coro():
pass
with pytest.raises(
RuntimeError, match="Detected code that calls async_create_task from a thread."
):
await hass.async_add_executor_job(hass.async_create_task, _any_coro)