Migrate restore_state helper to use registry loading pattern (#93773)

* Migrate restore_state helper to use registry loading pattern

As more entities have started using restore_state over time, it
has become a startup bottleneck as each entity being added is
creating a task to load restore state data that is already loaded
since it is a singleton

We now use the same pattern as the registry helpers

* fix refactoring error -- guess I am tired

* fixes

* fix tests

* fix more

* fix more

* fix zha tests

* fix zha tests

* comments

* fix error

* add missing coverage

* s/DATA_RESTORE_STATE_TASK/DATA_RESTORE_STATE/g
This commit is contained in:
J. Nick Koston 2023-05-30 20:48:17 -05:00 committed by GitHub
parent b91c6911d9
commit fba826ae9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 147 additions and 86 deletions

View file

@ -32,6 +32,7 @@ from .helpers import (
entity_registry,
issue_registry,
recorder,
restore_state,
template,
)
from .helpers.dispatcher import async_dispatcher_send
@ -248,6 +249,7 @@ async def load_registries(hass: core.HomeAssistant) -> None:
issue_registry.async_load(hass),
hass.async_add_executor_job(_cache_uname_processor),
template.async_load_custom_templates(hass),
restore_state.async_load(hass),
)

View file

@ -2,7 +2,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from datetime import datetime, timedelta
import logging
from typing import Any, cast
@ -18,10 +17,9 @@ from . import start
from .entity import Entity
from .event import async_track_time_interval
from .json import JSONEncoder
from .singleton import singleton
from .storage import Store
DATA_RESTORE_STATE_TASK = "restore_state_task"
DATA_RESTORE_STATE = "restore_state"
_LOGGER = logging.getLogger(__name__)
@ -96,31 +94,25 @@ class StoredState:
)
async def async_load(hass: HomeAssistant) -> None:
"""Load the restore state task."""
hass.data[DATA_RESTORE_STATE] = await RestoreStateData.async_get_instance(hass)
@callback
def async_get(hass: HomeAssistant) -> RestoreStateData:
"""Get the restore state data helper."""
return cast(RestoreStateData, hass.data[DATA_RESTORE_STATE])
class RestoreStateData:
"""Helper class for managing the helper saved data."""
@staticmethod
@singleton(DATA_RESTORE_STATE_TASK)
async def async_get_instance(hass: HomeAssistant) -> RestoreStateData:
"""Get the singleton instance of this data helper."""
"""Get the instance of this data helper."""
data = RestoreStateData(hass)
try:
stored_states = await data.store.async_load()
except HomeAssistantError as exc:
_LOGGER.error("Error loading last states", exc_info=exc)
stored_states = None
if stored_states is None:
_LOGGER.debug("Not creating cache - no saved states found")
data.last_states = {}
else:
data.last_states = {
item["state"]["entity_id"]: StoredState.from_dict(item)
for item in stored_states
if valid_entity_id(item["state"]["entity_id"])
}
_LOGGER.debug("Created cache with %s", list(data.last_states))
await data.async_load()
async def hass_start(hass: HomeAssistant) -> None:
"""Start the restore state task."""
@ -133,8 +125,7 @@ class RestoreStateData:
@classmethod
async def async_save_persistent_states(cls, hass: HomeAssistant) -> None:
"""Dump states now."""
data = await cls.async_get_instance(hass)
await data.async_dump_states()
await async_get(hass).async_dump_states()
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the restore state data class."""
@ -145,6 +136,25 @@ class RestoreStateData:
self.last_states: dict[str, StoredState] = {}
self.entities: dict[str, RestoreEntity] = {}
async def async_load(self) -> None:
"""Load the instance of this data helper."""
try:
stored_states = await self.store.async_load()
except HomeAssistantError as exc:
_LOGGER.error("Error loading last states", exc_info=exc)
stored_states = None
if stored_states is None:
_LOGGER.debug("Not creating cache - no saved states found")
self.last_states = {}
else:
self.last_states = {
item["state"]["entity_id"]: StoredState.from_dict(item)
for item in stored_states
if valid_entity_id(item["state"]["entity_id"])
}
_LOGGER.debug("Created cache with %s", list(self.last_states))
@callback
def async_get_stored_states(self) -> list[StoredState]:
"""Get the set of states which should be stored.
@ -288,21 +298,18 @@ class RestoreEntity(Entity):
async def async_internal_added_to_hass(self) -> None:
"""Register this entity as a restorable entity."""
_, data = await asyncio.gather(
super().async_internal_added_to_hass(),
RestoreStateData.async_get_instance(self.hass),
)
data.async_restore_entity_added(self)
await super().async_internal_added_to_hass()
async_get(self.hass).async_restore_entity_added(self)
async def async_internal_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
_, data = await asyncio.gather(
super().async_internal_will_remove_from_hass(),
RestoreStateData.async_get_instance(self.hass),
async_get(self.hass).async_restore_entity_removed(
self.entity_id, self.extra_restore_state_data
)
data.async_restore_entity_removed(self.entity_id, self.extra_restore_state_data)
await super().async_internal_will_remove_from_hass()
async def _async_get_restored_data(self) -> StoredState | None:
@callback
def _async_get_restored_data(self) -> StoredState | None:
"""Get data stored for an entity, if any."""
if self.hass is None or self.entity_id is None:
# Return None if this entity isn't added to hass yet
@ -310,20 +317,17 @@ class RestoreEntity(Entity):
"Cannot get last state. Entity not added to hass"
)
return None
data = await RestoreStateData.async_get_instance(self.hass)
if self.entity_id not in data.last_states:
return None
return data.last_states[self.entity_id]
return async_get(self.hass).last_states.get(self.entity_id)
async def async_get_last_state(self) -> State | None:
"""Get the entity state from the previous run."""
if (stored_state := await self._async_get_restored_data()) is None:
if (stored_state := self._async_get_restored_data()) is None:
return None
return stored_state.state
async def async_get_last_extra_data(self) -> ExtraStoredData | None:
"""Get the entity specific state data from the previous run."""
if (stored_state := await self._async_get_restored_data()) is None:
if (stored_state := self._async_get_restored_data()) is None:
return None
return stored_state.extra_data

View file

@ -61,6 +61,7 @@ from homeassistant.helpers import (
issue_registry as ir,
recorder as recorder_helper,
restore_state,
restore_state as rs,
storage,
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -251,12 +252,20 @@ async def async_test_home_assistant(event_loop, load_registries=True):
# Load the registries
entity.async_setup(hass)
if load_registries:
with patch("homeassistant.helpers.storage.Store.async_load", return_value=None):
with patch(
"homeassistant.helpers.storage.Store.async_load", return_value=None
), patch(
"homeassistant.helpers.restore_state.RestoreStateData.async_setup_dump",
return_value=None,
), patch(
"homeassistant.helpers.restore_state.start.async_at_start"
):
await asyncio.gather(
ar.async_load(hass),
dr.async_load(hass),
er.async_load(hass),
ir.async_load(hass),
rs.async_load(hass),
)
hass.data[bootstrap.DATA_REGISTRIES_LOADED] = None
@ -1010,7 +1019,7 @@ def init_recorder_component(hass, add_config=None, db_url="sqlite://"):
def mock_restore_cache(hass: HomeAssistant, states: Sequence[State]) -> None:
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
key = restore_state.DATA_RESTORE_STATE
data = restore_state.RestoreStateData(hass)
now = dt_util.utcnow()
@ -1037,7 +1046,7 @@ def mock_restore_cache_with_extra_data(
hass: HomeAssistant, states: Sequence[tuple[State, Mapping[str, Any]]]
) -> None:
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
key = restore_state.DATA_RESTORE_STATE
data = restore_state.RestoreStateData(hass)
now = dt_util.utcnow()
@ -1060,6 +1069,26 @@ def mock_restore_cache_with_extra_data(
hass.data[key] = data
async def async_mock_restore_state_shutdown_restart(
hass: HomeAssistant,
) -> restore_state.RestoreStateData:
"""Mock shutting down and saving restore state and restoring."""
data = restore_state.async_get(hass)
await data.async_dump_states()
await async_mock_load_restore_state_from_storage(hass)
return data
async def async_mock_load_restore_state_from_storage(
hass: HomeAssistant,
) -> None:
"""Mock loading restore state from storage.
hass_storage must already be mocked.
"""
await restore_state.async_get(hass).async_load()
class MockEntity(entity.Entity):
"""Mock Entity class."""

View file

@ -34,7 +34,10 @@ from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
from tests.common import mock_restore_cache_with_extra_data
from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)
class MockDefaultNumberEntity(NumberEntity):
@ -635,7 +638,7 @@ async def test_restore_number_save_state(
await hass.async_block_till_done()
# Trigger saving state
await hass.async_stop()
await async_mock_restore_state_shutdown_restart(hass)
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]

View file

@ -35,7 +35,10 @@ from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
from tests.common import mock_restore_cache_with_extra_data
from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)
@pytest.mark.parametrize(
@ -397,7 +400,7 @@ async def test_restore_sensor_save_state(
await hass.async_block_till_done()
# Trigger saving state
await hass.async_stop()
await async_mock_restore_state_shutdown_restart(hass)
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]

View file

@ -20,7 +20,10 @@ from homeassistant.core import HomeAssistant, ServiceCall, State
from homeassistant.helpers.restore_state import STORAGE_KEY as RESTORE_STATE_KEY
from homeassistant.setup import async_setup_component
from tests.common import mock_restore_cache_with_extra_data
from tests.common import (
async_mock_restore_state_shutdown_restart,
mock_restore_cache_with_extra_data,
)
class MockTextEntity(TextEntity):
@ -141,7 +144,7 @@ async def test_restore_number_save_state(
await hass.async_block_till_done()
# Trigger saving state
await hass.async_stop()
await async_mock_restore_state_shutdown_restart(hass)
assert len(hass_storage[RESTORE_STATE_KEY]["data"]) == 1
state = hass_storage[RESTORE_STATE_KEY]["data"][0]["state"]

View file

@ -47,11 +47,7 @@ from homeassistant.const import (
from homeassistant.core import Context, CoreState, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.restore_state import (
DATA_RESTORE_STATE_TASK,
RestoreStateData,
StoredState,
)
from homeassistant.helpers.restore_state import StoredState, async_get
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
@ -838,12 +834,9 @@ async def test_restore_idle(hass: HomeAssistant) -> None:
utc_now,
)
data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
data = async_get(hass)
await data.store.async_save([stored_state.as_dict()])
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
await data.async_load()
entity = Timer.from_storage(
{
@ -878,12 +871,9 @@ async def test_restore_paused(hass: HomeAssistant) -> None:
utc_now,
)
data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
data = async_get(hass)
await data.store.async_save([stored_state.as_dict()])
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
await data.async_load()
entity = Timer.from_storage(
{
@ -922,12 +912,9 @@ async def test_restore_active_resume(hass: HomeAssistant) -> None:
utc_now,
)
data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
data = async_get(hass)
await data.store.async_save([stored_state.as_dict()])
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
await data.async_load()
entity = Timer.from_storage(
{
@ -973,12 +960,9 @@ async def test_restore_active_finished_outside_grace(hass: HomeAssistant) -> Non
utc_now,
)
data = await RestoreStateData.async_get_instance(hass)
await hass.async_block_till_done()
data = async_get(hass)
await data.store.async_save([stored_state.as_dict()])
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
await data.async_load()
entity = Timer.from_storage(
{

View file

@ -21,6 +21,8 @@ from .common import (
)
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
from tests.common import async_mock_load_restore_state_from_storage
DEVICE_IAS = {
1: {
SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID,
@ -186,6 +188,7 @@ async def test_binary_sensor_migration_not_migrated(
entity_id = "binary_sensor.fakemanufacturer_fakemodel_iaszone"
core_rs(entity_id, state=restored_state, attributes={}) # migration sensor state
await async_mock_load_restore_state_from_storage(hass)
zigpy_device = zigpy_device_mock(DEVICE_IAS)
zha_device = await zha_device_restored(zigpy_device)
@ -208,6 +211,7 @@ async def test_binary_sensor_migration_already_migrated(
entity_id = "binary_sensor.fakemanufacturer_fakemodel_iaszone"
core_rs(entity_id, state=STATE_OFF, attributes={"migrated_to_cache": True})
await async_mock_load_restore_state_from_storage(hass)
zigpy_device = zigpy_device_mock(DEVICE_IAS)
@ -243,6 +247,7 @@ async def test_onoff_binary_sensor_restore_state(
entity_id = "binary_sensor.fakemanufacturer_fakemodel_opening"
core_rs(entity_id, state=restored_state, attributes={})
await async_mock_load_restore_state_from_storage(hass)
zigpy_device = zigpy_device_mock(DEVICE_ONOFF)
zha_device = await zha_device_restored(zigpy_device)

View file

@ -26,6 +26,8 @@ from homeassistant.util import dt as dt_util
from .common import async_enable_traffic, find_entity_id, send_attributes_report
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE
from tests.common import async_mock_load_restore_state_from_storage
@pytest.fixture(autouse=True)
def select_select_only():
@ -176,6 +178,7 @@ async def test_select_restore_state(
entity_id = "select.fakemanufacturer_fakemodel_default_siren_tone"
core_rs(entity_id, state="Burglar")
await async_mock_load_restore_state_from_storage(hass)
zigpy_device = zigpy_device_mock(
{

View file

@ -47,6 +47,8 @@ from .common import (
)
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
from tests.common import async_mock_load_restore_state_from_storage
ENTITY_ID_PREFIX = "sensor.fakemanufacturer_fakemodel_{}"
@ -530,6 +532,7 @@ def core_rs(hass_storage):
],
)
async def test_temp_uom(
hass: HomeAssistant,
uom,
raw_temp,
expected,
@ -544,6 +547,7 @@ async def test_temp_uom(
entity_id = "sensor.fake1026_fakemodel1026_004f3202_temperature"
if restore:
core_rs(entity_id, uom, state=(expected - 2))
await async_mock_load_restore_state_from_storage(hass)
hass = await hass_ms(
CONF_UNIT_SYSTEM_METRIC

View file

@ -8,11 +8,13 @@ from homeassistant.core import CoreState, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.restore_state import (
DATA_RESTORE_STATE_TASK,
DATA_RESTORE_STATE,
STORAGE_KEY,
RestoreEntity,
RestoreStateData,
StoredState,
async_get,
async_load,
)
from homeassistant.util import dt as dt_util
@ -28,12 +30,25 @@ async def test_caching_data(hass: HomeAssistant) -> None:
StoredState(State("input_boolean.b2", "on"), None, now),
]
data = await RestoreStateData.async_get_instance(hass)
data = async_get(hass)
await hass.async_block_till_done()
await data.store.async_save([state.as_dict() for state in stored_states])
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
hass.data.pop(DATA_RESTORE_STATE)
with patch(
"homeassistant.helpers.restore_state.Store.async_load",
side_effect=HomeAssistantError,
):
# Failure to load should not be treated as fatal
await async_load(hass)
data = async_get(hass)
assert data.last_states == {}
await async_load(hass)
data = async_get(hass)
entity = RestoreEntity()
entity.hass = hass
@ -55,12 +70,14 @@ async def test_caching_data(hass: HomeAssistant) -> None:
async def test_periodic_write(hass: HomeAssistant) -> None:
"""Test that we write periodiclly but not after stop."""
data = await RestoreStateData.async_get_instance(hass)
data = async_get(hass)
await hass.async_block_till_done()
await data.store.async_save([])
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
hass.data.pop(DATA_RESTORE_STATE)
await async_load(hass)
data = async_get(hass)
entity = RestoreEntity()
entity.hass = hass
@ -101,12 +118,14 @@ async def test_periodic_write(hass: HomeAssistant) -> None:
async def test_save_persistent_states(hass: HomeAssistant) -> None:
"""Test that we cancel the currently running job, save the data, and verify the perdiodic job continues."""
data = await RestoreStateData.async_get_instance(hass)
data = async_get(hass)
await hass.async_block_till_done()
await data.store.async_save([])
# Emulate a fresh load
hass.data.pop(DATA_RESTORE_STATE_TASK)
hass.data.pop(DATA_RESTORE_STATE)
await async_load(hass)
data = async_get(hass)
entity = RestoreEntity()
entity.hass = hass
@ -166,13 +185,15 @@ async def test_hass_starting(hass: HomeAssistant) -> None:
StoredState(State("input_boolean.b2", "on"), None, now),
]
data = await RestoreStateData.async_get_instance(hass)
data = async_get(hass)
await hass.async_block_till_done()
await data.store.async_save([state.as_dict() for state in stored_states])
# Emulate a fresh load
hass.state = CoreState.not_running
hass.data.pop(DATA_RESTORE_STATE_TASK)
hass.data.pop(DATA_RESTORE_STATE)
await async_load(hass)
data = async_get(hass)
entity = RestoreEntity()
entity.hass = hass
@ -223,7 +244,7 @@ async def test_dump_data(hass: HomeAssistant) -> None:
entity.entity_id = "input_boolean.b1"
await entity.async_internal_added_to_hass()
data = await RestoreStateData.async_get_instance(hass)
data = async_get(hass)
now = dt_util.utcnow()
data.last_states = {
"input_boolean.b0": StoredState(State("input_boolean.b0", "off"), None, now),
@ -297,7 +318,7 @@ async def test_dump_error(hass: HomeAssistant) -> None:
entity.entity_id = "input_boolean.b1"
await entity.async_internal_added_to_hass()
data = await RestoreStateData.async_get_instance(hass)
data = async_get(hass)
with patch(
"homeassistant.helpers.restore_state.Store.async_save",
@ -335,7 +356,7 @@ async def test_state_saved_on_remove(hass: HomeAssistant) -> None:
"input_boolean.b0", "on", {"complicated": {"value": {1, 2, now}}}
)
data = await RestoreStateData.async_get_instance(hass)
data = async_get(hass)
# No last states should currently be saved
assert not data.last_states