Restore states through a JSON store instead of recorder (#17270)

* Restore states through a JSON store

* Accept entity_id directly in restore state helper

* Keep states stored between runs for a limited time

* Remove warning
This commit is contained in:
Adam Mills 2018-11-28 07:16:43 -05:00 committed by Paulus Schoutsen
parent a039c3209b
commit 5c3a4e3d10
46 changed files with 493 additions and 422 deletions

View file

@ -21,7 +21,7 @@ from homeassistant.const import (
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import track_point_in_time
import homeassistant.util.dt as dt_util
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -116,7 +116,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
)])
class ManualAlarm(alarm.AlarmControlPanel):
class ManualAlarm(alarm.AlarmControlPanel, RestoreEntity):
"""
Representation of an alarm status.
@ -310,7 +310,7 @@ class ManualAlarm(alarm.AlarmControlPanel):
async def async_added_to_hass(self):
"""Run when entity about to be added to hass."""
state = await async_get_last_state(self.hass, self.entity_id)
state = await self.async_get_last_state()
if state:
self._state = state.state
self._state_ts = state.last_updated

View file

@ -108,8 +108,7 @@ class MqttAlarm(MqttAvailability, MqttDiscoveryUpdate,
async def async_added_to_hass(self):
"""Subscribe mqtt events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
await self._subscribe_topics()
async def discovery_update(self, discovery_payload):

View file

@ -21,7 +21,7 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import extract_domain_configs, script, condition
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util.dt import utcnow
import homeassistant.helpers.config_validation as cv
@ -182,7 +182,7 @@ async def async_setup(hass, config):
return True
class AutomationEntity(ToggleEntity):
class AutomationEntity(ToggleEntity, RestoreEntity):
"""Entity to show status of entity."""
def __init__(self, automation_id, name, async_attach_triggers, cond_func,
@ -227,12 +227,13 @@ class AutomationEntity(ToggleEntity):
async def async_added_to_hass(self) -> None:
"""Startup with initial state or previous state."""
await super().async_added_to_hass()
if self._initial_state is not None:
enable_automation = self._initial_state
_LOGGER.debug("Automation %s initial state %s from config "
"initial_state", self.entity_id, enable_automation)
else:
state = await async_get_last_state(self.hass, self.entity_id)
state = await self.async_get_last_state()
if state:
enable_automation = state.state == STATE_ON
self._last_triggered = state.attributes.get('last_triggered')
@ -291,6 +292,7 @@ class AutomationEntity(ToggleEntity):
async def async_will_remove_from_hass(self):
"""Remove listeners when removing automation from HASS."""
await super().async_will_remove_from_hass()
await self.async_turn_off()
async def async_enable(self):

View file

@ -102,8 +102,7 @@ class MqttBinarySensor(MqttAvailability, MqttDiscoveryUpdate,
async def async_added_to_hass(self):
"""Subscribe mqtt events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
await self._subscribe_topics()
async def discovery_update(self, discovery_payload):

View file

@ -23,7 +23,7 @@ from homeassistant.helpers import condition
from homeassistant.helpers.event import (
async_track_state_change, async_track_time_interval)
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -96,7 +96,7 @@ async def async_setup_platform(hass, config, async_add_entities,
precision)])
class GenericThermostat(ClimateDevice):
class GenericThermostat(ClimateDevice, RestoreEntity):
"""Representation of a Generic Thermostat device."""
def __init__(self, hass, name, heater_entity_id, sensor_entity_id,
@ -155,8 +155,9 @@ class GenericThermostat(ClimateDevice):
async def async_added_to_hass(self):
"""Run when entity about to be added."""
await super().async_added_to_hass()
# Check If we have an old state
old_state = await async_get_last_state(self.hass, self.entity_id)
old_state = await self.async_get_last_state()
if old_state is not None:
# If we have no initial temperature, restore
if self._target_temp is None:

View file

@ -221,8 +221,7 @@ class MqttClimate(MqttAvailability, MqttDiscoveryUpdate, ClimateDevice):
async def async_added_to_hass(self):
"""Handle being added to home assistant."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
await self._subscribe_topics()
async def discovery_update(self, discovery_payload):

View file

@ -10,9 +10,8 @@ import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -86,7 +85,7 @@ async def async_setup(hass, config):
return True
class Counter(Entity):
class Counter(RestoreEntity):
"""Representation of a counter."""
def __init__(self, object_id, name, initial, restore, step, icon):
@ -128,10 +127,11 @@ class Counter(Entity):
async def async_added_to_hass(self):
"""Call when entity about to be added to Home Assistant."""
await super().async_added_to_hass()
# __init__ will set self._state to self._initial, only override
# if needed.
if self._restore:
state = await async_get_last_state(self.hass, self.entity_id)
state = await self.async_get_last_state()
if state is not None:
self._state = int(state.state)

View file

@ -205,8 +205,7 @@ class MqttCover(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
async def async_added_to_hass(self):
"""Subscribe MQTT events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
await self._subscribe_topics()
async def discovery_update(self, discovery_payload):

View file

@ -22,9 +22,8 @@ from homeassistant.components.zone.zone import async_active_zone
from homeassistant.config import load_yaml_config_file, async_log_exception
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import GPSType, ConfigType, HomeAssistantType
import homeassistant.helpers.config_validation as cv
from homeassistant import util
@ -396,7 +395,7 @@ class DeviceTracker:
await asyncio.wait(tasks, loop=self.hass.loop)
class Device(Entity):
class Device(RestoreEntity):
"""Represent a tracked device."""
host_name = None # type: str
@ -564,7 +563,8 @@ class Device(Entity):
async def async_added_to_hass(self):
"""Add an entity."""
state = await async_get_last_state(self.hass, self.entity_id)
await super().async_added_to_hass()
state = await self.async_get_last_state()
if not state:
return
self._state = state.state

View file

@ -151,8 +151,7 @@ class MqttFan(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
await self._subscribe_topics()
async def discovery_update(self, discovery_payload):

View file

@ -38,20 +38,6 @@ SIGNIFICANT_DOMAINS = ('thermostat', 'climate')
IGNORE_DOMAINS = ('zone', 'scene',)
def last_recorder_run(hass):
"""Retrieve the last closed recorder run from the database."""
from homeassistant.components.recorder.models import RecorderRuns
with session_scope(hass=hass) as session:
res = (session.query(RecorderRuns)
.filter(RecorderRuns.end.isnot(None))
.order_by(RecorderRuns.end.desc()).first())
if res is None:
return None
session.expunge(res)
return res
def get_significant_states(hass, start_time, end_time=None, entity_ids=None,
filters=None, include_start_time_state=True):
"""

View file

@ -15,7 +15,7 @@ from homeassistant.loader import bind_hass
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
DOMAIN = 'input_boolean'
@ -84,7 +84,7 @@ async def async_setup(hass, config):
return True
class InputBoolean(ToggleEntity):
class InputBoolean(ToggleEntity, RestoreEntity):
"""Representation of a boolean input."""
def __init__(self, object_id, name, initial, icon):
@ -117,10 +117,11 @@ class InputBoolean(ToggleEntity):
async def async_added_to_hass(self):
"""Call when entity about to be added to hass."""
# If not None, we got an initial value.
await super().async_added_to_hass()
if self._state is not None:
return
state = await async_get_last_state(self.hass, self.entity_id)
state = await self.async_get_last_state()
self._state = state and state.state == STATE_ON
async def async_turn_on(self, **kwargs):

View file

@ -11,9 +11,8 @@ import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
@ -97,7 +96,7 @@ async def async_setup(hass, config):
return True
class InputDatetime(Entity):
class InputDatetime(RestoreEntity):
"""Representation of a datetime input."""
def __init__(self, object_id, name, has_date, has_time, icon, initial):
@ -112,6 +111,7 @@ class InputDatetime(Entity):
async def async_added_to_hass(self):
"""Run when entity about to be added."""
await super().async_added_to_hass()
restore_val = None
# Priority 1: Initial State
@ -120,7 +120,7 @@ class InputDatetime(Entity):
# Priority 2: Old state
if restore_val is None:
old_state = await async_get_last_state(self.hass, self.entity_id)
old_state = await self.async_get_last_state()
if old_state is not None:
restore_val = old_state.state

View file

@ -11,9 +11,8 @@ import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, CONF_ICON, CONF_NAME, CONF_MODE)
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -123,7 +122,7 @@ async def async_setup(hass, config):
return True
class InputNumber(Entity):
class InputNumber(RestoreEntity):
"""Representation of a slider."""
def __init__(self, object_id, name, initial, minimum, maximum, step, icon,
@ -178,10 +177,11 @@ class InputNumber(Entity):
async def async_added_to_hass(self):
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
if self._current_value is not None:
return
state = await async_get_last_state(self.hass, self.entity_id)
state = await self.async_get_last_state()
value = state and float(state.state)
# Check against None because value can be 0

View file

@ -10,9 +10,8 @@ import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID, CONF_ICON, CONF_NAME
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -116,7 +115,7 @@ async def async_setup(hass, config):
return True
class InputSelect(Entity):
class InputSelect(RestoreEntity):
"""Representation of a select input."""
def __init__(self, object_id, name, initial, options, icon):
@ -129,10 +128,11 @@ class InputSelect(Entity):
async def async_added_to_hass(self):
"""Run when entity about to be added."""
await super().async_added_to_hass()
if self._current_option is not None:
return
state = await async_get_last_state(self.hass, self.entity_id)
state = await self.async_get_last_state()
if not state or state.state not in self._options:
self._current_option = self._options[0]
else:

View file

@ -11,9 +11,8 @@ import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, CONF_ICON, CONF_NAME, CONF_MODE)
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -104,7 +103,7 @@ async def async_setup(hass, config):
return True
class InputText(Entity):
class InputText(RestoreEntity):
"""Represent a text box."""
def __init__(self, object_id, name, initial, minimum, maximum, icon,
@ -157,10 +156,11 @@ class InputText(Entity):
async def async_added_to_hass(self):
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
if self._current_value is not None:
return
state = await async_get_last_state(self.hass, self.entity_id)
state = await self.async_get_last_state()
value = state and state.state
# Check against None because value can be 0

View file

@ -18,7 +18,7 @@ from homeassistant.components.light import (
import homeassistant.helpers.config_validation as cv
from homeassistant.util.color import (
color_temperature_mired_to_kelvin, color_hs_to_RGB)
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
REQUIREMENTS = ['limitlessled==1.1.3']
@ -157,7 +157,7 @@ def state(new_state):
return decorator
class LimitlessLEDGroup(Light):
class LimitlessLEDGroup(Light, RestoreEntity):
"""Representation of a LimitessLED group."""
def __init__(self, group, config):
@ -189,7 +189,8 @@ class LimitlessLEDGroup(Light):
async def async_added_to_hass(self):
"""Handle entity about to be added to hass event."""
last_state = await async_get_last_state(self.hass, self.entity_id)
await super().async_added_to_hass()
last_state = await self.async_get_last_state()
if last_state:
self._is_on = (last_state.state == STATE_ON)
self._brightness = last_state.attributes.get('brightness')

View file

@ -22,7 +22,7 @@ from homeassistant.components.mqtt import (
CONF_AVAILABILITY_TOPIC, CONF_COMMAND_TOPIC, CONF_PAYLOAD_AVAILABLE,
CONF_PAYLOAD_NOT_AVAILABLE, CONF_QOS, CONF_RETAIN, CONF_STATE_TOPIC,
MqttAvailability, MqttDiscoveryUpdate)
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
import homeassistant.helpers.config_validation as cv
import homeassistant.util.color as color_util
@ -166,7 +166,7 @@ async def async_setup_entity_basic(hass, config, async_add_entities,
)])
class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light):
class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light, RestoreEntity):
"""Representation of a MQTT light."""
def __init__(self, name, unique_id, effect_list, topic, templates,
@ -237,8 +237,7 @@ class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light):
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
templates = {}
for key, tpl in list(self._templates.items()):
@ -248,7 +247,7 @@ class MqttLight(MqttAvailability, MqttDiscoveryUpdate, Light):
tpl.hass = self.hass
templates[key] = tpl.async_render_with_possible_json_value
last_state = await async_get_last_state(self.hass, self.entity_id)
last_state = await self.async_get_last_state()
@callback
def state_received(topic, payload, qos):

View file

@ -25,7 +25,7 @@ from homeassistant.const import (
CONF_RGB, CONF_WHITE_VALUE, CONF_XY, STATE_ON)
from homeassistant.core import callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
import homeassistant.util.color as color_util
@ -121,7 +121,8 @@ async def async_setup_entity_json(hass: HomeAssistantType, config: ConfigType,
)])
class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light):
class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light,
RestoreEntity):
"""Representation of a MQTT JSON light."""
def __init__(self, name, unique_id, effect_list, topic, qos, retain,
@ -183,10 +184,9 @@ class MqttLightJson(MqttAvailability, MqttDiscoveryUpdate, Light):
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
last_state = await async_get_last_state(self.hass, self.entity_id)
last_state = await self.async_get_last_state()
@callback
def state_received(topic, payload, qos):

View file

@ -21,7 +21,7 @@ from homeassistant.components.mqtt import (
MqttAvailability)
import homeassistant.helpers.config_validation as cv
import homeassistant.util.color as color_util
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -102,7 +102,7 @@ async def async_setup_entity_template(hass, config, async_add_entities,
)])
class MqttTemplate(MqttAvailability, Light):
class MqttTemplate(MqttAvailability, Light, RestoreEntity):
"""Representation of a MQTT Template light."""
def __init__(self, hass, name, effect_list, topics, templates, optimistic,
@ -153,7 +153,7 @@ class MqttTemplate(MqttAvailability, Light):
"""Subscribe to MQTT events."""
await super().async_added_to_hass()
last_state = await async_get_last_state(self.hass, self.entity_id)
last_state = await self.async_get_last_state()
@callback
def state_received(topic, payload, qos):

View file

@ -111,8 +111,7 @@ class MqttLock(MqttAvailability, MqttDiscoveryUpdate, LockDevice):
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
@callback
def message_received(topic, payload, qos):

View file

@ -840,6 +840,7 @@ class MqttAvailability(Entity):
This method must be run in the event loop and returns a coroutine.
"""
await super().async_added_to_hass()
await self._availability_subscribe_topics()
async def availability_discovery_update(self, config: dict):
@ -900,6 +901,8 @@ class MqttDiscoveryUpdate(Entity):
async def async_added_to_hass(self) -> None:
"""Subscribe to discovery updates."""
await super().async_added_to_hass()
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.components.mqtt.discovery import (
ALREADY_DISCOVERED, MQTT_DISCOVERY_UPDATED)

View file

@ -28,7 +28,6 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import generate_filter
from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util
from homeassistant.loader import bind_hass
from . import migration, purge
from .const import DATA_INSTANCE
@ -83,12 +82,6 @@ CONFIG_SCHEMA = vol.Schema({
}, extra=vol.ALLOW_EXTRA)
@bind_hass
async def wait_connection_ready(hass):
"""Wait till the connection is ready."""
return await hass.data[DATA_INSTANCE].async_db_ready
def run_information(hass, point_in_time: Optional[datetime] = None):
"""Return information about current run.

View file

@ -10,9 +10,8 @@ import voluptuous as vol
from homeassistant.components.sensor import DOMAIN, PLATFORM_SCHEMA
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import track_time_change
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
import homeassistant.util.dt as dt_util
REQUIREMENTS = ['fastdotcom==0.0.3']
@ -51,7 +50,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
hass.services.register(DOMAIN, 'update_fastdotcom', update)
class SpeedtestSensor(Entity):
class SpeedtestSensor(RestoreEntity):
"""Implementation of a FAst.com sensor."""
def __init__(self, speedtest_data):
@ -86,7 +85,8 @@ class SpeedtestSensor(Entity):
async def async_added_to_hass(self):
"""Handle entity which will be added."""
state = await async_get_last_state(self.hass, self.entity_id)
await super().async_added_to_hass()
state = await self.async_get_last_state()
if not state:
return
self._state = state.state

View file

@ -119,8 +119,7 @@ class MqttSensor(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
await self._subscribe_topics()
async def discovery_update(self, discovery_payload):

View file

@ -11,9 +11,8 @@ import voluptuous as vol
from homeassistant.components.sensor import DOMAIN, PLATFORM_SCHEMA
from homeassistant.const import ATTR_ATTRIBUTION, CONF_MONITORED_CONDITIONS
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import track_time_change
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
import homeassistant.util.dt as dt_util
REQUIREMENTS = ['speedtest-cli==2.0.2']
@ -76,7 +75,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
hass.services.register(DOMAIN, 'update_speedtest', update)
class SpeedtestSensor(Entity):
class SpeedtestSensor(RestoreEntity):
"""Implementation of a speedtest.net sensor."""
def __init__(self, speedtest_data, sensor_type):
@ -137,7 +136,8 @@ class SpeedtestSensor(Entity):
async def async_added_to_hass(self):
"""Handle all entity which are about to be added."""
state = await async_get_last_state(self.hass, self.entity_id)
await super().async_added_to_hass()
state = await self.async_get_last_state()
if not state:
return
self._state = state.state

View file

@ -24,7 +24,7 @@ from homeassistant.components import mqtt, switch
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.typing import HomeAssistantType, ConfigType
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -102,8 +102,9 @@ async def _async_setup_entity(hass, config, async_add_entities,
async_add_entities([newswitch])
# pylint: disable=too-many-ancestors
class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
SwitchDevice):
SwitchDevice, RestoreEntity):
"""Representation of a switch that can be toggled using MQTT."""
def __init__(self, name, icon,
@ -136,8 +137,7 @@ class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
async def async_added_to_hass(self):
"""Subscribe to MQTT events."""
await MqttAvailability.async_added_to_hass(self)
await MqttDiscoveryUpdate.async_added_to_hass(self)
await super().async_added_to_hass()
@callback
def state_message_received(topic, payload, qos):
@ -161,8 +161,7 @@ class MqttSwitch(MqttAvailability, MqttDiscoveryUpdate, MqttEntityDeviceInfo,
self._qos)
if self._optimistic:
last_state = await async_get_last_state(self.hass,
self.entity_id)
last_state = await self.async_get_last_state()
if last_state:
self._state = last_state.state == STATE_ON

View file

@ -13,7 +13,7 @@ from homeassistant.components import pilight
from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA)
from homeassistant.const import (CONF_NAME, CONF_ID, CONF_SWITCHES, CONF_STATE,
CONF_PROTOCOL, STATE_ON)
from homeassistant.helpers.restore_state import async_get_last_state
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -97,7 +97,7 @@ class _ReceiveHandle:
switch.set_state(turn_on=turn_on, send_code=self.echo)
class PilightSwitch(SwitchDevice):
class PilightSwitch(SwitchDevice, RestoreEntity):
"""Representation of a Pilight switch."""
def __init__(self, hass, name, code_on, code_off, code_on_receive,
@ -123,7 +123,8 @@ class PilightSwitch(SwitchDevice):
async def async_added_to_hass(self):
"""Call when entity about to be added to hass."""
state = await async_get_last_state(self._hass, self.entity_id)
await super().async_added_to_hass()
state = await self.async_get_last_state()
if state:
self._state = state.state == STATE_ON

View file

@ -12,9 +12,9 @@ import voluptuous as vol
import homeassistant.util.dt as dt_util
import homeassistant.helpers.config_validation as cv
from homeassistant.const import (ATTR_ENTITY_ID, CONF_ICON, CONF_NAME)
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.helpers.restore_state import RestoreEntity
_LOGGER = logging.getLogger(__name__)
@ -97,7 +97,7 @@ async def async_setup(hass, config):
return True
class Timer(Entity):
class Timer(RestoreEntity):
"""Representation of a timer."""
def __init__(self, hass, object_id, name, icon, duration):
@ -146,8 +146,7 @@ class Timer(Entity):
if self._state is not None:
return
restore_state = self._hass.helpers.restore_state
state = await restore_state.async_get_last_state(self.entity_id)
state = await self.async_get_last_state()
self._state = state and state.state == state
async def async_start(self, duration):

View file

@ -363,10 +363,7 @@ class Entity:
async def async_remove(self):
"""Remove entity from Home Assistant."""
will_remove = getattr(self, 'async_will_remove_from_hass', None)
if will_remove:
await will_remove() # pylint: disable=not-callable
await self.async_will_remove_from_hass()
if self._on_remove is not None:
while self._on_remove:
@ -390,6 +387,12 @@ class Entity:
self.hass.async_create_task(readd())
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
def __eq__(self, other):
"""Return the comparison."""
if not isinstance(other, self.__class__):

View file

@ -346,8 +346,7 @@ class EntityPlatform:
self.entities[entity_id] = entity
entity.async_on_remove(lambda: self.entities.pop(entity_id))
if hasattr(entity, 'async_added_to_hass'):
await entity.async_added_to_hass()
await entity.async_added_to_hass()
await entity.async_update_ha_state()

View file

@ -2,97 +2,174 @@
import asyncio
import logging
from datetime import timedelta
from typing import Any, Dict, List, Set, Optional # noqa pylint_disable=unused-import
import async_timeout
from homeassistant.core import HomeAssistant, CoreState, callback
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.loader import bind_hass
from homeassistant.components.history import get_states, last_recorder_run
from homeassistant.components.recorder import (
wait_connection_ready, DOMAIN as _RECORDER)
from homeassistant.core import HomeAssistant, callback, State, CoreState
from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP)
import homeassistant.util.dt as dt_util
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.storage import Store # noqa pylint_disable=unused-import
DATA_RESTORE_STATE_TASK = 'restore_state_task'
RECORDER_TIMEOUT = 10
DATA_RESTORE_CACHE = 'restore_state_cache'
_LOCK = 'restore_lock'
_LOGGER = logging.getLogger(__name__)
STORAGE_KEY = 'core.restore_state'
STORAGE_VERSION = 1
# How long between periodically saving the current states to disk
STATE_DUMP_INTERVAL = timedelta(minutes=15)
# How long should a saved state be preserved if the entity no longer exists
STATE_EXPIRATION = timedelta(days=7)
class RestoreStateData():
"""Helper class for managing the helper saved data."""
@classmethod
async def async_get_instance(
cls, hass: HomeAssistant) -> 'RestoreStateData':
"""Get the singleton instance of this data helper."""
task = hass.data.get(DATA_RESTORE_STATE_TASK)
if task is None:
async def load_instance(hass: HomeAssistant) -> 'RestoreStateData':
"""Set up the restore state helper."""
data = cls(hass)
try:
states = await data.store.async_load()
except HomeAssistantError as exc:
_LOGGER.error("Error loading last states", exc_info=exc)
states = None
if states is None:
_LOGGER.debug('Not creating cache - no saved states found')
data.last_states = {}
else:
data.last_states = {
state['entity_id']: State.from_dict(state)
for state in states}
_LOGGER.debug(
'Created cache with %s', list(data.last_states))
if hass.state == CoreState.running:
data.async_setup_dump()
else:
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_START, data.async_setup_dump)
return data
task = hass.data[DATA_RESTORE_STATE_TASK] = hass.async_create_task(
load_instance(hass))
return await task
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the restore state data class."""
self.hass = hass # type: HomeAssistant
self.store = Store(hass, STORAGE_VERSION, STORAGE_KEY,
encoder=JSONEncoder) # type: Store
self.last_states = {} # type: Dict[str, State]
self.entity_ids = set() # type: Set[str]
def async_get_states(self) -> List[State]:
"""Get the set of states which should be stored.
This includes the states of all registered entities, as well as the
stored states from the previous run, which have not been created as
entities on this run, and have not expired.
"""
all_states = self.hass.states.async_all()
current_entity_ids = set(state.entity_id for state in all_states)
# Start with the currently registered states
states = [state for state in all_states
if state.entity_id in self.entity_ids]
expiration_time = dt_util.utcnow() - STATE_EXPIRATION
for entity_id, state in self.last_states.items():
# Don't save old states that have entities in the current run
if entity_id in current_entity_ids:
continue
# Don't save old states that have expired
if state.last_updated < expiration_time:
continue
states.append(state)
return states
async def async_dump_states(self) -> None:
"""Save the current state machine to storage."""
_LOGGER.debug("Dumping states")
try:
await self.store.async_save([
state.as_dict() for state in self.async_get_states()])
except HomeAssistantError as exc:
_LOGGER.error("Error saving current states", exc_info=exc)
def _load_restore_cache(hass: HomeAssistant):
"""Load the restore cache to be used by other components."""
@callback
def remove_cache(event):
"""Remove the states cache."""
hass.data.pop(DATA_RESTORE_CACHE, None)
def async_setup_dump(self, *args: Any) -> None:
"""Set up the restore state listeners."""
# Dump the initial states now. This helps minimize the risk of having
# old states loaded by overwritting the last states once home assistant
# has started and the old states have been read.
self.hass.async_create_task(self.async_dump_states())
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, remove_cache)
# Dump states periodically
async_track_time_interval(
self.hass, lambda *_: self.hass.async_create_task(
self.async_dump_states()), STATE_DUMP_INTERVAL)
last_run = last_recorder_run(hass)
# Dump states when stopping hass
self.hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, lambda *_: self.hass.async_create_task(
self.async_dump_states()))
if last_run is None or last_run.end is None:
_LOGGER.debug('Not creating cache - no suitable last run found: %s',
last_run)
hass.data[DATA_RESTORE_CACHE] = {}
return
@callback
def async_register_entity(self, entity_id: str) -> None:
"""Store this entity's state when hass is shutdown."""
self.entity_ids.add(entity_id)
last_end_time = last_run.end - timedelta(seconds=1)
# Unfortunately the recorder_run model do not return offset-aware time
last_end_time = last_end_time.replace(tzinfo=dt_util.UTC)
_LOGGER.debug("Last run: %s - %s", last_run.start, last_end_time)
states = get_states(hass, last_end_time, run=last_run)
# Cache the states
hass.data[DATA_RESTORE_CACHE] = {
state.entity_id: state for state in states}
_LOGGER.debug('Created cache with %s', list(hass.data[DATA_RESTORE_CACHE]))
@callback
def async_unregister_entity(self, entity_id: str) -> None:
"""Unregister this entity from saving state."""
self.entity_ids.remove(entity_id)
@bind_hass
async def async_get_last_state(hass, entity_id: str):
"""Restore state."""
if DATA_RESTORE_CACHE in hass.data:
return hass.data[DATA_RESTORE_CACHE].get(entity_id)
class RestoreEntity(Entity):
"""Mixin class for restoring previous entity state."""
if _RECORDER not in hass.config.components:
return None
async def async_added_to_hass(self) -> None:
"""Register this entity as a restorable entity."""
_, data = await asyncio.gather(
super().async_added_to_hass(),
RestoreStateData.async_get_instance(self.hass),
)
data.async_register_entity(self.entity_id)
if hass.state not in (CoreState.starting, CoreState.not_running):
_LOGGER.debug("Cache for %s can only be loaded during startup, not %s",
entity_id, hass.state)
return None
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
_, data = await asyncio.gather(
super().async_will_remove_from_hass(),
RestoreStateData.async_get_instance(self.hass),
)
data.async_unregister_entity(self.entity_id)
try:
with async_timeout.timeout(RECORDER_TIMEOUT, loop=hass.loop):
connected = await wait_connection_ready(hass)
except asyncio.TimeoutError:
return None
if not connected:
return None
if _LOCK not in hass.data:
hass.data[_LOCK] = asyncio.Lock(loop=hass.loop)
async with hass.data[_LOCK]:
if DATA_RESTORE_CACHE not in hass.data:
await hass.async_add_job(
_load_restore_cache, hass)
return hass.data.get(DATA_RESTORE_CACHE, {}).get(entity_id)
async def async_restore_state(entity, extract_info):
"""Call entity.async_restore_state with cached info."""
if entity.hass.state not in (CoreState.starting, CoreState.not_running):
_LOGGER.debug("Not restoring state for %s: Hass is not starting: %s",
entity.entity_id, entity.hass.state)
return
state = await async_get_last_state(entity.hass, entity.entity_id)
if not state:
return
await entity.async_restore_state(**extract_info(state))
async def async_get_last_state(self) -> Optional[State]:
"""Get the entity state from the previous run."""
if self.hass is None or self.entity_id is None:
# Return None if this entity isn't added to hass yet
_LOGGER.warning("Cannot get last state. Entity not added to hass")
return None
data = await RestoreStateData.async_get_instance(self.hass)
return data.last_states.get(self.entity_id)

View file

@ -1,13 +1,14 @@
"""Helper to help store data."""
import asyncio
from json import JSONEncoder
import logging
import os
from typing import Dict, Optional, Callable, Any
from typing import Dict, List, Optional, Callable, Union
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback
from homeassistant.loader import bind_hass
from homeassistant.util import json
from homeassistant.util import json as json_util
from homeassistant.helpers.event import async_call_later
STORAGE_DIR = '.storage'
@ -16,7 +17,7 @@ _LOGGER = logging.getLogger(__name__)
@bind_hass
async def async_migrator(hass, old_path, store, *,
old_conf_load_func=json.load_json,
old_conf_load_func=json_util.load_json,
old_conf_migrate_func=None):
"""Migrate old data to a store and then load data.
@ -46,7 +47,8 @@ async def async_migrator(hass, old_path, store, *,
class Store:
"""Class to help storing data."""
def __init__(self, hass, version: int, key: str, private: bool = False):
def __init__(self, hass, version: int, key: str, private: bool = False, *,
encoder: JSONEncoder = None):
"""Initialize storage class."""
self.version = version
self.key = key
@ -57,13 +59,14 @@ class Store:
self._unsub_stop_listener = None
self._write_lock = asyncio.Lock(loop=hass.loop)
self._load_task = None
self._encoder = encoder
@property
def path(self):
"""Return the config path."""
return self.hass.config.path(STORAGE_DIR, self.key)
async def async_load(self) -> Optional[Dict[str, Any]]:
async def async_load(self) -> Optional[Union[Dict, List]]:
"""Load data.
If the expected version does not match the given version, the migrate
@ -88,7 +91,7 @@ class Store:
data['data'] = data.pop('data_func')()
else:
data = await self.hass.async_add_executor_job(
json.load_json, self.path)
json_util.load_json, self.path)
if data == {}:
return None
@ -103,7 +106,7 @@ class Store:
self._load_task = None
return stored
async def async_save(self, data):
async def async_save(self, data: Union[Dict, List]) -> None:
"""Save data."""
self._data = {
'version': self.version,
@ -178,7 +181,7 @@ class Store:
try:
await self.hass.async_add_executor_job(
self._write_data, self.path, data)
except (json.SerializationError, json.WriteError) as err:
except (json_util.SerializationError, json_util.WriteError) as err:
_LOGGER.error('Error writing config for %s: %s', self.key, err)
def _write_data(self, path: str, data: Dict):
@ -187,7 +190,7 @@ class Store:
os.makedirs(os.path.dirname(path))
_LOGGER.debug('Writing data for %s', self.key)
json.save_json(path, data, self._private)
json_util.save_json(path, data, self._private, encoder=self._encoder)
async def _async_migrate_func(self, old_version, old_data):
"""Migrate to the new version."""

View file

@ -1,6 +1,6 @@
"""JSON utility functions."""
import logging
from typing import Union, List, Dict
from typing import Union, List, Dict, Optional
import json
import os
@ -41,7 +41,8 @@ def load_json(filename: str, default: Union[List, Dict, None] = None) \
def save_json(filename: str, data: Union[List, Dict],
private: bool = False) -> None:
private: bool = False, *,
encoder: Optional[json.JSONEncoder] = None) -> None:
"""Save JSON data to a file.
Returns True on success.
@ -49,7 +50,7 @@ def save_json(filename: str, data: Union[List, Dict],
tmp_filename = ""
tmp_path = os.path.split(filename)[0]
try:
json_data = json.dumps(data, sort_keys=True, indent=4)
json_data = json.dumps(data, sort_keys=True, indent=4, cls=encoder)
# Modern versions of Python tempfile create this file with mode 0o600
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8',
dir=tmp_path, delete=False) as fdesc:

View file

@ -114,8 +114,7 @@ def get_test_home_assistant():
# pylint: disable=protected-access
@asyncio.coroutine
def async_test_home_assistant(loop):
async def async_test_home_assistant(loop):
"""Return a Home Assistant object pointing at test config dir."""
hass = ha.HomeAssistant(loop)
hass.config.async_load = Mock()
@ -168,13 +167,12 @@ def async_test_home_assistant(loop):
# Mock async_start
orig_start = hass.async_start
@asyncio.coroutine
def mock_async_start():
async def mock_async_start():
"""Start the mocking."""
# We only mock time during tests and we want to track tasks
with patch('homeassistant.core._async_create_timer'), \
patch.object(hass, 'async_stop_track_tasks'):
yield from orig_start()
await orig_start()
hass.async_start = mock_async_start
@ -715,14 +713,20 @@ def init_recorder_component(hass, add_config=None):
def mock_restore_cache(hass, states):
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_CACHE
hass.data[key] = {
key = restore_state.DATA_RESTORE_STATE_TASK
data = restore_state.RestoreStateData(hass)
data.last_states = {
state.entity_id: state for state in states}
_LOGGER.debug('Restore cache: %s', hass.data[key])
assert len(hass.data[key]) == len(states), \
_LOGGER.debug('Restore cache: %s', data.last_states)
assert len(data.last_states) == len(states), \
"Duplicate entity_id? {}".format(states)
hass.state = ha.CoreState.starting
mock_component(hass, recorder.DOMAIN)
async def get_restore_state_data() -> restore_state.RestoreStateData:
return data
# Patch the singleton task in hass.data to return our new RestoreStateData
hass.data[key] = hass.async_create_task(get_restore_state_data())
class MockDependency:
@ -846,9 +850,10 @@ def mock_storage(data=None):
def mock_write_data(store, path, data_to_write):
"""Mock version of write data."""
# To ensure that the data can be serialized
_LOGGER.info('Writing data to %s: %s', store.key, data_to_write)
data[store.key] = json.loads(json.dumps(data_to_write))
# To ensure that the data can be serialized
data[store.key] = json.loads(json.dumps(
data_to_write, cls=store._encoder))
with patch('homeassistant.helpers.storage.Store._async_load',
side_effect=mock_async_load, autospec=True), \

View file

@ -6,10 +6,8 @@ from unittest.mock import patch
import requests
from aiohttp.hdrs import CONTENT_TYPE
from homeassistant import setup, const, core
import homeassistant.components as core_components
from homeassistant import setup, const
from homeassistant.components import emulated_hue, http
from homeassistant.util.async_ import run_coroutine_threadsafe
from tests.common import get_test_instance_port, get_test_home_assistant
@ -20,29 +18,6 @@ BRIDGE_URL_BASE = 'http://127.0.0.1:{}'.format(BRIDGE_SERVER_PORT) + '{}'
JSON_HEADERS = {CONTENT_TYPE: const.CONTENT_TYPE_JSON}
def setup_hass_instance(emulated_hue_config):
"""Set up the Home Assistant instance to test."""
hass = get_test_home_assistant()
# We need to do this to get access to homeassistant/turn_(on,off)
run_coroutine_threadsafe(
core_components.async_setup(hass, {core.DOMAIN: {}}), hass.loop
).result()
setup.setup_component(
hass, http.DOMAIN,
{http.DOMAIN: {http.CONF_SERVER_PORT: HTTP_SERVER_PORT}})
setup.setup_component(hass, emulated_hue.DOMAIN, emulated_hue_config)
return hass
def start_hass_instance(hass):
"""Start the Home Assistant instance to test."""
hass.start()
class TestEmulatedHue(unittest.TestCase):
"""Test the emulated Hue component."""
@ -53,11 +28,6 @@ class TestEmulatedHue(unittest.TestCase):
"""Set up the class."""
cls.hass = hass = get_test_home_assistant()
# We need to do this to get access to homeassistant/turn_(on,off)
run_coroutine_threadsafe(
core_components.async_setup(hass, {core.DOMAIN: {}}), hass.loop
).result()
setup.setup_component(
hass, http.DOMAIN,
{http.DOMAIN: {http.CONF_SERVER_PORT: HTTP_SERVER_PORT}})

View file

@ -585,7 +585,7 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mqtt_mock):
'effect': 'random',
'color_temp': 100,
'white_value': 50})
with patch('homeassistant.components.light.mqtt.schema_basic'
with patch('homeassistant.helpers.restore_state.RestoreEntity'
'.async_get_last_state',
return_value=mock_coro(fake_state)):
with assert_setup_component(1, light.DOMAIN):

View file

@ -279,7 +279,7 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mqtt_mock):
'color_temp': 100,
'white_value': 50})
with patch('homeassistant.components.light.mqtt.schema_json'
with patch('homeassistant.helpers.restore_state.RestoreEntity'
'.async_get_last_state',
return_value=mock_coro(fake_state)):
assert await async_setup_component(hass, light.DOMAIN, {

View file

@ -245,7 +245,7 @@ async def test_optimistic(hass, mqtt_mock):
'color_temp': 100,
'white_value': 50})
with patch('homeassistant.components.light.mqtt.schema_template'
with patch('homeassistant.helpers.restore_state.RestoreEntity'
'.async_get_last_state',
return_value=mock_coro(fake_state)):
with assert_setup_component(1, light.DOMAIN):

View file

@ -1,6 +1,5 @@
"""The tests for the Recorder component."""
# pylint: disable=protected-access
import asyncio
from unittest.mock import patch, call
import pytest
@ -9,7 +8,7 @@ from sqlalchemy.pool import StaticPool
from homeassistant.bootstrap import async_setup_component
from homeassistant.components.recorder import (
wait_connection_ready, migration, const, models)
migration, const, models)
from tests.components.recorder import models_original
@ -23,26 +22,24 @@ def create_engine_test(*args, **kwargs):
return engine
@asyncio.coroutine
def test_schema_update_calls(hass):
async def test_schema_update_calls(hass):
"""Test that schema migrations occur in correct order."""
with patch('sqlalchemy.create_engine', new=create_engine_test), \
patch('homeassistant.components.recorder.migration._apply_update') as \
update:
yield from async_setup_component(hass, 'recorder', {
await async_setup_component(hass, 'recorder', {
'recorder': {
'db_url': 'sqlite://'
}
})
yield from wait_connection_ready(hass)
await hass.async_block_till_done()
update.assert_has_calls([
call(hass.data[const.DATA_INSTANCE].engine, version+1, 0) for version
in range(0, models.SCHEMA_VERSION)])
@asyncio.coroutine
def test_schema_migrate(hass):
async def test_schema_migrate(hass):
"""Test the full schema migration logic.
We're just testing that the logic can execute successfully here without
@ -52,12 +49,12 @@ def test_schema_migrate(hass):
with patch('sqlalchemy.create_engine', new=create_engine_test), \
patch('homeassistant.components.recorder.Recorder._setup_run') as \
setup_run:
yield from async_setup_component(hass, 'recorder', {
await async_setup_component(hass, 'recorder', {
'recorder': {
'db_url': 'sqlite://'
}
})
yield from wait_connection_ready(hass)
await hass.async_block_till_done()
assert setup_run.called

View file

@ -57,7 +57,8 @@ async def test_sending_mqtt_commands_and_optimistic(hass, mock_publish):
"""Test the sending MQTT commands in optimistic mode."""
fake_state = ha.State('switch.test', 'on')
with patch('homeassistant.components.switch.mqtt.async_get_last_state',
with patch('homeassistant.helpers.restore_state.RestoreEntity'
'.async_get_last_state',
return_value=mock_coro(fake_state)):
assert await async_setup_component(hass, switch.DOMAIN, {
switch.DOMAIN: {

View file

@ -519,7 +519,6 @@ async def test_fetch_period_api(hass, hass_client):
"""Test the fetch period view for history."""
await hass.async_add_job(init_recorder_component, hass)
await async_setup_component(hass, 'history', {})
await hass.components.recorder.wait_connection_ready()
await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done)
client = await hass_client()
response = await client.get(

View file

@ -575,7 +575,6 @@ async def test_logbook_view(hass, aiohttp_client):
"""Test the logbook view."""
await hass.async_add_job(init_recorder_component, hass)
await async_setup_component(hass, 'logbook', {})
await hass.components.recorder.wait_connection_ready()
await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done)
client = await aiohttp_client(hass.http.app)
response = await client.get(
@ -587,7 +586,6 @@ async def test_logbook_view_period_entity(hass, aiohttp_client):
"""Test the logbook view with period and entity."""
await hass.async_add_job(init_recorder_component, hass)
await async_setup_component(hass, 'logbook', {})
await hass.components.recorder.wait_connection_ready()
await hass.async_add_job(hass.data[recorder.DATA_INSTANCE].block_till_done)
entity_id_test = 'switch.test'

View file

@ -1,60 +1,52 @@
"""The tests for the Restore component."""
import asyncio
from datetime import timedelta
from unittest.mock import patch, MagicMock
from datetime import datetime
from homeassistant.setup import setup_component
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.core import CoreState, split_entity_id, State
import homeassistant.util.dt as dt_util
from homeassistant.components import input_boolean, recorder
from homeassistant.core import CoreState, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.restore_state import (
async_get_last_state, DATA_RESTORE_CACHE)
from homeassistant.components.recorder.models import RecorderRuns, States
RestoreStateData, RestoreEntity, DATA_RESTORE_STATE_TASK)
from homeassistant.util import dt as dt_util
from tests.common import (
get_test_home_assistant, mock_coro, init_recorder_component,
mock_component)
from asynctest import patch
from tests.common import mock_coro
@asyncio.coroutine
def test_caching_data(hass):
async def test_caching_data(hass):
"""Test that we cache data."""
mock_component(hass, 'recorder')
hass.state = CoreState.starting
states = [
State('input_boolean.b0', 'on'),
State('input_boolean.b1', 'on'),
State('input_boolean.b2', 'on'),
]
with patch('homeassistant.helpers.restore_state.last_recorder_run',
return_value=MagicMock(end=dt_util.utcnow())), \
patch('homeassistant.helpers.restore_state.get_states',
return_value=states), \
patch('homeassistant.helpers.restore_state.wait_connection_ready',
return_value=mock_coro(True)):
state = yield from async_get_last_state(hass, 'input_boolean.b1')
data = await RestoreStateData.async_get_instance(hass)
await data.store.async_save([state.as_dict() for state in states])
assert DATA_RESTORE_CACHE in hass.data
assert hass.data[DATA_RESTORE_CACHE] == {st.entity_id: st for st in states}
# Emulate a fresh load
hass.data[DATA_RESTORE_STATE_TASK] = None
entity = RestoreEntity()
entity.hass = hass
entity.entity_id = 'input_boolean.b1'
# Mock that only b1 is present this run
with patch('homeassistant.helpers.restore_state.Store.async_save'
) as mock_write_data:
state = await entity.async_get_last_state()
assert state is not None
assert state.entity_id == 'input_boolean.b1'
assert state.state == 'on'
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
yield from hass.async_block_till_done()
assert DATA_RESTORE_CACHE not in hass.data
assert mock_write_data.called
@asyncio.coroutine
def test_hass_running(hass):
"""Test that cache cannot be accessed while hass is running."""
mock_component(hass, 'recorder')
async def test_hass_starting(hass):
"""Test that we cache data."""
hass.state = CoreState.starting
states = [
State('input_boolean.b0', 'on'),
@ -62,129 +54,144 @@ def test_hass_running(hass):
State('input_boolean.b2', 'on'),
]
with patch('homeassistant.helpers.restore_state.last_recorder_run',
return_value=MagicMock(end=dt_util.utcnow())), \
patch('homeassistant.helpers.restore_state.get_states',
return_value=states), \
patch('homeassistant.helpers.restore_state.wait_connection_ready',
return_value=mock_coro(True)):
state = yield from async_get_last_state(hass, 'input_boolean.b1')
assert state is None
data = await RestoreStateData.async_get_instance(hass)
await data.store.async_save([state.as_dict() for state in states])
# Emulate a fresh load
hass.data[DATA_RESTORE_STATE_TASK] = None
@asyncio.coroutine
def test_not_connected(hass):
"""Test that cache cannot be accessed if db connection times out."""
mock_component(hass, 'recorder')
hass.state = CoreState.starting
entity = RestoreEntity()
entity.hass = hass
entity.entity_id = 'input_boolean.b1'
states = [State('input_boolean.b1', 'on')]
# Mock that only b1 is present this run
states = [
State('input_boolean.b1', 'on'),
]
with patch('homeassistant.helpers.restore_state.Store.async_save'
) as mock_write_data, patch.object(
hass.states, 'async_all', return_value=states):
state = await entity.async_get_last_state()
with patch('homeassistant.helpers.restore_state.last_recorder_run',
return_value=MagicMock(end=dt_util.utcnow())), \
patch('homeassistant.helpers.restore_state.get_states',
return_value=states), \
patch('homeassistant.helpers.restore_state.wait_connection_ready',
return_value=mock_coro(False)):
state = yield from async_get_last_state(hass, 'input_boolean.b1')
assert state is None
@asyncio.coroutine
def test_no_last_run_found(hass):
"""Test that cache cannot be accessed if no last run found."""
mock_component(hass, 'recorder')
hass.state = CoreState.starting
states = [State('input_boolean.b1', 'on')]
with patch('homeassistant.helpers.restore_state.last_recorder_run',
return_value=None), \
patch('homeassistant.helpers.restore_state.get_states',
return_value=states), \
patch('homeassistant.helpers.restore_state.wait_connection_ready',
return_value=mock_coro(True)):
state = yield from async_get_last_state(hass, 'input_boolean.b1')
assert state is None
@asyncio.coroutine
def test_cache_timeout(hass):
"""Test that cache timeout returns none."""
mock_component(hass, 'recorder')
hass.state = CoreState.starting
states = [State('input_boolean.b1', 'on')]
@asyncio.coroutine
def timeout_coro():
raise asyncio.TimeoutError()
with patch('homeassistant.helpers.restore_state.last_recorder_run',
return_value=MagicMock(end=dt_util.utcnow())), \
patch('homeassistant.helpers.restore_state.get_states',
return_value=states), \
patch('homeassistant.helpers.restore_state.wait_connection_ready',
return_value=timeout_coro()):
state = yield from async_get_last_state(hass, 'input_boolean.b1')
assert state is None
def _add_data_in_last_run(hass, entities):
"""Add test data in the last recorder_run."""
# pylint: disable=protected-access
t_now = dt_util.utcnow() - timedelta(minutes=10)
t_min_1 = t_now - timedelta(minutes=20)
t_min_2 = t_now - timedelta(minutes=30)
with recorder.session_scope(hass=hass) as session:
session.add(RecorderRuns(
start=t_min_2,
end=t_now,
created=t_min_2
))
for entity_id, state in entities.items():
session.add(States(
entity_id=entity_id,
domain=split_entity_id(entity_id)[0],
state=state,
attributes='{}',
last_changed=t_min_1,
last_updated=t_min_1,
created=t_min_1))
def test_filling_the_cache():
"""Test filling the cache from the DB."""
test_entity_id1 = 'input_boolean.b1'
test_entity_id2 = 'input_boolean.b2'
hass = get_test_home_assistant()
hass.state = CoreState.starting
init_recorder_component(hass)
_add_data_in_last_run(hass, {
test_entity_id1: 'on',
test_entity_id2: 'off',
})
hass.block_till_done()
setup_component(hass, input_boolean.DOMAIN, {
input_boolean.DOMAIN: {
'b1': None,
'b2': None,
}})
hass.start()
state = hass.states.get('input_boolean.b1')
assert state
assert state is not None
assert state.entity_id == 'input_boolean.b1'
assert state.state == 'on'
state = hass.states.get('input_boolean.b2')
assert state
assert state.state == 'off'
# Assert that no data was written yet, since hass is still starting.
assert not mock_write_data.called
hass.stop()
# Finish hass startup
with patch('homeassistant.helpers.restore_state.Store.async_save'
) as mock_write_data:
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
await hass.async_block_till_done()
# Assert that this session states were written
assert mock_write_data.called
async def test_dump_data(hass):
"""Test that we cache data."""
states = [
State('input_boolean.b0', 'on'),
State('input_boolean.b1', 'on'),
State('input_boolean.b2', 'on'),
]
entity = Entity()
entity.hass = hass
entity.entity_id = 'input_boolean.b0'
await entity.async_added_to_hass()
entity = RestoreEntity()
entity.hass = hass
entity.entity_id = 'input_boolean.b1'
await entity.async_added_to_hass()
data = await RestoreStateData.async_get_instance(hass)
data.last_states = {
'input_boolean.b0': State('input_boolean.b0', 'off'),
'input_boolean.b1': State('input_boolean.b1', 'off'),
'input_boolean.b2': State('input_boolean.b2', 'off'),
'input_boolean.b3': State('input_boolean.b3', 'off'),
'input_boolean.b4': State(
'input_boolean.b4', 'off', last_updated=datetime(
1985, 10, 26, 1, 22, tzinfo=dt_util.UTC)),
}
with patch('homeassistant.helpers.restore_state.Store.async_save'
) as mock_write_data, patch.object(
hass.states, 'async_all', return_value=states):
await data.async_dump_states()
assert mock_write_data.called
args = mock_write_data.mock_calls[0][1]
written_states = args[0]
# b0 should not be written, since it didn't extend RestoreEntity
# b1 should be written, since it is present in the current run
# b2 should not be written, since it is not registered with the helper
# b3 should be written, since it is still not expired
# b4 should not be written, since it is now expired
assert len(written_states) == 2
assert written_states[0]['entity_id'] == 'input_boolean.b1'
assert written_states[0]['state'] == 'on'
assert written_states[1]['entity_id'] == 'input_boolean.b3'
assert written_states[1]['state'] == 'off'
# Test that removed entities are not persisted
await entity.async_will_remove_from_hass()
with patch('homeassistant.helpers.restore_state.Store.async_save'
) as mock_write_data, patch.object(
hass.states, 'async_all', return_value=states):
await data.async_dump_states()
assert mock_write_data.called
args = mock_write_data.mock_calls[0][1]
written_states = args[0]
assert len(written_states) == 1
assert written_states[0]['entity_id'] == 'input_boolean.b3'
assert written_states[0]['state'] == 'off'
async def test_dump_error(hass):
"""Test that we cache data."""
states = [
State('input_boolean.b0', 'on'),
State('input_boolean.b1', 'on'),
State('input_boolean.b2', 'on'),
]
entity = Entity()
entity.hass = hass
entity.entity_id = 'input_boolean.b0'
await entity.async_added_to_hass()
entity = RestoreEntity()
entity.hass = hass
entity.entity_id = 'input_boolean.b1'
await entity.async_added_to_hass()
data = await RestoreStateData.async_get_instance(hass)
with patch('homeassistant.helpers.restore_state.Store.async_save',
return_value=mock_coro(exception=HomeAssistantError)
) as mock_write_data, patch.object(
hass.states, 'async_all', return_value=states):
await data.async_dump_states()
assert mock_write_data.called
async def test_load_error(hass):
"""Test that we cache data."""
entity = RestoreEntity()
entity.hass = hass
entity.entity_id = 'input_boolean.b1'
with patch('homeassistant.helpers.storage.Store.async_load',
return_value=mock_coro(exception=HomeAssistantError)):
state = await entity.async_get_last_state()
assert state is None

View file

@ -1,7 +1,8 @@
"""Tests for the storage helper."""
import asyncio
from datetime import timedelta
from unittest.mock import patch
import json
from unittest.mock import patch, Mock
import pytest
@ -31,6 +32,21 @@ async def test_loading(hass, store):
assert data == MOCK_DATA
async def test_custom_encoder(hass):
"""Test we can save and load data."""
class JSONEncoder(json.JSONEncoder):
"""Mock JSON encoder."""
def default(self, o):
"""Mock JSON encode method."""
return "9"
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, encoder=JSONEncoder)
await store.async_save(Mock())
data = await store.async_load()
assert data == "9"
async def test_loading_non_existing(hass, store):
"""Test we can save and load data."""
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):

View file

@ -1,14 +1,17 @@
"""Test Home Assistant json utility functions."""
from json import JSONEncoder
import os
import unittest
import sys
from tempfile import mkdtemp
from homeassistant.util.json import (SerializationError,
load_json, save_json)
from homeassistant.util.json import (
SerializationError, load_json, save_json)
from homeassistant.exceptions import HomeAssistantError
import pytest
from unittest.mock import Mock
# Test data that can be saved as JSON
TEST_JSON_A = {"a": 1, "B": "two"}
TEST_JSON_B = {"a": "one", "B": 2}
@ -74,3 +77,17 @@ class TestJSON(unittest.TestCase):
fh.write(TEST_BAD_SERIALIED)
with pytest.raises(HomeAssistantError):
load_json(fname)
def test_custom_encoder(self):
"""Test serializing with a custom encoder."""
class MockJSONEncoder(JSONEncoder):
"""Mock JSON encoder."""
def default(self, o):
"""Mock JSON encode method."""
return "9"
fname = self._path_for("test6")
save_json(fname, Mock(), encoder=MockJSONEncoder)
data = load_json(fname)
self.assertEqual(data, "9")