Add setup function to the component loader (#98148)

* Add setup function to the component loader

* Update test

* Setup the loader in safe mode and in check_config script
This commit is contained in:
Erik Montnemery 2023-08-15 10:59:42 +02:00 committed by GitHub
parent b1e5b3be34
commit 3b9d6f2dde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 38 additions and 48 deletions

View file

@ -134,6 +134,7 @@ async def async_setup_hass(
_LOGGER.info("Config directory: %s", runtime_config.config_dir)
loader.async_setup(hass)
config_dict = None
basic_setup_success = False
@ -185,6 +186,8 @@ async def async_setup_hass(
hass.config.internal_url = old_config.internal_url
hass.config.external_url = old_config.external_url
hass.config.config_dir = old_config.config_dir
# Setup loader cache after the config dir has been set
loader.async_setup(hass)
if safe_mode:
_LOGGER.info("Starting in safe mode")

View file

@ -166,6 +166,13 @@ class Manifest(TypedDict, total=False):
loggers: list[str]
def async_setup(hass: HomeAssistant) -> None:
"""Set up the necessary data structures."""
_async_mount_config_dir(hass)
hass.data[DATA_COMPONENTS] = {}
hass.data[DATA_INTEGRATIONS] = {}
def manifest_from_legacy_module(domain: str, module: ModuleType) -> Manifest:
"""Generate a manifest from a legacy module."""
return {
@ -802,9 +809,7 @@ class Integration:
def get_component(self) -> ComponentProtocol:
"""Return the component."""
cache: dict[str, ComponentProtocol] = self.hass.data.setdefault(
DATA_COMPONENTS, {}
)
cache: dict[str, ComponentProtocol] = self.hass.data[DATA_COMPONENTS]
if self.domain in cache:
return cache[self.domain]
@ -824,7 +829,7 @@ class Integration:
def get_platform(self, platform_name: str) -> ModuleType:
"""Return a platform for an integration."""
cache: dict[str, ModuleType] = self.hass.data.setdefault(DATA_COMPONENTS, {})
cache: dict[str, ModuleType] = self.hass.data[DATA_COMPONENTS]
full_name = f"{self.domain}.{platform_name}"
if full_name in cache:
return cache[full_name]
@ -883,11 +888,7 @@ async def async_get_integrations(
hass: HomeAssistant, domains: Iterable[str]
) -> dict[str, Integration | Exception]:
"""Get integrations."""
if (cache := hass.data.get(DATA_INTEGRATIONS)) is None:
if not _async_mount_config_dir(hass):
return {domain: IntegrationNotFound(domain) for domain in domains}
cache = hass.data[DATA_INTEGRATIONS] = {}
cache = hass.data[DATA_INTEGRATIONS]
results: dict[str, Integration | Exception] = {}
needed: dict[str, asyncio.Future[None]] = {}
in_progress: dict[str, asyncio.Future[None]] = {}
@ -993,10 +994,7 @@ def _load_file(
comp_or_platform
]
if (cache := hass.data.get(DATA_COMPONENTS)) is None:
if not _async_mount_config_dir(hass):
return None
cache = hass.data[DATA_COMPONENTS] = {}
cache = hass.data[DATA_COMPONENTS]
for path in (f"{base}.{comp_or_platform}" for base in base_paths):
try:
@ -1066,7 +1064,7 @@ class Components:
def __getattr__(self, comp_name: str) -> ModuleWrapper:
"""Fetch a component."""
# Test integration cache
integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name)
integration = self._hass.data[DATA_INTEGRATIONS].get(comp_name)
if isinstance(integration, Integration):
component: ComponentProtocol | None = integration.get_component()

View file

@ -11,7 +11,7 @@ import os
from typing import Any
from unittest.mock import patch
from homeassistant import core
from homeassistant import core, loader
from homeassistant.config import get_default_config_dir
from homeassistant.config_entries import ConfigEntries
from homeassistant.exceptions import HomeAssistantError
@ -232,6 +232,7 @@ def check(config_dir, secrets=False):
async def async_check_config(config_dir):
"""Check the HA config."""
hass = core.HomeAssistant()
loader.async_setup(hass)
hass.config.config_dir = config_dir
hass.config_entries = ConfigEntries(hass, {})
await ar.async_load(hass)

View file

@ -256,6 +256,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
# Load the registries
entity.async_setup(hass)
loader.async_setup(hass)
if load_registries:
with patch(
"homeassistant.helpers.storage.Store.async_load", return_value=None
@ -1339,16 +1340,10 @@ def mock_integration(
integration._import_platform = mock_import_platform
_LOGGER.info("Adding mock integration: %s", module.DOMAIN)
integration_cache = hass.data.get(loader.DATA_INTEGRATIONS)
if integration_cache is None:
integration_cache = hass.data[loader.DATA_INTEGRATIONS] = {}
loader._async_mount_config_dir(hass)
integration_cache = hass.data[loader.DATA_INTEGRATIONS]
integration_cache[module.DOMAIN] = integration
module_cache = hass.data.get(loader.DATA_COMPONENTS)
if module_cache is None:
module_cache = hass.data[loader.DATA_COMPONENTS] = {}
loader._async_mount_config_dir(hass)
module_cache = hass.data[loader.DATA_COMPONENTS]
module_cache[module.DOMAIN] = module
return integration
@ -1374,15 +1369,8 @@ def mock_platform(
platform_path is in form hue.config_flow.
"""
domain = platform_path.split(".")[0]
integration_cache = hass.data.get(loader.DATA_INTEGRATIONS)
if integration_cache is None:
integration_cache = hass.data[loader.DATA_INTEGRATIONS] = {}
loader._async_mount_config_dir(hass)
module_cache = hass.data.get(loader.DATA_COMPONENTS)
if module_cache is None:
module_cache = hass.data[loader.DATA_COMPONENTS] = {}
loader._async_mount_config_dir(hass)
integration_cache = hass.data[loader.DATA_INTEGRATIONS]
module_cache = hass.data[loader.DATA_COMPONENTS]
if domain not in integration_cache:
mock_integration(hass, MockModule(domain))

View file

@ -304,7 +304,7 @@ async def test_websocket_get_action_capabilities(
return {"extra_fields": vol.Schema({vol.Optional("code"): str})}
return {}
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"]
module.async_get_action_capabilities = _async_get_action_capabilities
@ -406,7 +406,7 @@ async def test_websocket_get_action_capabilities_bad_action(
await async_setup_component(hass, "device_automation", {})
expected_capabilities = {}
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"]
module.async_get_action_capabilities = Mock(
side_effect=InvalidDeviceAutomationConfig
@ -459,7 +459,7 @@ async def test_websocket_get_condition_capabilities(
"""List condition capabilities."""
return await toggle_entity.async_get_condition_capabilities(hass, config)
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_condition"]
module.async_get_condition_capabilities = _async_get_condition_capabilities
@ -569,7 +569,7 @@ async def test_websocket_get_condition_capabilities_bad_condition(
await async_setup_component(hass, "device_automation", {})
expected_capabilities = {}
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_condition"]
module.async_get_condition_capabilities = Mock(
side_effect=InvalidDeviceAutomationConfig
@ -747,7 +747,7 @@ async def test_websocket_get_trigger_capabilities(
"""List trigger capabilities."""
return await toggle_entity.async_get_trigger_capabilities(hass, config)
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"]
module.async_get_trigger_capabilities = _async_get_trigger_capabilities
@ -857,7 +857,7 @@ async def test_websocket_get_trigger_capabilities_bad_trigger(
await async_setup_component(hass, "device_automation", {})
expected_capabilities = {}
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"]
module.async_get_trigger_capabilities = Mock(
side_effect=InvalidDeviceAutomationConfig
@ -912,7 +912,7 @@ async def test_automation_with_device_action(
) -> None:
"""Test automation with a device action."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"]
module.async_call_action_from_config = AsyncMock()
@ -949,7 +949,7 @@ async def test_automation_with_dynamically_validated_action(
) -> None:
"""Test device automation with an action which is dynamically validated."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"]
module.async_validate_action_config = AsyncMock()
@ -1003,7 +1003,7 @@ async def test_automation_with_device_condition(
) -> None:
"""Test automation with a device condition."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_condition"]
module.async_condition_from_config = Mock()
@ -1037,7 +1037,7 @@ async def test_automation_with_dynamically_validated_condition(
) -> None:
"""Test device automation with a condition which is dynamically validated."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_condition"]
module.async_validate_condition_config = AsyncMock()
@ -1102,7 +1102,7 @@ async def test_automation_with_device_trigger(
) -> None:
"""Test automation with a device trigger."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"]
module.async_attach_trigger = AsyncMock()
@ -1136,7 +1136,7 @@ async def test_automation_with_dynamically_validated_trigger(
) -> None:
"""Test device automation with a trigger which is dynamically validated."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"]
module.async_attach_trigger = AsyncMock()
module.async_validate_trigger_config = AsyncMock(wraps=lambda hass, config: config)
@ -1457,7 +1457,7 @@ async def test_automation_with_unknown_device(
) -> None:
"""Test device automation with a trigger with an unknown device."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"]
module.async_validate_trigger_config = AsyncMock()
@ -1492,7 +1492,7 @@ async def test_automation_with_device_wrong_domain(
) -> None:
"""Test device automation where the device doesn't have the right config entry."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"]
module.async_validate_trigger_config = AsyncMock()
@ -1534,7 +1534,7 @@ async def test_automation_with_device_component_not_loaded(
) -> None:
"""Test device automation where the device's config entry is not loaded."""
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_trigger"]
module.async_validate_trigger_config = AsyncMock()
module.async_attach_trigger = AsyncMock()

View file

@ -1810,7 +1810,7 @@ async def test_execute_script_with_dynamically_validated_action(
ws_client = await hass_ws_client(hass)
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
module_cache = hass.data[loader.DATA_COMPONENTS]
module = module_cache["fake_integration.device_action"]
module.async_call_action_from_config = AsyncMock()
module.async_validate_action_config = AsyncMock(