From ed31cc363b3a457ed9249f7cf0386f2ec86e1850 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 11 Feb 2021 17:36:19 +0100 Subject: [PATCH] Wait for registries to load at startup (#46265) * Wait for registries to load at startup * Don't decorate new functions with @bind_hass * Fix typing errors in zwave_js * Load registries in async_test_home_assistant * Tweak * Typo * Tweak * Explicitly silence mypy errors * Fix tests * Fix more tests * Fix test * Improve docstring * Wait for registries to load --- homeassistant/bootstrap.py | 11 ++++-- homeassistant/components/zwave_js/__init__.py | 8 ++-- homeassistant/helpers/area_registry.py | 37 +++++++++---------- homeassistant/helpers/device_registry.py | 28 ++++++++++---- homeassistant/helpers/entity_registry.py | 27 +++++++++++--- tests/common.py | 11 +++++- tests/components/discovery/test_init.py | 14 +++---- tests/components/template/test_sensor.py | 3 ++ tests/conftest.py | 14 ++++++- tests/helpers/test_area_registry.py | 21 ++--------- tests/helpers/test_device_registry.py | 31 ++++++---------- tests/helpers/test_entity_registry.py | 26 ++++--------- tests/test_bootstrap.py | 7 ++++ 13 files changed, 131 insertions(+), 107 deletions(-) diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 0f5bda7fbf24..fd2d580a879e 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -17,6 +17,7 @@ from homeassistant import config as conf_util, config_entries, core, loader from homeassistant.components import http from homeassistant.const import REQUIRED_NEXT_PYTHON_DATE, REQUIRED_NEXT_PYTHON_VER from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import area_registry, device_registry, entity_registry from homeassistant.helpers.typing import ConfigType from homeassistant.setup import ( DATA_SETUP, @@ -510,10 +511,12 @@ async def _async_set_up_integrations( stage_2_domains = domains_to_setup - logging_domains - debuggers - stage_1_domains - # Kick off loading the registries. They don't need to be awaited. - asyncio.create_task(hass.helpers.device_registry.async_get_registry()) - asyncio.create_task(hass.helpers.entity_registry.async_get_registry()) - asyncio.create_task(hass.helpers.area_registry.async_get_registry()) + # Load the registries + await asyncio.gather( + device_registry.async_load(hass), + entity_registry.async_load(hass), + area_registry.async_load(hass), + ) # Start setup if stage_1_domains: diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index 01b8f4785c5d..d5624551b279 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -131,8 +131,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # grab device in device registry attached to this node dev_id = get_device_id(client, node) device = dev_reg.async_get_device({dev_id}) - # note: removal of entity registry is handled by core - dev_reg.async_remove_device(device.id) + # note: removal of entity registry entry is handled by core + dev_reg.async_remove_device(device.id) # type: ignore @callback def async_on_value_notification(notification: ValueNotification) -> None: @@ -149,7 +149,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ATTR_NODE_ID: notification.node.node_id, ATTR_HOME_ID: client.driver.controller.home_id, ATTR_ENDPOINT: notification.endpoint, - ATTR_DEVICE_ID: device.id, + ATTR_DEVICE_ID: device.id, # type: ignore ATTR_COMMAND_CLASS: notification.command_class, ATTR_COMMAND_CLASS_NAME: notification.command_class_name, ATTR_LABEL: notification.metadata.label, @@ -170,7 +170,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ATTR_DOMAIN: DOMAIN, ATTR_NODE_ID: notification.node.node_id, ATTR_HOME_ID: client.driver.controller.home_id, - ATTR_DEVICE_ID: device.id, + ATTR_DEVICE_ID: device.id, # type: ignore ATTR_LABEL: notification.notification_label, ATTR_PARAMETERS: notification.parameters, }, diff --git a/homeassistant/helpers/area_registry.py b/homeassistant/helpers/area_registry.py index bdd231686e22..a41e748d1adb 100644 --- a/homeassistant/helpers/area_registry.py +++ b/homeassistant/helpers/area_registry.py @@ -1,5 +1,5 @@ """Provide a way to connect devices to one physical location.""" -from asyncio import Event, gather +from asyncio import gather from collections import OrderedDict from typing import Container, Dict, Iterable, List, MutableMapping, Optional, cast @@ -154,24 +154,23 @@ class AreaRegistry: return data +@callback +def async_get(hass: HomeAssistantType) -> AreaRegistry: + """Get area registry.""" + return cast(AreaRegistry, hass.data[DATA_REGISTRY]) + + +async def async_load(hass: HomeAssistantType) -> None: + """Load area registry.""" + assert DATA_REGISTRY not in hass.data + hass.data[DATA_REGISTRY] = AreaRegistry(hass) + await hass.data[DATA_REGISTRY].async_load() + + @bind_hass async def async_get_registry(hass: HomeAssistantType) -> AreaRegistry: - """Return area registry instance.""" - reg_or_evt = hass.data.get(DATA_REGISTRY) + """Get area registry. - if not reg_or_evt: - evt = hass.data[DATA_REGISTRY] = Event() - - reg = AreaRegistry(hass) - await reg.async_load() - - hass.data[DATA_REGISTRY] = reg - evt.set() - return reg - - if isinstance(reg_or_evt, Event): - evt = reg_or_evt - await evt.wait() - return cast(AreaRegistry, hass.data.get(DATA_REGISTRY)) - - return cast(AreaRegistry, reg_or_evt) + This is deprecated and will be removed in the future. Use async_get instead. + """ + return async_get(hass) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index c449d2ed4d02..0d62b2cab479 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -2,16 +2,16 @@ from collections import OrderedDict import logging import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union, cast import attr from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import Event, callback +from homeassistant.loader import bind_hass import homeassistant.util.uuid as uuid_util from .debounce import Debouncer -from .singleton import singleton from .typing import UNDEFINED, HomeAssistantType, UndefinedType # mypy: disallow_any_generics @@ -593,12 +593,26 @@ class DeviceRegistry: self._async_update_device(dev_id, area_id=None) -@singleton(DATA_REGISTRY) +@callback +def async_get(hass: HomeAssistantType) -> DeviceRegistry: + """Get device registry.""" + return cast(DeviceRegistry, hass.data[DATA_REGISTRY]) + + +async def async_load(hass: HomeAssistantType) -> None: + """Load device registry.""" + assert DATA_REGISTRY not in hass.data + hass.data[DATA_REGISTRY] = DeviceRegistry(hass) + await hass.data[DATA_REGISTRY].async_load() + + +@bind_hass async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry: - """Create entity registry.""" - reg = DeviceRegistry(hass) - await reg.async_load() - return reg + """Get device registry. + + This is deprecated and will be removed in the future. Use async_get instead. + """ + return async_get(hass) @callback diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 052e7398ba16..51985f7bae45 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -19,6 +19,7 @@ from typing import ( Optional, Tuple, Union, + cast, ) import attr @@ -35,10 +36,10 @@ from homeassistant.const import ( ) from homeassistant.core import Event, callback, split_entity_id, valid_entity_id from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED +from homeassistant.loader import bind_hass from homeassistant.util import slugify from homeassistant.util.yaml import load_yaml -from .singleton import singleton from .typing import UNDEFINED, HomeAssistantType, UndefinedType if TYPE_CHECKING: @@ -568,12 +569,26 @@ class EntityRegistry: self._add_index(entry) -@singleton(DATA_REGISTRY) +@callback +def async_get(hass: HomeAssistantType) -> EntityRegistry: + """Get entity registry.""" + return cast(EntityRegistry, hass.data[DATA_REGISTRY]) + + +async def async_load(hass: HomeAssistantType) -> None: + """Load entity registry.""" + assert DATA_REGISTRY not in hass.data + hass.data[DATA_REGISTRY] = EntityRegistry(hass) + await hass.data[DATA_REGISTRY].async_load() + + +@bind_hass async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry: - """Create entity registry.""" - reg = EntityRegistry(hass) - await reg.async_load() - return reg + """Get entity registry. + + This is deprecated and will be removed in the future. Use async_get instead. + """ + return async_get(hass) @callback diff --git a/tests/common.py b/tests/common.py index ab5da25e38d3..c07716dbfc9a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -146,7 +146,7 @@ def get_test_home_assistant(): # pylint: disable=protected-access -async def async_test_home_assistant(loop): +async def async_test_home_assistant(loop, load_registries=True): """Return a Home Assistant object pointing at test config dir.""" hass = ha.HomeAssistant() store = auth_store.AuthStore(hass) @@ -280,6 +280,15 @@ async def async_test_home_assistant(loop): hass.config_entries._entries = [] hass.config_entries._store._async_ensure_stop_listener = lambda: None + # Load the registries + if load_registries: + await asyncio.gather( + device_registry.async_load(hass), + entity_registry.async_load(hass), + area_registry.async_load(hass), + ) + await hass.async_block_till_done() + hass.state = ha.CoreState.running # Mock async_start diff --git a/tests/components/discovery/test_init.py b/tests/components/discovery/test_init.py index fd66e59ef212..2c1e41e82851 100644 --- a/tests/components/discovery/test_init.py +++ b/tests/components/discovery/test_init.py @@ -38,19 +38,17 @@ async def mock_discovery(hass, discoveries, config=BASE_CONFIG): """Mock discoveries.""" with patch("homeassistant.components.zeroconf.async_get_instance"), patch( "homeassistant.components.zeroconf.async_setup", return_value=True - ): - assert await async_setup_component(hass, "discovery", config) - await hass.async_block_till_done() - await hass.async_start() - hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) - await hass.async_block_till_done() - - with patch.object(discovery, "_discover", discoveries), patch( + ), patch.object(discovery, "_discover", discoveries), patch( "homeassistant.components.discovery.async_discover", return_value=mock_coro() ) as mock_discover, patch( "homeassistant.components.discovery.async_load_platform", return_value=mock_coro(), ) as mock_platform: + assert await async_setup_component(hass, "discovery", config) + await hass.async_block_till_done() + await hass.async_start() + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() async_fire_time_changed(hass, utcnow()) # Work around an issue where our loop.call_soon not get caught await hass.async_block_till_done() diff --git a/tests/components/template/test_sensor.py b/tests/components/template/test_sensor.py index 7f560fa0abb1..9d014f86a367 100644 --- a/tests/components/template/test_sensor.py +++ b/tests/components/template/test_sensor.py @@ -3,6 +3,8 @@ from asyncio import Event from datetime import timedelta from unittest.mock import patch +import pytest + from homeassistant.bootstrap import async_from_config_dict from homeassistant.components import sensor from homeassistant.const import ( @@ -403,6 +405,7 @@ async def test_setup_valid_device_class(hass): assert "device_class" not in state.attributes +@pytest.mark.parametrize("load_registries", [False]) async def test_creating_sensor_loads_group(hass): """Test setting up template sensor loads group component first.""" order = [] diff --git a/tests/conftest.py b/tests/conftest.py index 6e3edbd73e86..3fc2dc748cbd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -121,7 +121,17 @@ def hass_storage(): @pytest.fixture -def hass(loop, hass_storage, request): +def load_registries(): + """Fixture to control the loading of registries when setting up the hass fixture. + + To avoid loading the registries, tests can be marked with: + @pytest.mark.parametrize("load_registries", [False]) + """ + return True + + +@pytest.fixture +def hass(loop, load_registries, hass_storage, request): """Fixture to provide a test instance of Home Assistant.""" def exc_handle(loop, context): @@ -141,7 +151,7 @@ def hass(loop, hass_storage, request): orig_exception_handler(loop, context) exceptions = [] - hass = loop.run_until_complete(async_test_home_assistant(loop)) + hass = loop.run_until_complete(async_test_home_assistant(loop, load_registries)) orig_exception_handler = loop.get_exception_handler() loop.set_exception_handler(exc_handle) diff --git a/tests/helpers/test_area_registry.py b/tests/helpers/test_area_registry.py index ec008dde7da1..2b06202c8621 100644 --- a/tests/helpers/test_area_registry.py +++ b/tests/helpers/test_area_registry.py @@ -1,7 +1,4 @@ """Tests for the Area Registry.""" -import asyncio -import unittest.mock - import pytest from homeassistant.core import callback @@ -164,6 +161,7 @@ async def test_load_area(hass, registry): assert list(registry.areas) == list(registry2.areas) +@pytest.mark.parametrize("load_registries", [False]) async def test_loading_area_from_storage(hass, hass_storage): """Test loading stored areas on start.""" hass_storage[area_registry.STORAGE_KEY] = { @@ -171,20 +169,7 @@ async def test_loading_area_from_storage(hass, hass_storage): "data": {"areas": [{"id": "12345A", "name": "mock"}]}, } - registry = await area_registry.async_get_registry(hass) + await area_registry.async_load(hass) + registry = area_registry.async_get(hass) assert len(registry.areas) == 1 - - -async def test_loading_race_condition(hass): - """Test only one storage load called when concurrent loading occurred .""" - with unittest.mock.patch( - "homeassistant.helpers.area_registry.AreaRegistry.async_load" - ) as mock_load: - results = await asyncio.gather( - area_registry.async_get_registry(hass), - area_registry.async_get_registry(hass), - ) - - mock_load.assert_called_once_with() - assert results[0] == results[1] diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index 01959174335c..a128f8aa3903 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -1,5 +1,4 @@ """Tests for the Device Registry.""" -import asyncio import time from unittest.mock import patch @@ -135,6 +134,7 @@ async def test_multiple_config_entries(registry): assert entry2.config_entries == {"123", "456"} +@pytest.mark.parametrize("load_registries", [False]) async def test_loading_from_storage(hass, hass_storage): """Test loading stored devices on start.""" hass_storage[device_registry.STORAGE_KEY] = { @@ -167,7 +167,8 @@ async def test_loading_from_storage(hass, hass_storage): }, } - registry = await device_registry.async_get_registry(hass) + await device_registry.async_load(hass) + registry = device_registry.async_get(hass) assert len(registry.devices) == 1 assert len(registry.deleted_devices) == 1 @@ -687,20 +688,6 @@ async def test_update_remove_config_entries(hass, registry, update_events): assert update_events[4]["device_id"] == entry3.id -async def test_loading_race_condition(hass): - """Test only one storage load called when concurrent loading occurred .""" - with patch( - "homeassistant.helpers.device_registry.DeviceRegistry.async_load" - ) as mock_load: - results = await asyncio.gather( - device_registry.async_get_registry(hass), - device_registry.async_get_registry(hass), - ) - - mock_load.assert_called_once_with() - assert results[0] == results[1] - - async def test_update_sw_version(registry): """Verify that we can update software version of a device.""" entry = registry.async_get_or_create( @@ -798,10 +785,16 @@ async def test_cleanup_startup(hass): assert len(mock_call.mock_calls) == 1 +@pytest.mark.parametrize("load_registries", [False]) async def test_cleanup_entity_registry_change(hass): - """Test we run a cleanup when entity registry changes.""" - await device_registry.async_get_registry(hass) - ent_reg = await entity_registry.async_get_registry(hass) + """Test we run a cleanup when entity registry changes. + + Don't pre-load the registries as the debouncer will then not be waiting for + EVENT_ENTITY_REGISTRY_UPDATED events. + """ + await device_registry.async_load(hass) + await entity_registry.async_load(hass) + ent_reg = entity_registry.async_get(hass) with patch( "homeassistant.helpers.device_registry.Debouncer.async_call" diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index b176f7022d53..71cfb3315918 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -1,6 +1,4 @@ """Tests for the Entity Registry.""" -import asyncio -import unittest.mock from unittest.mock import patch import pytest @@ -219,6 +217,7 @@ def test_is_registered(registry): assert not registry.async_is_registered("light.non_existing") +@pytest.mark.parametrize("load_registries", [False]) async def test_loading_extra_values(hass, hass_storage): """Test we load extra data from the registry.""" hass_storage[entity_registry.STORAGE_KEY] = { @@ -258,7 +257,8 @@ async def test_loading_extra_values(hass, hass_storage): }, } - registry = await entity_registry.async_get_registry(hass) + await entity_registry.async_load(hass) + registry = entity_registry.async_get(hass) assert len(registry.entities) == 4 @@ -350,6 +350,7 @@ async def test_removing_area_id(registry): assert entry_w_area != entry_wo_area +@pytest.mark.parametrize("load_registries", [False]) async def test_migration(hass): """Test migration from old data to new.""" mock_config = MockConfigEntry(domain="test-platform", entry_id="test-config-id") @@ -366,7 +367,8 @@ async def test_migration(hass): with patch("os.path.isfile", return_value=True), patch("os.remove"), patch( "homeassistant.helpers.entity_registry.load_yaml", return_value=old_conf ): - registry = await entity_registry.async_get_registry(hass) + await entity_registry.async_load(hass) + registry = entity_registry.async_get(hass) assert registry.async_is_registered("light.kitchen") entry = registry.async_get_or_create( @@ -427,20 +429,6 @@ async def test_loading_invalid_entity_id(hass, hass_storage): assert valid_entity_id(entity_invalid_start.entity_id) -async def test_loading_race_condition(hass): - """Test only one storage load called when concurrent loading occurred .""" - with unittest.mock.patch( - "homeassistant.helpers.entity_registry.EntityRegistry.async_load" - ) as mock_load: - results = await asyncio.gather( - entity_registry.async_get_registry(hass), - entity_registry.async_get_registry(hass), - ) - - mock_load.assert_called_once_with() - assert results[0] == results[1] - - async def test_update_entity_unique_id(registry): """Test entity's unique_id is updated.""" mock_config = MockConfigEntry(domain="light", entry_id="mock-id-1") @@ -794,7 +782,7 @@ async def test_disable_device_disables_entities(hass, registry): async def test_disabled_entities_excluded_from_entity_list(hass, registry): - """Test that disabled entities are exclduded from async_entries_for_device.""" + """Test that disabled entities are excluded from async_entries_for_device.""" device_registry = mock_device_registry(hass) config_entry = MockConfigEntry(domain="light") diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index fc653c25d0ba..c035f6f1d1d7 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -71,6 +71,7 @@ async def test_load_hassio(hass): assert bootstrap._get_domains(hass, {}) == {"hassio"} +@pytest.mark.parametrize("load_registries", [False]) async def test_empty_setup(hass): """Test an empty set up loads the core.""" await bootstrap.async_from_config_dict({}, hass) @@ -91,6 +92,7 @@ async def test_core_failure_loads_safe_mode(hass, caplog): assert "group" not in hass.config.components +@pytest.mark.parametrize("load_registries", [False]) async def test_setting_up_config(hass): """Test we set up domains in config.""" await bootstrap._async_set_up_integrations( @@ -100,6 +102,7 @@ async def test_setting_up_config(hass): assert "group" in hass.config.components +@pytest.mark.parametrize("load_registries", [False]) async def test_setup_after_deps_all_present(hass): """Test after_dependencies when all present.""" order = [] @@ -144,6 +147,7 @@ async def test_setup_after_deps_all_present(hass): assert order == ["logger", "root", "first_dep", "second_dep"] +@pytest.mark.parametrize("load_registries", [False]) async def test_setup_after_deps_in_stage_1_ignored(hass): """Test after_dependencies are ignored in stage 1.""" # This test relies on this @@ -190,6 +194,7 @@ async def test_setup_after_deps_in_stage_1_ignored(hass): assert order == ["cloud", "an_after_dep", "normal_integration"] +@pytest.mark.parametrize("load_registries", [False]) async def test_setup_after_deps_via_platform(hass): """Test after_dependencies set up via platform.""" order = [] @@ -239,6 +244,7 @@ async def test_setup_after_deps_via_platform(hass): assert order == ["after_dep_of_platform_int", "platform_int"] +@pytest.mark.parametrize("load_registries", [False]) async def test_setup_after_deps_not_trigger_load(hass): """Test after_dependencies does not trigger loading it.""" order = [] @@ -277,6 +283,7 @@ async def test_setup_after_deps_not_trigger_load(hass): assert "second_dep" in hass.config.components +@pytest.mark.parametrize("load_registries", [False]) async def test_setup_after_deps_not_present(hass): """Test after_dependencies when referenced integration doesn't exist.""" order = []