Wait for registries to load at startup (#46265)

* Wait for registries to load at startup

* Don't decorate new functions with @bind_hass

* Fix typing errors in zwave_js

* Load registries in async_test_home_assistant

* Tweak

* Typo

* Tweak

* Explicitly silence mypy errors

* Fix tests

* Fix more tests

* Fix test

* Improve docstring

* Wait for registries to load
This commit is contained in:
Erik Montnemery 2021-02-11 17:36:19 +01:00 committed by GitHub
parent 888c9e120d
commit ed31cc363b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 131 additions and 107 deletions

View file

@ -17,6 +17,7 @@ from homeassistant import config as conf_util, config_entries, core, loader
from homeassistant.components import http
from homeassistant.const import REQUIRED_NEXT_PYTHON_DATE, REQUIRED_NEXT_PYTHON_VER
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import area_registry, device_registry, entity_registry
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import (
DATA_SETUP,
@ -510,10 +511,12 @@ async def _async_set_up_integrations(
stage_2_domains = domains_to_setup - logging_domains - debuggers - stage_1_domains
# Kick off loading the registries. They don't need to be awaited.
asyncio.create_task(hass.helpers.device_registry.async_get_registry())
asyncio.create_task(hass.helpers.entity_registry.async_get_registry())
asyncio.create_task(hass.helpers.area_registry.async_get_registry())
# Load the registries
await asyncio.gather(
device_registry.async_load(hass),
entity_registry.async_load(hass),
area_registry.async_load(hass),
)
# Start setup
if stage_1_domains:

View file

@ -131,8 +131,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# grab device in device registry attached to this node
dev_id = get_device_id(client, node)
device = dev_reg.async_get_device({dev_id})
# note: removal of entity registry is handled by core
dev_reg.async_remove_device(device.id)
# note: removal of entity registry entry is handled by core
dev_reg.async_remove_device(device.id) # type: ignore
@callback
def async_on_value_notification(notification: ValueNotification) -> None:
@ -149,7 +149,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
ATTR_NODE_ID: notification.node.node_id,
ATTR_HOME_ID: client.driver.controller.home_id,
ATTR_ENDPOINT: notification.endpoint,
ATTR_DEVICE_ID: device.id,
ATTR_DEVICE_ID: device.id, # type: ignore
ATTR_COMMAND_CLASS: notification.command_class,
ATTR_COMMAND_CLASS_NAME: notification.command_class_name,
ATTR_LABEL: notification.metadata.label,
@ -170,7 +170,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
ATTR_DOMAIN: DOMAIN,
ATTR_NODE_ID: notification.node.node_id,
ATTR_HOME_ID: client.driver.controller.home_id,
ATTR_DEVICE_ID: device.id,
ATTR_DEVICE_ID: device.id, # type: ignore
ATTR_LABEL: notification.notification_label,
ATTR_PARAMETERS: notification.parameters,
},

View file

@ -1,5 +1,5 @@
"""Provide a way to connect devices to one physical location."""
from asyncio import Event, gather
from asyncio import gather
from collections import OrderedDict
from typing import Container, Dict, Iterable, List, MutableMapping, Optional, cast
@ -154,24 +154,23 @@ class AreaRegistry:
return data
@callback
def async_get(hass: HomeAssistantType) -> AreaRegistry:
"""Get area registry."""
return cast(AreaRegistry, hass.data[DATA_REGISTRY])
async def async_load(hass: HomeAssistantType) -> None:
"""Load area registry."""
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = AreaRegistry(hass)
await hass.data[DATA_REGISTRY].async_load()
@bind_hass
async def async_get_registry(hass: HomeAssistantType) -> AreaRegistry:
"""Return area registry instance."""
reg_or_evt = hass.data.get(DATA_REGISTRY)
"""Get area registry.
if not reg_or_evt:
evt = hass.data[DATA_REGISTRY] = Event()
reg = AreaRegistry(hass)
await reg.async_load()
hass.data[DATA_REGISTRY] = reg
evt.set()
return reg
if isinstance(reg_or_evt, Event):
evt = reg_or_evt
await evt.wait()
return cast(AreaRegistry, hass.data.get(DATA_REGISTRY))
return cast(AreaRegistry, reg_or_evt)
This is deprecated and will be removed in the future. Use async_get instead.
"""
return async_get(hass)

View file

@ -2,16 +2,16 @@
from collections import OrderedDict
import logging
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, cast
import attr
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import Event, callback
from homeassistant.loader import bind_hass
import homeassistant.util.uuid as uuid_util
from .debounce import Debouncer
from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType, UndefinedType
# mypy: disallow_any_generics
@ -593,12 +593,26 @@ class DeviceRegistry:
self._async_update_device(dev_id, area_id=None)
@singleton(DATA_REGISTRY)
@callback
def async_get(hass: HomeAssistantType) -> DeviceRegistry:
"""Get device registry."""
return cast(DeviceRegistry, hass.data[DATA_REGISTRY])
async def async_load(hass: HomeAssistantType) -> None:
"""Load device registry."""
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = DeviceRegistry(hass)
await hass.data[DATA_REGISTRY].async_load()
@bind_hass
async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry:
"""Create entity registry."""
reg = DeviceRegistry(hass)
await reg.async_load()
return reg
"""Get device registry.
This is deprecated and will be removed in the future. Use async_get instead.
"""
return async_get(hass)
@callback

View file

@ -19,6 +19,7 @@ from typing import (
Optional,
Tuple,
Union,
cast,
)
import attr
@ -35,10 +36,10 @@ from homeassistant.const import (
)
from homeassistant.core import Event, callback, split_entity_id, valid_entity_id
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.loader import bind_hass
from homeassistant.util import slugify
from homeassistant.util.yaml import load_yaml
from .singleton import singleton
from .typing import UNDEFINED, HomeAssistantType, UndefinedType
if TYPE_CHECKING:
@ -568,12 +569,26 @@ class EntityRegistry:
self._add_index(entry)
@singleton(DATA_REGISTRY)
@callback
def async_get(hass: HomeAssistantType) -> EntityRegistry:
"""Get entity registry."""
return cast(EntityRegistry, hass.data[DATA_REGISTRY])
async def async_load(hass: HomeAssistantType) -> None:
"""Load entity registry."""
assert DATA_REGISTRY not in hass.data
hass.data[DATA_REGISTRY] = EntityRegistry(hass)
await hass.data[DATA_REGISTRY].async_load()
@bind_hass
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
"""Create entity registry."""
reg = EntityRegistry(hass)
await reg.async_load()
return reg
"""Get entity registry.
This is deprecated and will be removed in the future. Use async_get instead.
"""
return async_get(hass)
@callback

View file

@ -146,7 +146,7 @@ def get_test_home_assistant():
# pylint: disable=protected-access
async def async_test_home_assistant(loop):
async def async_test_home_assistant(loop, load_registries=True):
"""Return a Home Assistant object pointing at test config dir."""
hass = ha.HomeAssistant()
store = auth_store.AuthStore(hass)
@ -280,6 +280,15 @@ async def async_test_home_assistant(loop):
hass.config_entries._entries = []
hass.config_entries._store._async_ensure_stop_listener = lambda: None
# Load the registries
if load_registries:
await asyncio.gather(
device_registry.async_load(hass),
entity_registry.async_load(hass),
area_registry.async_load(hass),
)
await hass.async_block_till_done()
hass.state = ha.CoreState.running
# Mock async_start

View file

@ -38,19 +38,17 @@ async def mock_discovery(hass, discoveries, config=BASE_CONFIG):
"""Mock discoveries."""
with patch("homeassistant.components.zeroconf.async_get_instance"), patch(
"homeassistant.components.zeroconf.async_setup", return_value=True
):
assert await async_setup_component(hass, "discovery", config)
await hass.async_block_till_done()
await hass.async_start()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
with patch.object(discovery, "_discover", discoveries), patch(
), patch.object(discovery, "_discover", discoveries), patch(
"homeassistant.components.discovery.async_discover", return_value=mock_coro()
) as mock_discover, patch(
"homeassistant.components.discovery.async_load_platform",
return_value=mock_coro(),
) as mock_platform:
assert await async_setup_component(hass, "discovery", config)
await hass.async_block_till_done()
await hass.async_start()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
async_fire_time_changed(hass, utcnow())
# Work around an issue where our loop.call_soon not get caught
await hass.async_block_till_done()

View file

@ -3,6 +3,8 @@ from asyncio import Event
from datetime import timedelta
from unittest.mock import patch
import pytest
from homeassistant.bootstrap import async_from_config_dict
from homeassistant.components import sensor
from homeassistant.const import (
@ -403,6 +405,7 @@ async def test_setup_valid_device_class(hass):
assert "device_class" not in state.attributes
@pytest.mark.parametrize("load_registries", [False])
async def test_creating_sensor_loads_group(hass):
"""Test setting up template sensor loads group component first."""
order = []

View file

@ -121,7 +121,17 @@ def hass_storage():
@pytest.fixture
def hass(loop, hass_storage, request):
def load_registries():
"""Fixture to control the loading of registries when setting up the hass fixture.
To avoid loading the registries, tests can be marked with:
@pytest.mark.parametrize("load_registries", [False])
"""
return True
@pytest.fixture
def hass(loop, load_registries, hass_storage, request):
"""Fixture to provide a test instance of Home Assistant."""
def exc_handle(loop, context):
@ -141,7 +151,7 @@ def hass(loop, hass_storage, request):
orig_exception_handler(loop, context)
exceptions = []
hass = loop.run_until_complete(async_test_home_assistant(loop))
hass = loop.run_until_complete(async_test_home_assistant(loop, load_registries))
orig_exception_handler = loop.get_exception_handler()
loop.set_exception_handler(exc_handle)

View file

@ -1,7 +1,4 @@
"""Tests for the Area Registry."""
import asyncio
import unittest.mock
import pytest
from homeassistant.core import callback
@ -164,6 +161,7 @@ async def test_load_area(hass, registry):
assert list(registry.areas) == list(registry2.areas)
@pytest.mark.parametrize("load_registries", [False])
async def test_loading_area_from_storage(hass, hass_storage):
"""Test loading stored areas on start."""
hass_storage[area_registry.STORAGE_KEY] = {
@ -171,20 +169,7 @@ async def test_loading_area_from_storage(hass, hass_storage):
"data": {"areas": [{"id": "12345A", "name": "mock"}]},
}
registry = await area_registry.async_get_registry(hass)
await area_registry.async_load(hass)
registry = area_registry.async_get(hass)
assert len(registry.areas) == 1
async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred ."""
with unittest.mock.patch(
"homeassistant.helpers.area_registry.AreaRegistry.async_load"
) as mock_load:
results = await asyncio.gather(
area_registry.async_get_registry(hass),
area_registry.async_get_registry(hass),
)
mock_load.assert_called_once_with()
assert results[0] == results[1]

View file

@ -1,5 +1,4 @@
"""Tests for the Device Registry."""
import asyncio
import time
from unittest.mock import patch
@ -135,6 +134,7 @@ async def test_multiple_config_entries(registry):
assert entry2.config_entries == {"123", "456"}
@pytest.mark.parametrize("load_registries", [False])
async def test_loading_from_storage(hass, hass_storage):
"""Test loading stored devices on start."""
hass_storage[device_registry.STORAGE_KEY] = {
@ -167,7 +167,8 @@ async def test_loading_from_storage(hass, hass_storage):
},
}
registry = await device_registry.async_get_registry(hass)
await device_registry.async_load(hass)
registry = device_registry.async_get(hass)
assert len(registry.devices) == 1
assert len(registry.deleted_devices) == 1
@ -687,20 +688,6 @@ async def test_update_remove_config_entries(hass, registry, update_events):
assert update_events[4]["device_id"] == entry3.id
async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred ."""
with patch(
"homeassistant.helpers.device_registry.DeviceRegistry.async_load"
) as mock_load:
results = await asyncio.gather(
device_registry.async_get_registry(hass),
device_registry.async_get_registry(hass),
)
mock_load.assert_called_once_with()
assert results[0] == results[1]
async def test_update_sw_version(registry):
"""Verify that we can update software version of a device."""
entry = registry.async_get_or_create(
@ -798,10 +785,16 @@ async def test_cleanup_startup(hass):
assert len(mock_call.mock_calls) == 1
@pytest.mark.parametrize("load_registries", [False])
async def test_cleanup_entity_registry_change(hass):
"""Test we run a cleanup when entity registry changes."""
await device_registry.async_get_registry(hass)
ent_reg = await entity_registry.async_get_registry(hass)
"""Test we run a cleanup when entity registry changes.
Don't pre-load the registries as the debouncer will then not be waiting for
EVENT_ENTITY_REGISTRY_UPDATED events.
"""
await device_registry.async_load(hass)
await entity_registry.async_load(hass)
ent_reg = entity_registry.async_get(hass)
with patch(
"homeassistant.helpers.device_registry.Debouncer.async_call"

View file

@ -1,6 +1,4 @@
"""Tests for the Entity Registry."""
import asyncio
import unittest.mock
from unittest.mock import patch
import pytest
@ -219,6 +217,7 @@ def test_is_registered(registry):
assert not registry.async_is_registered("light.non_existing")
@pytest.mark.parametrize("load_registries", [False])
async def test_loading_extra_values(hass, hass_storage):
"""Test we load extra data from the registry."""
hass_storage[entity_registry.STORAGE_KEY] = {
@ -258,7 +257,8 @@ async def test_loading_extra_values(hass, hass_storage):
},
}
registry = await entity_registry.async_get_registry(hass)
await entity_registry.async_load(hass)
registry = entity_registry.async_get(hass)
assert len(registry.entities) == 4
@ -350,6 +350,7 @@ async def test_removing_area_id(registry):
assert entry_w_area != entry_wo_area
@pytest.mark.parametrize("load_registries", [False])
async def test_migration(hass):
"""Test migration from old data to new."""
mock_config = MockConfigEntry(domain="test-platform", entry_id="test-config-id")
@ -366,7 +367,8 @@ async def test_migration(hass):
with patch("os.path.isfile", return_value=True), patch("os.remove"), patch(
"homeassistant.helpers.entity_registry.load_yaml", return_value=old_conf
):
registry = await entity_registry.async_get_registry(hass)
await entity_registry.async_load(hass)
registry = entity_registry.async_get(hass)
assert registry.async_is_registered("light.kitchen")
entry = registry.async_get_or_create(
@ -427,20 +429,6 @@ async def test_loading_invalid_entity_id(hass, hass_storage):
assert valid_entity_id(entity_invalid_start.entity_id)
async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred ."""
with unittest.mock.patch(
"homeassistant.helpers.entity_registry.EntityRegistry.async_load"
) as mock_load:
results = await asyncio.gather(
entity_registry.async_get_registry(hass),
entity_registry.async_get_registry(hass),
)
mock_load.assert_called_once_with()
assert results[0] == results[1]
async def test_update_entity_unique_id(registry):
"""Test entity's unique_id is updated."""
mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1")
@ -794,7 +782,7 @@ async def test_disable_device_disables_entities(hass, registry):
async def test_disabled_entities_excluded_from_entity_list(hass, registry):
"""Test that disabled entities are exclduded from async_entries_for_device."""
"""Test that disabled entities are excluded from async_entries_for_device."""
device_registry = mock_device_registry(hass)
config_entry = MockConfigEntry(domain="light")

View file

@ -71,6 +71,7 @@ async def test_load_hassio(hass):
assert bootstrap._get_domains(hass, {}) == {"hassio"}
@pytest.mark.parametrize("load_registries", [False])
async def test_empty_setup(hass):
"""Test an empty set up loads the core."""
await bootstrap.async_from_config_dict({}, hass)
@ -91,6 +92,7 @@ async def test_core_failure_loads_safe_mode(hass, caplog):
assert "group" not in hass.config.components
@pytest.mark.parametrize("load_registries", [False])
async def test_setting_up_config(hass):
"""Test we set up domains in config."""
await bootstrap._async_set_up_integrations(
@ -100,6 +102,7 @@ async def test_setting_up_config(hass):
assert "group" in hass.config.components
@pytest.mark.parametrize("load_registries", [False])
async def test_setup_after_deps_all_present(hass):
"""Test after_dependencies when all present."""
order = []
@ -144,6 +147,7 @@ async def test_setup_after_deps_all_present(hass):
assert order == ["logger", "root", "first_dep", "second_dep"]
@pytest.mark.parametrize("load_registries", [False])
async def test_setup_after_deps_in_stage_1_ignored(hass):
"""Test after_dependencies are ignored in stage 1."""
# This test relies on this
@ -190,6 +194,7 @@ async def test_setup_after_deps_in_stage_1_ignored(hass):
assert order == ["cloud", "an_after_dep", "normal_integration"]
@pytest.mark.parametrize("load_registries", [False])
async def test_setup_after_deps_via_platform(hass):
"""Test after_dependencies set up via platform."""
order = []
@ -239,6 +244,7 @@ async def test_setup_after_deps_via_platform(hass):
assert order == ["after_dep_of_platform_int", "platform_int"]
@pytest.mark.parametrize("load_registries", [False])
async def test_setup_after_deps_not_trigger_load(hass):
"""Test after_dependencies does not trigger loading it."""
order = []
@ -277,6 +283,7 @@ async def test_setup_after_deps_not_trigger_load(hass):
assert "second_dep" in hass.config.components
@pytest.mark.parametrize("load_registries", [False])
async def test_setup_after_deps_not_present(hass):
"""Test after_dependencies when referenced integration doesn't exist."""
order = []