Collection of typing improvements in common test helpers (#85509)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Franck Nijhof 2023-01-13 15:12:11 +01:00 committed by GitHub
parent 6baa905448
commit 4c2b20db68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
from collections import OrderedDict
from collections.abc import Awaitable, Callable, Collection
from collections.abc import Awaitable, Callable, Collection, Mapping, Sequence
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
import functools as ft
@ -16,13 +16,13 @@ import threading
import time
from time import monotonic
import types
from typing import Any
from typing import Any, NoReturn
from unittest.mock import AsyncMock, Mock, patch
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
import voluptuous as vol
from homeassistant import auth, bootstrap, config_entries, core as ha, loader
from homeassistant import auth, bootstrap, config_entries, loader
from homeassistant.auth import (
auth_store,
models as auth_models,
@ -42,7 +42,15 @@ from homeassistant.const import (
STATE_OFF,
STATE_ON,
)
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, ServiceCall, State
from homeassistant.core import (
BLOCK_LOG_TIMEOUT,
CoreState,
Event,
HomeAssistant,
ServiceCall,
State,
callback,
)
from homeassistant.helpers import (
area_registry,
device_registry,
@ -57,7 +65,7 @@ from homeassistant.helpers import (
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.typing import ConfigType
from homeassistant.helpers.typing import ConfigType, StateType
from homeassistant.setup import setup_component
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as date_util
@ -161,7 +169,7 @@ def get_test_home_assistant():
# pylint: disable=protected-access
async def async_test_home_assistant(event_loop, load_registries=True):
"""Return a Home Assistant object pointing at test config dir."""
hass = ha.HomeAssistant()
hass = HomeAssistant()
store = auth_store.AuthStore(hass)
hass.auth = auth.AuthManager(hass, store, {}, {})
ensure_auth_manager_loaded(hass.auth)
@ -308,7 +316,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
await hass.async_block_till_done()
hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None
hass.state = ha.CoreState.running
hass.state = CoreState.running
# Mock async_start
orig_start = hass.async_start
@ -321,7 +329,7 @@ async def async_test_home_assistant(event_loop, load_registries=True):
hass.async_start = mock_async_start
@ha.callback
@callback
def clear_instance(event):
"""Clear global instance."""
INSTANCES.remove(hass)
@ -337,7 +345,7 @@ def async_mock_service(
"""Set up a fake service & return a calls log list to this service."""
calls = []
@ha.callback
@callback
def mock_service_log(call): # pylint: disable=unnecessary-lambda
"""Mock service call."""
calls.append(call)
@ -350,7 +358,7 @@ def async_mock_service(
mock_service = threadsafe_callback_factory(async_mock_service)
@ha.callback
@callback
def async_mock_intent(hass, intent_typ):
"""Set up a fake intent handler."""
intents = []
@ -368,7 +376,7 @@ def async_mock_intent(hass, intent_typ):
return intents
@ha.callback
@callback
def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
"""Fire the MQTT message."""
# Local import to avoid processing MQTT modules when running a testcase
@ -384,7 +392,7 @@ def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
@ha.callback
@callback
def async_fire_time_changed_exact(
hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False
) -> None:
@ -403,7 +411,7 @@ def async_fire_time_changed_exact(
_async_fire_time_changed(hass, utc_datetime, fire_all)
@ha.callback
@callback
def async_fire_time_changed(
hass: HomeAssistant, datetime_: datetime | None = None, fire_all: bool = False
) -> None:
@ -432,7 +440,7 @@ def async_fire_time_changed(
_async_fire_time_changed(hass, utc_datetime, fire_all)
@ha.callback
@callback
def _async_fire_time_changed(
hass: HomeAssistant, utc_datetime: datetime | None, fire_all: bool
) -> None:
@ -491,7 +499,7 @@ def mock_state_change_event(
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
@ha.callback
@callback
def mock_component(hass: HomeAssistant, component: str) -> None:
"""Mock a component is setup."""
if component in hass.config.components:
@ -624,7 +632,7 @@ async def register_auth_provider(
return provider
@ha.callback
@callback
def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded."""
store = auth_mgr._store
@ -995,7 +1003,7 @@ def init_recorder_component(hass, add_config=None, db_url="sqlite://"):
)
def mock_restore_cache(hass, states):
def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None:
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
data = restore_state.RestoreStateData(hass)
@ -1020,7 +1028,9 @@ def mock_restore_cache(hass, states):
hass.data[key] = data
def mock_restore_cache_with_extra_data(hass, states):
def mock_restore_cache_with_extra_data(
hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]]
) -> None:
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
data = restore_state.RestoreStateData(hass)
@ -1048,7 +1058,7 @@ def mock_restore_cache_with_extra_data(hass, states):
class MockEntity(entity.Entity):
"""Mock Entity class."""
def __init__(self, **values):
def __init__(self, **values: Any) -> None:
"""Initialize an entity."""
self._values = values
@ -1056,86 +1066,86 @@ class MockEntity(entity.Entity):
self.entity_id = values["entity_id"]
@property
def available(self):
def available(self) -> bool:
"""Return True if entity is available."""
return self._handle("available")
@property
def capability_attributes(self):
def capability_attributes(self) -> Mapping[str, Any] | None:
"""Info about capabilities."""
return self._handle("capability_attributes")
@property
def device_class(self):
def device_class(self) -> str | None:
"""Info how device should be classified."""
return self._handle("device_class")
@property
def device_info(self):
def device_info(self) -> entity.DeviceInfo | None:
"""Info how it links to a device."""
return self._handle("device_info")
@property
def entity_category(self):
def entity_category(self) -> entity.EntityCategory | None:
"""Return the entity category."""
return self._handle("entity_category")
@property
def has_entity_name(self):
def has_entity_name(self) -> bool:
"""Return the has_entity_name name flag."""
return self._handle("has_entity_name")
@property
def entity_registry_enabled_default(self):
def entity_registry_enabled_default(self) -> bool:
"""Return if the entity should be enabled when first added to the entity registry."""
return self._handle("entity_registry_enabled_default")
@property
def entity_registry_visible_default(self):
def entity_registry_visible_default(self) -> bool:
"""Return if the entity should be visible when first added to the entity registry."""
return self._handle("entity_registry_visible_default")
@property
def icon(self):
def icon(self) -> str | None:
"""Return the suggested icon."""
return self._handle("icon")
@property
def name(self):
def name(self) -> str | None:
"""Return the name of the entity."""
return self._handle("name")
@property
def should_poll(self):
def should_poll(self) -> bool:
"""Return the ste of the polling."""
return self._handle("should_poll")
@property
def state(self):
def state(self) -> StateType:
"""Return the state of the entity."""
return self._handle("state")
@property
def supported_features(self):
def supported_features(self) -> int | None:
"""Info about supported features."""
return self._handle("supported_features")
@property
def translation_key(self):
def translation_key(self) -> str | None:
"""Return the translation key."""
return self._handle("translation_key")
@property
def unique_id(self):
def unique_id(self) -> str | None:
"""Return the unique ID of the entity."""
return self._handle("unique_id")
@property
def unit_of_measurement(self):
def unit_of_measurement(self) -> str | None:
"""Info on the units the entity state is in."""
return self._handle("unit_of_measurement")
def _handle(self, attr):
def _handle(self, attr: str) -> Any:
"""Return attribute value."""
if attr in self._values:
return self._values[attr]
@ -1202,7 +1212,7 @@ def mock_storage(data=None):
yield data
async def flush_store(store):
async def flush_store(store: storage.Store) -> None:
"""Make sure all delayed writes of a store are written."""
if store._data is None:
return
@ -1212,12 +1222,14 @@ async def flush_store(store):
await store._async_handle_write_data()
async def get_system_health_info(hass, domain):
async def get_system_health_info(hass: HomeAssistant, domain: str) -> dict[str, Any]:
"""Get system health info."""
return await hass.data["system_health"][domain].info_callback(hass)
def mock_integration(hass, module, built_in=True):
def mock_integration(
hass: HomeAssistant, module: MockModule, built_in: bool = True
) -> loader.Integration:
"""Mock an integration."""
integration = loader.Integration(
hass,
@ -1228,7 +1240,7 @@ def mock_integration(hass, module, built_in=True):
module.mock_manifest(),
)
def mock_import_platform(platform_name):
def mock_import_platform(platform_name: str) -> NoReturn:
raise ImportError(
f"Mocked unable to import platform '{platform_name}'",
name=f"{integration.pkg_path}.{platform_name}",
@ -1243,7 +1255,9 @@ def mock_integration(hass, module, built_in=True):
return integration
def mock_entity_platform(hass, platform_path, module):
def mock_entity_platform(
hass: HomeAssistant, platform_path: str, module: MockPlatform | None
) -> None:
"""Mock a entity platform.
platform_path is in form light.hue. Will create platform
@ -1253,7 +1267,9 @@ def mock_entity_platform(hass, platform_path, module):
mock_platform(hass, f"{platform_name}.{domain}", module)
def mock_platform(hass, platform_path, module=None):
def mock_platform(
hass: HomeAssistant, platform_path: str, module: Mock | MockPlatform | None = None
) -> None:
"""Mock a platform.
platform_path is in form hue.config_flow.
@ -1269,12 +1285,12 @@ def mock_platform(hass, platform_path, module=None):
module_cache[platform_path] = module or Mock()
def async_capture_events(hass, event_name):
def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]:
"""Create a helper that captures events."""
events = []
@ha.callback
def capture_events(event):
@callback
def capture_events(event: Event) -> None:
events.append(event)
hass.bus.async_listen(event_name, capture_events)
@ -1282,13 +1298,13 @@ def async_capture_events(hass, event_name):
return events
@ha.callback
def async_mock_signal(hass, signal):
@callback
def async_mock_signal(hass: HomeAssistant, signal: str) -> list[tuple[Any]]:
"""Catch all dispatches to a signal."""
calls = []
@ha.callback
def mock_signal_handler(*args):
@callback
def mock_signal_handler(*args: Any) -> None:
"""Mock service call."""
calls.append(args)
@ -1297,7 +1313,7 @@ def async_mock_signal(hass, signal):
return calls
def assert_lists_same(a, b):
def assert_lists_same(a: list[Any], b: list[Any]) -> None:
"""Compare two lists, ignoring order.
Check both that all items in a are in b and that all items in b are in a,
@ -1322,17 +1338,17 @@ class _HA_ANY:
_other = _SENTINEL
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Test equal."""
self._other = other
return True
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
"""Test not equal."""
self._other = other
return False
def __repr__(self):
def __repr__(self) -> str:
"""Return repr() other to not show up in pytest quality diffs."""
if self._other is _SENTINEL:
return "<ANY>"
@ -1342,7 +1358,7 @@ class _HA_ANY:
ANY = _HA_ANY()
def raise_contains_mocks(val):
def raise_contains_mocks(val: Any) -> None:
"""Raise for mocks."""
if isinstance(val, Mock):
raise ValueError