Adjust async_step_discovery methods for BaseServiceInfo (#60285)

Co-authored-by: epenet <epenet@users.noreply.github.com>
This commit is contained in:
epenet 2021-11-25 02:30:02 +01:00 committed by GitHub
parent 0920e74aa2
commit 75057949d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 19 deletions

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Iterable, Mapping
from contextvars import ContextVar
import dataclasses
from enum import Enum
import functools
import logging
@ -1360,13 +1361,13 @@ class ConfigFlow(data_entry_flow.FlowHandler):
self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by Homekit discovery."""
return await self.async_step_discovery(cast(dict, discovery_info))
return await self.async_step_discovery(dataclasses.asdict(discovery_info))
async def async_step_mqtt(
self, discovery_info: MqttServiceInfo
) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by MQTT discovery."""
return await self.async_step_discovery(cast(dict, discovery_info))
return await self.async_step_discovery(dataclasses.asdict(discovery_info))
async def async_step_ssdp(
self, discovery_info: DiscoveryInfoType
@ -1378,19 +1379,19 @@ class ConfigFlow(data_entry_flow.FlowHandler):
self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by Zeroconf discovery."""
return await self.async_step_discovery(cast(dict, discovery_info))
return await self.async_step_discovery(dataclasses.asdict(discovery_info))
async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo
) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by DHCP discovery."""
return await self.async_step_discovery(cast(dict, discovery_info))
return await self.async_step_discovery(dataclasses.asdict(discovery_info))
async def async_step_usb(
self, discovery_info: UsbServiceInfo
) -> data_entry_flow.FlowResult:
"""Handle a flow initialized by USB discovery."""
return await self.async_step_discovery(cast(dict, discovery_info))
return await self.async_step_discovery(dataclasses.asdict(discovery_info))
@callback
def async_create_entry( # pylint: disable=arguments-differ

View file

@ -5,6 +5,7 @@ from unittest.mock import patch
from spotipy import SpotifyException
from homeassistant import data_entry_flow, setup
from homeassistant.components import zeroconf
from homeassistant.components.spotify.const import DOMAIN
from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER, SOURCE_ZEROCONF
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET
@ -12,6 +13,15 @@ from homeassistant.helpers import config_entry_oauth2_flow
from tests.common import MockConfigEntry
BLANK_ZEROCONF_INFO = zeroconf.ZeroconfServiceInfo(
host="1.2.3.4",
hostname="mock_hostname",
name="mock_name",
port=None,
properties={},
type="mock_type",
)
async def test_abort_if_no_configuration(hass):
"""Check flow aborts when no configuration is present."""
@ -23,7 +33,7 @@ async def test_abort_if_no_configuration(hass):
assert result["reason"] == "missing_configuration"
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_ZEROCONF}
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=BLANK_ZEROCONF_INFO
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
@ -35,7 +45,7 @@ async def test_zeroconf_abort_if_existing_entry(hass):
MockConfigEntry(domain=DOMAIN).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_ZEROCONF}
DOMAIN, context={"source": SOURCE_ZEROCONF}, data=BLANK_ZEROCONF_INFO
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT

View file

@ -220,7 +220,9 @@ async def test_step_discovery(hass, flow_handler, local_impl):
)
result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}
TEST_DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=data_entry_flow.BaseServiceInfo(),
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
@ -242,7 +244,9 @@ async def test_abort_discovered_multiple(hass, flow_handler, local_impl):
assert result["step_id"] == "pick_implementation"
result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}
TEST_DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=data_entry_flow.BaseServiceInfo(),
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT

View file

@ -9,7 +9,7 @@ import pytest
from homeassistant import config_entries, data_entry_flow, loader
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CoreState, callback
from homeassistant.data_entry_flow import RESULT_TYPE_ABORT
from homeassistant.data_entry_flow import RESULT_TYPE_ABORT, BaseServiceInfo
from homeassistant.exceptions import (
ConfigEntryAuthFailed,
ConfigEntryNotReady,
@ -2350,13 +2350,13 @@ async def test_async_setup_update_entry(hass):
@pytest.mark.parametrize(
"discovery_source",
(
config_entries.SOURCE_DISCOVERY,
config_entries.SOURCE_SSDP,
config_entries.SOURCE_USB,
config_entries.SOURCE_HOMEKIT,
config_entries.SOURCE_DHCP,
config_entries.SOURCE_ZEROCONF,
config_entries.SOURCE_HASSIO,
(config_entries.SOURCE_DISCOVERY, {}),
(config_entries.SOURCE_SSDP, {}),
(config_entries.SOURCE_USB, BaseServiceInfo()),
(config_entries.SOURCE_HOMEKIT, BaseServiceInfo()),
(config_entries.SOURCE_DHCP, BaseServiceInfo()),
(config_entries.SOURCE_ZEROCONF, BaseServiceInfo()),
(config_entries.SOURCE_HASSIO, {}),
),
)
async def test_flow_with_default_discovery(hass, manager, discovery_source):
@ -2382,7 +2382,7 @@ async def test_flow_with_default_discovery(hass, manager, discovery_source):
with patch.dict(config_entries.HANDLERS, {"comp": TestFlow}):
# Create one to be in progress
result = await manager.flow.async_init(
"comp", context={"source": discovery_source}
"comp", context={"source": discovery_source[0]}, data=discovery_source[1]
)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
@ -2403,7 +2403,7 @@ async def test_flow_with_default_discovery(hass, manager, discovery_source):
entry = hass.config_entries.async_entries("comp")[0]
assert entry.title == "yo"
assert entry.source == discovery_source
assert entry.source == discovery_source[0]
assert entry.unique_id is None