From 1522e673516b8b807540d1cd87c564ef985c3639 Mon Sep 17 00:00:00 2001 From: Johann Kellerman Date: Sun, 5 Mar 2017 01:19:01 +0200 Subject: [PATCH] Restore for automation entities (#6254) * Restore for automation entities * coroutine * no clue what i'm doing now * Still passes nicely in py 3.4 --- .../components/automation/__init__.py | 13 +++- tests/common.py | 17 ++--- tests/components/automation/test_init.py | 76 +++++++++++++++---- 3 files changed, 80 insertions(+), 26 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 0e734d7214d0..a5fc52c448ea 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -21,6 +21,7 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import extract_domain_configs, script, condition from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.restore_state import async_get_last_state from homeassistant.loader import get_platform from homeassistant.util.dt import utcnow import homeassistant.helpers.config_validation as cv @@ -265,9 +266,15 @@ class AutomationEntity(ToggleEntity): @asyncio.coroutine def async_added_to_hass(self) -> None: - """Startup if initial_state.""" - if self._initial_state: - yield from self.async_enable() + """Startup with initial state or previous state.""" + state = yield from async_get_last_state(self.hass, self.entity_id) + if state is None: + if self._initial_state: + yield from self.async_enable() + else: + self._last_triggered = state.attributes.get('last_triggered') + if state.state == STATE_ON: + yield from self.async_enable() @asyncio.coroutine def async_turn_on(self, **kwargs) -> None: diff --git a/tests/common.py b/tests/common.py index 34cd97656956..840dfd50caa3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -131,6 +131,7 @@ def async_test_home_assistant(loop): @ha.callback def clear_instance(event): + """Clear global instance.""" global INST_COUNT INST_COUNT -= 1 @@ -140,20 +141,18 @@ def async_test_home_assistant(loop): def mock_service(hass, domain, service): - """Setup a fake service. - - Return a list that logs all calls to fake service. - """ + """Setup a fake service & return a list that logs calls to this service.""" calls = [] - # pylint: disable=redefined-outer-name - @ha.callback - def mock_service(call): + @asyncio.coroutine + def mock_service_log(call): # pylint: disable=unnecessary-lambda """"Mocked service call.""" calls.append(call) - # pylint: disable=unnecessary-lambda - hass.services.register(domain, service, mock_service) + if hass.loop.__dict__.get("_thread_ident", 0) == threading.get_ident(): + hass.services.async_register(domain, service, mock_service_log) + else: + hass.services.register(domain, service, mock_service_log) return calls diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index fa7658f34075..9dc080890116 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1,17 +1,19 @@ """The tests for the automation component.""" -import unittest +import asyncio from datetime import timedelta +import unittest from unittest.mock import patch -from homeassistant.core import callback -from homeassistant.bootstrap import setup_component +from homeassistant.core import State +from homeassistant.bootstrap import setup_component, async_setup_component import homeassistant.components.automation as automation -from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.const import ATTR_ENTITY_ID, STATE_ON, STATE_OFF from homeassistant.exceptions import HomeAssistantError import homeassistant.util.dt as dt_util -from tests.common import get_test_home_assistant, assert_setup_component, \ - fire_time_changed, mock_component +from tests.common import ( + assert_setup_component, get_test_home_assistant, fire_time_changed, + mock_component, mock_service, mock_restore_cache) # pylint: disable=invalid-name @@ -22,14 +24,7 @@ class TestAutomation(unittest.TestCase): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() mock_component(self.hass, 'group') - self.calls = [] - - @callback - def record_call(service): - """Helper to record calls.""" - self.calls.append(service) - - self.hass.services.register('test', 'automation', record_call) + self.calls = mock_service(self.hass, 'test', 'automation') def tearDown(self): """Stop everything that was started.""" @@ -572,3 +567,56 @@ class TestAutomation(unittest.TestCase): self.hass.bus.fire('test_event') self.hass.block_till_done() assert len(self.calls) == 2 + + +@asyncio.coroutine +def test_automation_restore_state(hass): + """Ensure states are restored on startup.""" + time = dt_util.utcnow() + + mock_restore_cache(hass, ( + State('automation.hello', STATE_ON), + State('automation.bye', STATE_OFF, {'last_triggered': time}), + )) + + config = {automation.DOMAIN: [{ + 'alias': 'hello', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event_hello', + }, + 'action': {'service': 'test.automation'} + }, { + 'alias': 'bye', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event_bye', + }, + 'action': {'service': 'test.automation'} + }]} + + assert (yield from async_setup_component(hass, automation.DOMAIN, config)) + + state = hass.states.get('automation.hello') + assert state + assert state.state == STATE_ON + + state = hass.states.get('automation.bye') + assert state + assert state.state == STATE_OFF + assert state.attributes.get('last_triggered') == time + + calls = mock_service(hass, 'test', 'automation') + + assert automation.is_on(hass, 'automation.bye') is False + + hass.bus.async_fire('test_event_bye') + yield from hass.async_block_till_done() + assert len(calls) == 0 + + assert automation.is_on(hass, 'automation.hello') + + hass.bus.async_fire('test_event_hello') + yield from hass.async_block_till_done() + + assert len(calls) == 1