Add some typing to common test helpers (#80337)

This commit is contained in:
Franck Nijhof 2022-10-14 18:23:49 +02:00 committed by GitHub
parent 4ebf9df901
commit e3af2cb6b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -20,6 +20,7 @@ from typing import Any
from unittest.mock import AsyncMock, Mock, patch
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
import voluptuous as vol
from homeassistant import auth, config_entries, core as ha, loader
from homeassistant.auth import (
@ -42,7 +43,7 @@ from homeassistant.const import (
STATE_OFF,
STATE_ON,
)
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, ServiceCall, State
from homeassistant.helpers import (
area_registry,
device_registry,
@ -57,6 +58,7 @@ from homeassistant.helpers import (
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import setup_component
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as date_util
@ -328,7 +330,9 @@ async def async_test_home_assistant(loop, load_registries=True):
return hass
def async_mock_service(hass, domain, service, schema=None):
def async_mock_service(
hass: HomeAssistant, domain: str, service: str, schema: vol.Schema | None = None
) -> list[ServiceCall]:
"""Set up a fake service & return a calls log list to this service."""
calls = []
@ -417,18 +421,20 @@ def get_fixture_path(filename: str, integration: str | None = None) -> pathlib.P
if integration is None:
return pathlib.Path(__file__).parent.joinpath("fixtures", filename)
else:
return pathlib.Path(__file__).parent.joinpath(
"components", integration, "fixtures", filename
)
return pathlib.Path(__file__).parent.joinpath(
"components", integration, "fixtures", filename
)
def load_fixture(filename, integration=None):
def load_fixture(filename: str, integration: str | None = None) -> str:
"""Load a fixture."""
return get_fixture_path(filename, integration).read_text()
def mock_state_change_event(hass, new_state, old_state=None):
def mock_state_change_event(
hass: HomeAssistant, new_state: State, old_state: State | None = None
) -> None:
"""Mock state change envent."""
event_data = {"entity_id": new_state.entity_id, "new_state": new_state}
@ -439,7 +445,7 @@ def mock_state_change_event(hass, new_state, old_state=None):
@ha.callback
def mock_component(hass, component):
def mock_component(hass: HomeAssistant, component: str) -> None:
"""Mock a component is setup."""
if component in hass.config.components:
AssertionError(f"Integration {component} is already setup")
@ -447,7 +453,10 @@ def mock_component(hass, component):
hass.config.components.add(component)
def mock_registry(hass, mock_entries=None):
def mock_registry(
hass: HomeAssistant,
mock_entries: dict[str, entity_registry.RegistryEntry] | None = None,
) -> entity_registry.EntityRegistry:
"""Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass)
if mock_entries is None:
@ -460,7 +469,9 @@ def mock_registry(hass, mock_entries=None):
return registry
def mock_area_registry(hass, mock_entries=None):
def mock_area_registry(
hass: HomeAssistant, mock_entries: dict[str, area_registry.AreaEntry] | None = None
) -> area_registry.AreaRegistry:
"""Mock the Area Registry."""
registry = area_registry.AreaRegistry(hass)
registry.areas = mock_entries or OrderedDict()
@ -469,7 +480,10 @@ def mock_area_registry(hass, mock_entries=None):
return registry
def mock_device_registry(hass, mock_entries=None):
def mock_device_registry(
hass: HomeAssistant,
mock_entries: dict[str, device_registry.DeviceEntry] | None = None,
) -> device_registry.DeviceRegistry:
"""Mock the Device Registry."""
registry = device_registry.DeviceRegistry(hass)
registry.devices = device_registry.DeviceRegistryItems()
@ -545,7 +559,9 @@ class MockUser(auth_models.User):
self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
async def register_auth_provider(hass, config):
async def register_auth_provider(
hass: HomeAssistant, config: ConfigType
) -> auth_providers.AuthProvider:
"""Register an auth provider."""
provider = await auth_providers.auth_provider_from_config(
hass, hass.auth._store, config