Remove config entry specifics from FlowManager (#85565)

This commit is contained in:
Erik Montnemery 2023-01-17 15:26:17 +01:00 committed by GitHub
parent 0f3221eac7
commit 3cd6bd87a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 119 additions and 85 deletions

View file

@ -761,6 +761,15 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
super().__init__(hass)
self.config_entries = config_entries
self._hass_config = hass_config
self._initializing: dict[str, dict[str, asyncio.Future]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
if not (current := self._initializing.get(handler)):
return
await asyncio.wait(current.values())
@callback
def _async_has_other_discovery_flows(self, flow_id: str) -> bool:
@ -770,12 +779,76 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
for flow in self._progress.values()
)
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult:
"""Start a configuration flow."""
if context is None:
context = {}
flow_id = uuid_util.random_uuid_hex()
init_done: asyncio.Future = asyncio.Future()
self._initializing.setdefault(handler, {})[flow_id] = init_done
task = asyncio.create_task(self._async_init(flow_id, handler, context, data))
self._initialize_tasks.setdefault(handler, []).append(task)
try:
flow, result = await task
finally:
self._initialize_tasks[handler].remove(task)
self._initializing[handler].pop(flow_id)
if result["type"] != data_entry_flow.FlowResultType.ABORT:
await self.async_post_init(flow, result)
return result
async def _async_init(
self,
flow_id: str,
handler: str,
context: dict,
data: Any,
) -> tuple[data_entry_flow.FlowHandler, FlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown."""
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
raise data_entry_flow.UnknownFlow("Flow was not created")
flow.hass = self.hass
flow.handler = handler
flow.flow_id = flow_id
flow.context = context
flow.init_data = data
self._async_add_flow_progress(flow)
try:
result = await self._async_handle_step(flow, flow.init_step, data)
finally:
init_done = self._initializing[handler][flow_id]
if not init_done.done():
init_done.set_result(None)
return flow, result
async def async_shutdown(self) -> None:
"""Cancel any initializing flows."""
for task_list in self._initialize_tasks.values():
for task in task_list:
task.cancel()
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
"""Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow)
# Mark the step as done.
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
init_done = self._initializing[flow.handler].get(flow.flow_id)
if init_done and not init_done.done():
init_done.set_result(None)
# Remove notification if no other discovery config entries in progress
if not self._async_has_other_discovery_flows(flow.flow_id):
persistent_notification.async_dismiss(self.hass, DISCOVERY_NOTIFICATION_ID)

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import abc
import asyncio
from collections.abc import Iterable, Mapping
import copy
from dataclasses import dataclass
@ -55,7 +54,7 @@ class BaseServiceInfo:
class FlowError(HomeAssistantError):
"""Error while configuring an account."""
"""Base class for data entry errors."""
class UnknownHandler(FlowError):
@ -137,18 +136,9 @@ class FlowManager(abc.ABC):
) -> None:
"""Initialize the flow manager."""
self.hass = hass
self._initializing: dict[str, list[asyncio.Future]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
self._progress: dict[str, FlowHandler] = {}
self._handler_progress_index: dict[str, set[str]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
if not (current := self._initializing.get(handler)):
return
await asyncio.wait(current)
@abc.abstractmethod
async def async_create_flow(
self,
@ -166,7 +156,7 @@ class FlowManager(abc.ABC):
async def async_finish_flow(
self, flow: FlowHandler, result: FlowResult
) -> FlowResult:
"""Finish a config flow and add an entry."""
"""Finish a data entry flow."""
async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None:
"""Entry has finished executing its first step asynchronously."""
@ -219,35 +209,9 @@ class FlowManager(abc.ABC):
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult:
"""Start a configuration flow."""
"""Start a data entry flow."""
if context is None:
context = {}
init_done: asyncio.Future = asyncio.Future()
self._initializing.setdefault(handler, []).append(init_done)
task = asyncio.create_task(self._async_init(init_done, handler, context, data))
self._initialize_tasks.setdefault(handler, []).append(task)
try:
flow, result = await task
finally:
self._initialize_tasks[handler].remove(task)
self._initializing[handler].remove(init_done)
if result["type"] != FlowResultType.ABORT:
await self.async_post_init(flow, result)
return result
async def _async_init(
self,
init_done: asyncio.Future,
handler: str,
context: dict,
data: Any,
) -> tuple[FlowHandler, FlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown."""
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
raise UnknownFlow("Flow was not created")
@ -257,19 +221,18 @@ class FlowManager(abc.ABC):
flow.context = context
flow.init_data = data
self._async_add_flow_progress(flow)
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
return flow, result
async def async_shutdown(self) -> None:
"""Cancel any initializing flows."""
for task_list in self._initialize_tasks.values():
for task in task_list:
task.cancel()
result = await self._async_handle_step(flow, flow.init_step, data)
if result["type"] != FlowResultType.ABORT:
await self.async_post_init(flow, result)
return result
async def async_configure(
self, flow_id: str, user_input: dict | None = None
) -> FlowResult:
"""Continue a configuration flow."""
"""Continue a data entry flow."""
if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow
@ -354,22 +317,16 @@ class FlowManager(abc.ABC):
try:
flow.async_remove()
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception("Error removing %s config flow: %s", flow.handler, err)
_LOGGER.exception("Error removing %s flow: %s", flow.handler, err)
async def _async_handle_step(
self,
flow: FlowHandler,
step_id: str,
user_input: dict | BaseServiceInfo | None,
step_done: asyncio.Future | None = None,
self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None
) -> FlowResult:
"""Handle a step of a flow."""
method = f"async_step_{step_id}"
if not hasattr(flow, method):
self._async_remove_flow_progress(flow.flow_id)
if step_done:
step_done.set_result(None)
raise UnknownStep(
f"Handler {flow.__class__.__name__} doesn't support step {step_id}"
)
@ -381,13 +338,6 @@ class FlowManager(abc.ABC):
flow.flow_id, flow.handler, err.reason, err.description_placeholders
)
# Mark the step as done.
# We do this before calling async_finish_flow because config entries will hit a
# circular dependency where async_finish_flow sets up new entry, which needs the
# integration to be set up, which is waiting for init to be done.
if step_done:
step_done.set_result(None)
if not isinstance(result["type"], FlowResultType):
result["type"] = FlowResultType(result["type"]) # type: ignore[unreachable]
report(
@ -424,7 +374,7 @@ class FlowManager(abc.ABC):
class FlowHandler:
"""Handle the configuration flow of a component."""
"""Handle a data entry flow."""
# Set by flow manager
cur_step: FlowResult | None = None
@ -519,7 +469,7 @@ class FlowHandler:
description: str | None = None,
description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult:
"""Finish config flow and create a config entry."""
"""Finish flow."""
flow_result = FlowResult(
version=self.VERSION,
type=FlowResultType.CREATE_ENTRY,
@ -541,7 +491,7 @@ class FlowHandler:
reason: str,
description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult:
"""Abort the config flow."""
"""Abort the flow."""
return _create_abort_data(
self.flow_id, self.handler, reason, description_placeholders
)
@ -626,7 +576,7 @@ class FlowHandler:
@callback
def async_remove(self) -> None:
"""Notification that the config flow has been removed."""
"""Notification that the flow has been removed."""
@callback

View file

@ -92,7 +92,9 @@ async def test_discover_config_flow(hass):
with patch.dict(
discovery.CONFIG_ENTRY_HANDLERS, {"mock-service": "mock-component"}
), patch("homeassistant.data_entry_flow.FlowManager.async_init") as m_init:
), patch(
"homeassistant.config_entries.ConfigEntriesFlowManager.async_init"
) as m_init:
await mock_discovery(hass, discover)
assert len(m_init.mock_calls) == 1

View file

@ -3537,3 +3537,29 @@ async def test_options_flow_options_not_mutated() -> None:
"sub_list": ["one", "two"],
}
assert entry.options == {"sub_dict": {"1": "one"}, "sub_list": ["one"]}
async def test_initializing_flows_canceled_on_shutdown(hass: HomeAssistant, manager):
"""Test that initializing flows are canceled on shutdown."""
class MockFlowHandler(config_entries.ConfigFlow):
"""Define a mock flow handler."""
VERSION = 1
async def async_step_reauth(self, data):
"""Mock Reauth."""
await asyncio.sleep(1)
with patch.dict(
config_entries.HANDLERS, {"comp": MockFlowHandler, "test": MockFlowHandler}
):
task = asyncio.create_task(
manager.flow.async_init("test", context={"source": "reauth"})
)
await hass.async_block_till_done()
await manager.flow.async_shutdown()
with pytest.raises(asyncio.exceptions.CancelledError):
await task

View file

@ -1,5 +1,4 @@
"""Test the flow classes."""
import asyncio
import logging
from unittest.mock import Mock, patch
@ -181,7 +180,7 @@ async def test_abort_calls_async_remove_with_exception(manager, caplog):
with caplog.at_level(logging.ERROR):
await manager.async_init("test")
assert "Error removing test config flow: error" in caplog.text
assert "Error removing test flow: error" in caplog.text
TestFlow.async_remove.assert_called_once()
@ -419,22 +418,6 @@ async def test_abort_flow_exception(manager):
assert form["description_placeholders"] == {"placeholder": "yo"}
async def test_initializing_flows_canceled_on_shutdown(hass, manager):
"""Test that initializing flows are canceled on shutdown."""
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
await asyncio.sleep(1)
task = asyncio.create_task(manager.async_init("test"))
await hass.async_block_till_done()
await manager.async_shutdown()
with pytest.raises(asyncio.exceptions.CancelledError):
await task
async def test_init_unknown_flow(manager):
"""Test that UnknownFlow is raised when async_create_flow returns None."""