mirror of
https://github.com/home-assistant/core
synced 2024-10-05 11:17:53 +00:00
Add unique ID to config entries (#29806)
* Add unique ID to config entries * Unload existing entries with same unique ID if flow with unique ID is finished * Remove unused exception * Fix typing * silence pylint * Fix tests * Add unique ID to Hue * Address typing comment * Tweaks to comments * lint
This commit is contained in:
parent
87ca61ddd7
commit
d851cb6f9e
|
@ -4,11 +4,11 @@ import logging
|
|||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant import config_entries, core
|
||||
from homeassistant.const import CONF_FILENAME, CONF_HOST
|
||||
from homeassistant.helpers import config_validation as cv, device_registry as dr
|
||||
|
||||
from .bridge import HueBridge
|
||||
from .bridge import HueBridge, normalize_bridge_id
|
||||
from .config_flow import ( # Loading the config flow file will register the flow
|
||||
configured_hosts,
|
||||
)
|
||||
|
@ -102,7 +102,9 @@ async def async_setup(hass, config):
|
|||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass, entry):
|
||||
async def async_setup_entry(
|
||||
hass: core.HomeAssistant, entry: config_entries.ConfigEntry
|
||||
):
|
||||
"""Set up a bridge from a config entry."""
|
||||
host = entry.data["host"]
|
||||
config = hass.data[DATA_CONFIGS].get(host)
|
||||
|
@ -121,6 +123,13 @@ async def async_setup_entry(hass, entry):
|
|||
|
||||
hass.data[DOMAIN][host] = bridge
|
||||
config = bridge.api.config
|
||||
|
||||
# For backwards compat
|
||||
if entry.unique_id is None:
|
||||
hass.config_entries.async_update_entry(
|
||||
entry, unique_id=normalize_bridge_id(config.bridgeid)
|
||||
)
|
||||
|
||||
device_registry = await dr.async_get_registry(hass)
|
||||
device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
|
|
|
@ -201,3 +201,25 @@ async def get_bridge(hass, host, username=None):
|
|||
except aiohue.AiohueException:
|
||||
LOGGER.exception("Unknown Hue linking error occurred")
|
||||
raise AuthenticationRequired
|
||||
|
||||
|
||||
def normalize_bridge_id(bridge_id: str):
|
||||
"""Normalize a bridge identifier.
|
||||
|
||||
There are three sources where we receive bridge ID from:
|
||||
- ssdp/upnp: <host>/description.xml, field root/device/serialNumber
|
||||
- nupnp: "id" field
|
||||
- Hue Bridge API: config.bridgeid
|
||||
|
||||
The SSDP/UPNP source does not contain the middle 4 characters compared
|
||||
to the other sources. In all our tests the middle 4 characters are "fffe".
|
||||
"""
|
||||
if len(bridge_id) == 16:
|
||||
return bridge_id[0:6] + bridge_id[-6:]
|
||||
|
||||
if len(bridge_id) == 12:
|
||||
return bridge_id
|
||||
|
||||
LOGGER.warning("Unexpected bridge id number found: %s", bridge_id)
|
||||
|
||||
return bridge_id
|
||||
|
|
|
@ -12,7 +12,7 @@ from homeassistant.components.ssdp import ATTR_MANUFACTURERURL, ATTR_NAME
|
|||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import aiohttp_client
|
||||
|
||||
from .bridge import get_bridge
|
||||
from .bridge import get_bridge, normalize_bridge_id
|
||||
from .const import DOMAIN, LOGGER
|
||||
from .errors import AuthenticationRequired, CannotConnect
|
||||
|
||||
|
@ -154,17 +154,15 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
if host in configured_hosts(self.hass):
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
# This value is based off host/description.xml and is, weirdly, missing
|
||||
# 4 characters in the middle of the serial compared to results returned
|
||||
# from the NUPNP API or when querying the bridge API for bridgeid.
|
||||
# (on first gen Hue hub)
|
||||
serial = discovery_info.get("serial")
|
||||
bridge_id = discovery_info.get("serial")
|
||||
|
||||
await self.async_set_unique_id(normalize_bridge_id(bridge_id))
|
||||
|
||||
return await self.async_step_import(
|
||||
{
|
||||
"host": host,
|
||||
# This format is the legacy format that Hue used for discovery
|
||||
"path": f"phue-{serial}.conf",
|
||||
"path": f"phue-{bridge_id}.conf",
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -180,6 +178,10 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
if host in configured_hosts(self.hass):
|
||||
return self.async_abort(reason="already_configured")
|
||||
|
||||
await self.async_set_unique_id(
|
||||
normalize_bridge_id(homekit_info["properties"]["id"].replace(":", ""))
|
||||
)
|
||||
|
||||
return await self.async_step_import({"host": host})
|
||||
|
||||
async def async_step_import(self, import_info):
|
||||
|
@ -234,18 +236,9 @@ class HueFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
host = bridge.host
|
||||
bridge_id = bridge.config.bridgeid
|
||||
|
||||
same_hub_entries = [
|
||||
entry.entry_id
|
||||
for entry in self.hass.config_entries.async_entries(DOMAIN)
|
||||
if entry.data["bridge_id"] == bridge_id or entry.data["host"] == host
|
||||
]
|
||||
|
||||
if same_hub_entries:
|
||||
await asyncio.wait(
|
||||
[
|
||||
self.hass.config_entries.async_remove(entry_id)
|
||||
for entry_id in same_hub_entries
|
||||
]
|
||||
if self.unique_id is None:
|
||||
await self.async_set_unique_id(
|
||||
normalize_bridge_id(bridge_id), raise_on_progress=False
|
||||
)
|
||||
|
||||
return self.async_create_entry(
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
|
||||
import uuid
|
||||
import weakref
|
||||
|
||||
|
@ -75,6 +75,10 @@ class OperationNotAllowed(ConfigError):
|
|||
"""Raised when a config entry operation is not allowed."""
|
||||
|
||||
|
||||
class UniqueIdInProgress(data_entry_flow.AbortFlow):
|
||||
"""Error to indicate that the unique Id is in progress."""
|
||||
|
||||
|
||||
class ConfigEntry:
|
||||
"""Hold a configuration entry."""
|
||||
|
||||
|
@ -85,6 +89,7 @@ class ConfigEntry:
|
|||
"title",
|
||||
"data",
|
||||
"options",
|
||||
"unique_id",
|
||||
"system_options",
|
||||
"source",
|
||||
"connection_class",
|
||||
|
@ -104,6 +109,7 @@ class ConfigEntry:
|
|||
connection_class: str,
|
||||
system_options: dict,
|
||||
options: Optional[dict] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
entry_id: Optional[str] = None,
|
||||
state: str = ENTRY_STATE_NOT_LOADED,
|
||||
) -> None:
|
||||
|
@ -138,6 +144,9 @@ class ConfigEntry:
|
|||
# State of the entry (LOADED, NOT_LOADED)
|
||||
self.state = state
|
||||
|
||||
# Unique ID of this entry.
|
||||
self.unique_id = unique_id
|
||||
|
||||
# Listeners to call on update
|
||||
self.update_listeners: List = []
|
||||
|
||||
|
@ -533,11 +542,15 @@ class ConfigEntries:
|
|||
self,
|
||||
entry: ConfigEntry,
|
||||
*,
|
||||
unique_id: Union[str, dict, None] = _UNDEF,
|
||||
data: dict = _UNDEF,
|
||||
options: dict = _UNDEF,
|
||||
system_options: dict = _UNDEF,
|
||||
) -> None:
|
||||
"""Update a config entry."""
|
||||
if unique_id is not _UNDEF:
|
||||
entry.unique_id = cast(Optional[str], unique_id)
|
||||
|
||||
if data is not _UNDEF:
|
||||
entry.data = data
|
||||
|
||||
|
@ -602,6 +615,25 @@ class ConfigEntries:
|
|||
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
return result
|
||||
|
||||
# Check if config entry exists with unique ID. Unload it.
|
||||
existing_entry = None
|
||||
unique_id = flow.context.get("unique_id")
|
||||
|
||||
if unique_id is not None:
|
||||
for check_entry in self.async_entries(result["handler"]):
|
||||
if check_entry.unique_id == unique_id:
|
||||
existing_entry = check_entry
|
||||
break
|
||||
|
||||
# Unload the entry before setting up the new one.
|
||||
# We will remove it only after the other one is set up,
|
||||
# so that device customizations are not getting lost.
|
||||
if (
|
||||
existing_entry is not None
|
||||
and existing_entry.state not in UNRECOVERABLE_STATES
|
||||
):
|
||||
await self.async_unload(existing_entry.entry_id)
|
||||
|
||||
entry = ConfigEntry(
|
||||
version=result["version"],
|
||||
domain=result["handler"],
|
||||
|
@ -611,12 +643,16 @@ class ConfigEntries:
|
|||
system_options={},
|
||||
source=flow.context["source"],
|
||||
connection_class=flow.CONNECTION_CLASS,
|
||||
unique_id=unique_id,
|
||||
)
|
||||
self._entries.append(entry)
|
||||
self._async_schedule_save()
|
||||
|
||||
await self.async_setup(entry.entry_id)
|
||||
|
||||
if existing_entry is not None:
|
||||
await self.async_remove(existing_entry.entry_id)
|
||||
|
||||
result["result"] = entry
|
||||
return result
|
||||
|
||||
|
@ -687,6 +723,8 @@ async def _old_conf_migrator(old_config: Dict[str, Any]) -> Dict[str, Any]:
|
|||
class ConfigFlow(data_entry_flow.FlowHandler):
|
||||
"""Base class for config flows with some helpers."""
|
||||
|
||||
unique_id = None
|
||||
|
||||
def __init_subclass__(cls, domain: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""Initialize a subclass, register if possible."""
|
||||
super().__init_subclass__(**kwargs) # type: ignore
|
||||
|
@ -701,6 +739,27 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
|||
"""Get the options flow for this handler."""
|
||||
raise data_entry_flow.UnknownHandler
|
||||
|
||||
async def async_set_unique_id(
|
||||
self, unique_id: str, *, raise_on_progress: bool = True
|
||||
) -> Optional[ConfigEntry]:
|
||||
"""Set a unique ID for the config flow.
|
||||
|
||||
Returns optionally existing config entry with same ID.
|
||||
"""
|
||||
if raise_on_progress:
|
||||
for progress in self._async_in_progress():
|
||||
if progress["context"].get("unique_id") == unique_id:
|
||||
raise UniqueIdInProgress("already_in_progress")
|
||||
|
||||
# pylint: disable=no-member
|
||||
self.context["unique_id"] = unique_id
|
||||
|
||||
for entry in self._async_current_entries():
|
||||
if entry.unique_id == unique_id:
|
||||
return entry
|
||||
|
||||
return None
|
||||
|
||||
@callback
|
||||
def _async_current_entries(self) -> List[ConfigEntry]:
|
||||
"""Return current entries."""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Classes to help gather user submissions."""
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
import uuid
|
||||
|
||||
import voluptuous as vol
|
||||
|
@ -36,6 +36,16 @@ class UnknownStep(FlowError):
|
|||
"""Unknown step specified."""
|
||||
|
||||
|
||||
class AbortFlow(FlowError):
|
||||
"""Exception to indicate a flow needs to be aborted."""
|
||||
|
||||
def __init__(self, reason: str, description_placeholders: Optional[Dict] = None):
|
||||
"""Initialize an abort flow exception."""
|
||||
super().__init__(f"Flow aborted: {reason}")
|
||||
self.reason = reason
|
||||
self.description_placeholders = description_placeholders
|
||||
|
||||
|
||||
class FlowManager:
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
|
@ -131,7 +141,12 @@ class FlowManager:
|
|||
)
|
||||
)
|
||||
|
||||
result: Dict = await getattr(flow, method)(user_input)
|
||||
try:
|
||||
result: Dict = await getattr(flow, method)(user_input)
|
||||
except AbortFlow as err:
|
||||
result = _create_abort_data(
|
||||
flow.flow_id, flow.handler, err.reason, err.description_placeholders
|
||||
)
|
||||
|
||||
if result["type"] not in (
|
||||
RESULT_TYPE_FORM,
|
||||
|
@ -228,13 +243,9 @@ class FlowHandler:
|
|||
self, *, reason: str, description_placeholders: Optional[Dict] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Abort the config flow."""
|
||||
return {
|
||||
"type": RESULT_TYPE_ABORT,
|
||||
"flow_id": self.flow_id,
|
||||
"handler": self.handler,
|
||||
"reason": reason,
|
||||
"description_placeholders": description_placeholders,
|
||||
}
|
||||
return _create_abort_data(
|
||||
self.flow_id, cast(str, self.handler), reason, description_placeholders
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_external_step(
|
||||
|
@ -259,3 +270,20 @@ class FlowHandler:
|
|||
"handler": self.handler,
|
||||
"step_id": next_step_id,
|
||||
}
|
||||
|
||||
|
||||
@callback
|
||||
def _create_abort_data(
|
||||
flow_id: str,
|
||||
handler: str,
|
||||
reason: str,
|
||||
description_placeholders: Optional[Dict] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Return the definition of an external step for the user to take."""
|
||||
return {
|
||||
"type": RESULT_TYPE_ABORT,
|
||||
"flow_id": flow_id,
|
||||
"handler": handler,
|
||||
"reason": reason,
|
||||
"description_placeholders": description_placeholders,
|
||||
}
|
||||
|
|
|
@ -671,6 +671,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
|
|||
options={},
|
||||
system_options={},
|
||||
connection_class=config_entries.CONN_CLASS_UNKNOWN,
|
||||
unique_id=None,
|
||||
):
|
||||
"""Initialize a mock config entry."""
|
||||
kwargs = {
|
||||
|
@ -682,6 +683,7 @@ class MockConfigEntry(config_entries.ConfigEntry):
|
|||
"version": version,
|
||||
"title": title,
|
||||
"connection_class": connection_class,
|
||||
"unique_id": unique_id,
|
||||
}
|
||||
if source is not None:
|
||||
kwargs["source"] = source
|
||||
|
|
|
@ -19,6 +19,7 @@ async def test_flow_works(hass, aioclient_mock):
|
|||
|
||||
flow = config_flow.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
flow.context = {}
|
||||
await flow.async_step_init()
|
||||
|
||||
with patch("aiohue.Bridge") as mock_bridge:
|
||||
|
@ -349,28 +350,33 @@ async def test_creating_entry_removes_entries_for_same_host_or_bridge(hass):
|
|||
accessible via a single IP. So when we create a new entry, we'll remove
|
||||
all existing entries that either have same IP or same bridge_id.
|
||||
"""
|
||||
MockConfigEntry(
|
||||
domain="hue", data={"host": "0.0.0.0", "bridge_id": "id-1234"}
|
||||
).add_to_hass(hass)
|
||||
orig_entry = MockConfigEntry(
|
||||
domain="hue",
|
||||
data={"host": "0.0.0.0", "bridge_id": "id-1234"},
|
||||
unique_id="id-1234",
|
||||
)
|
||||
orig_entry.add_to_hass(hass)
|
||||
|
||||
MockConfigEntry(
|
||||
domain="hue", data={"host": "1.2.3.4", "bridge_id": "id-1234"}
|
||||
domain="hue",
|
||||
data={"host": "1.2.3.4", "bridge_id": "id-5678"},
|
||||
unique_id="id-5678",
|
||||
).add_to_hass(hass)
|
||||
|
||||
assert len(hass.config_entries.async_entries("hue")) == 2
|
||||
|
||||
flow = config_flow.HueFlowHandler()
|
||||
flow.hass = hass
|
||||
flow.context = {}
|
||||
|
||||
bridge = Mock()
|
||||
bridge.username = "username-abc"
|
||||
bridge.config.bridgeid = "id-1234"
|
||||
bridge.config.name = "Mock Bridge"
|
||||
bridge.host = "0.0.0.0"
|
||||
|
||||
with patch.object(config_flow, "get_bridge", return_value=mock_coro(bridge)):
|
||||
result = await flow.async_step_import({"host": "0.0.0.0"})
|
||||
with patch.object(
|
||||
config_flow, "_find_username_from_config", return_value="mock-user"
|
||||
), patch.object(config_flow, "get_bridge", return_value=mock_coro(bridge)):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"hue", data={"host": "2.2.2.2"}, context={"source": "import"}
|
||||
)
|
||||
|
||||
assert result["type"] == "create_entry"
|
||||
assert result["title"] == "Mock Bridge"
|
||||
|
@ -379,9 +385,11 @@ async def test_creating_entry_removes_entries_for_same_host_or_bridge(hass):
|
|||
"bridge_id": "id-1234",
|
||||
"username": "username-abc",
|
||||
}
|
||||
# We did not process the result of this entry but already removed the old
|
||||
# ones. So we should have 0 entries.
|
||||
assert len(hass.config_entries.async_entries("hue")) == 0
|
||||
entries = hass.config_entries.async_entries("hue")
|
||||
assert len(entries) == 2
|
||||
new_entry = entries[-1]
|
||||
assert orig_entry.entry_id != new_entry.entry_id
|
||||
assert new_entry.unique_id == "id-1234"
|
||||
|
||||
|
||||
async def test_bridge_homekit(hass):
|
||||
|
@ -398,6 +406,7 @@ async def test_bridge_homekit(hass):
|
|||
"host": "0.0.0.0",
|
||||
"serial": "1234",
|
||||
"manufacturerURL": config_flow.HUE_MANUFACTURERURL,
|
||||
"properties": {"id": "aa:bb:cc:dd:ee:ff"},
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -175,3 +175,19 @@ async def test_unload_entry(hass):
|
|||
assert await hue.async_unload_entry(hass, entry)
|
||||
assert len(mock_bridge.return_value.async_reset.mock_calls) == 1
|
||||
assert hass.data[hue.DOMAIN] == {}
|
||||
|
||||
|
||||
async def test_setting_unique_id(hass):
|
||||
"""Test we set unique ID if not set yet."""
|
||||
entry = MockConfigEntry(domain=hue.DOMAIN, data={"host": "0.0.0.0"})
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch.object(hue, "HueBridge") as mock_bridge, patch(
|
||||
"homeassistant.helpers.device_registry.async_get_registry",
|
||||
return_value=mock_coro(Mock()),
|
||||
):
|
||||
mock_bridge.return_value.async_setup.return_value = mock_coro(True)
|
||||
mock_bridge.return_value.api.config = Mock(bridgeid="mock-id")
|
||||
assert await async_setup_component(hass, hue.DOMAIN, {}) is True
|
||||
|
||||
assert entry.unique_id == "mock-id"
|
||||
|
|
|
@ -1001,3 +1001,110 @@ async def test_reload_entry_entity_registry_works(hass):
|
|||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_unload_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_unqiue_id_persisted(hass, manager):
|
||||
"""Test that a unique ID is stored in the config entry."""
|
||||
mock_setup_entry = MagicMock(return_value=mock_coro(True))
|
||||
|
||||
mock_integration(hass, MockModule("comp", async_setup_entry=mock_setup_entry))
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
await self.async_set_unique_id("mock-unique-id")
|
||||
return self.async_create_entry(title="mock-title", data={})
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||
await manager.flow.async_init(
|
||||
"comp", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
p_hass, p_entry = mock_setup_entry.mock_calls[0][1]
|
||||
|
||||
assert p_hass is hass
|
||||
assert p_entry.unique_id == "mock-unique-id"
|
||||
|
||||
|
||||
async def test_unique_id_existing_entry(hass, manager):
|
||||
"""Test that we remove an entry if there already is an entry with unique ID."""
|
||||
hass.config.components.add("comp")
|
||||
MockConfigEntry(
|
||||
domain="comp",
|
||||
state=config_entries.ENTRY_STATE_LOADED,
|
||||
unique_id="mock-unique-id",
|
||||
).add_to_hass(hass)
|
||||
|
||||
async_setup_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
|
||||
async_unload_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
|
||||
async_remove_entry = MagicMock(side_effect=lambda _, _2: mock_coro(True))
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"comp",
|
||||
async_setup_entry=async_setup_entry,
|
||||
async_unload_entry=async_unload_entry,
|
||||
async_remove_entry=async_remove_entry,
|
||||
),
|
||||
)
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
existing_entry = await self.async_set_unique_id("mock-unique-id")
|
||||
|
||||
assert existing_entry is not None
|
||||
|
||||
return self.async_create_entry(title="mock-title", data={"via": "flow"})
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||
result = await manager.flow.async_init(
|
||||
"comp", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
|
||||
entries = hass.config_entries.async_entries("comp")
|
||||
assert len(entries) == 1
|
||||
assert entries[0].data == {"via": "flow"}
|
||||
|
||||
assert len(async_setup_entry.mock_calls) == 1
|
||||
assert len(async_unload_entry.mock_calls) == 1
|
||||
assert len(async_remove_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_unique_id_in_progress(hass, manager):
|
||||
"""Test that we abort if there is already a flow in progress with same unique id."""
|
||||
mock_integration(hass, MockModule("comp"))
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
await self.async_set_unique_id("mock-unique-id")
|
||||
return self.async_show_form(step_id="discovery")
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
|
||||
# Create one to be in progress
|
||||
result = await manager.flow.async_init(
|
||||
"comp", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
|
||||
# Will be canceled
|
||||
result2 = await manager.flow.async_init(
|
||||
"comp", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
assert result2["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result2["reason"] == "already_in_progress"
|
||||
|
|
|
@ -94,7 +94,7 @@ async def test_configure_two_steps(manager):
|
|||
|
||||
|
||||
async def test_show_form(manager):
|
||||
"""Test that abort removes the flow from progress."""
|
||||
"""Test that we can show a form."""
|
||||
schema = vol.Schema({vol.Required("username"): str, vol.Required("password"): str})
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
|
@ -271,3 +271,17 @@ async def test_external_step(hass, manager):
|
|||
result = await manager.async_configure(result["flow_id"])
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["title"] == "Hello"
|
||||
|
||||
|
||||
async def test_abort_flow_exception(manager):
|
||||
"""Test that the AbortFlow exception works."""
|
||||
|
||||
@manager.mock_reg_handler("test")
|
||||
class TestFlow(data_entry_flow.FlowHandler):
|
||||
async def async_step_init(self, user_input=None):
|
||||
raise data_entry_flow.AbortFlow("mock-reason", {"placeholder": "yo"})
|
||||
|
||||
form = await manager.async_init("test")
|
||||
assert form["type"] == "abort"
|
||||
assert form["reason"] == "mock-reason"
|
||||
assert form["description_placeholders"] == {"placeholder": "yo"}
|
||||
|
|
Loading…
Reference in a new issue