From e1cbd6b4c09a1630fa343ef069b5ccc78d073df1 Mon Sep 17 00:00:00 2001 From: Pascal Vizeli Date: Sat, 18 Feb 2017 23:17:18 +0100 Subject: [PATCH] MQTT convert to async (#6064) * Migrate mqtt to async * address paulus comment / convert it complet async * adress paulus comment / remove future * Automation triggers should be async * Fix MQTT async calls * Show that event helpers are callbacks * Fix tests * Lint --- .../components/automation/__init__.py | 2 +- homeassistant/components/automation/event.py | 2 + .../components/automation/litejet.py | 2 + homeassistant/components/automation/mqtt.py | 6 +- .../components/automation/numeric_state.py | 2 + homeassistant/components/automation/state.py | 3 + homeassistant/components/automation/sun.py | 2 + .../components/automation/template.py | 2 + homeassistant/components/automation/time.py | 2 + homeassistant/components/automation/zone.py | 2 + homeassistant/components/mqtt/__init__.py | 195 ++++++++++++------ homeassistant/components/mqtt/discovery.py | 7 +- homeassistant/components/mqtt/server.py | 22 +- homeassistant/helpers/event.py | 9 + tests/common.py | 8 +- .../alarm_control_panel/test_mqtt.py | 6 +- tests/components/cover/test_mqtt.py | 10 +- tests/components/light/test_mqtt.py | 24 +-- tests/components/light/test_mqtt_json.py | 44 ++-- tests/components/light/test_mqtt_template.py | 44 ++-- tests/components/lock/test_mqtt.py | 4 +- tests/components/mqtt/test_discovery.py | 15 +- tests/components/mqtt/test_init.py | 163 +++++++++------ tests/components/mqtt/test_server.py | 7 +- tests/components/switch/test_mqtt.py | 4 +- 25 files changed, 356 insertions(+), 231 deletions(-) diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 341e6e902337..bebace6d8272 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -412,7 +412,7 @@ def _async_process_trigger(hass, config, trigger_configs, name, action): if platform is None: return None - remove = platform.async_trigger(hass, conf, action) + remove = yield from platform.async_trigger(hass, conf, action) if not remove: _LOGGER.error("Error setting up trigger %s", name) diff --git a/homeassistant/components/automation/event.py b/homeassistant/components/automation/event.py index a51f9fa81877..21bf243e34fc 100644 --- a/homeassistant/components/automation/event.py +++ b/homeassistant/components/automation/event.py @@ -4,6 +4,7 @@ Offer event listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#event-trigger """ +import asyncio import logging import voluptuous as vol @@ -24,6 +25,7 @@ TRIGGER_SCHEMA = vol.Schema({ }) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for events based on configuration.""" event_type = config.get(CONF_EVENT_TYPE) diff --git a/homeassistant/components/automation/litejet.py b/homeassistant/components/automation/litejet.py index 2b298d4979bd..56109e27f1b9 100644 --- a/homeassistant/components/automation/litejet.py +++ b/homeassistant/components/automation/litejet.py @@ -4,6 +4,7 @@ Trigger an automation when a LiteJet switch is released. For more details about this platform, please refer to the documentation at https://home-assistant.io/components/automation.litejet/ """ +import asyncio import logging import voluptuous as vol @@ -32,6 +33,7 @@ TRIGGER_SCHEMA = vol.Schema({ }) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for events based on configuration.""" number = config.get(CONF_NUMBER) diff --git a/homeassistant/components/automation/mqtt.py b/homeassistant/components/automation/mqtt.py index 4818c02d9ff8..fbea2cede380 100644 --- a/homeassistant/components/automation/mqtt.py +++ b/homeassistant/components/automation/mqtt.py @@ -4,6 +4,7 @@ Offer MQTT listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#mqtt-trigger """ +import asyncio import json import voluptuous as vol @@ -24,6 +25,7 @@ TRIGGER_SCHEMA = vol.Schema({ }) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for state changes based on configuration.""" topic = config.get(CONF_TOPIC) @@ -49,4 +51,6 @@ def async_trigger(hass, config, action): 'trigger': data }) - return mqtt.async_subscribe(hass, topic, mqtt_automation_listener) + remove = yield from mqtt.async_subscribe( + hass, topic, mqtt_automation_listener) + return remove diff --git a/homeassistant/components/automation/numeric_state.py b/homeassistant/components/automation/numeric_state.py index 9c3ac7d83964..8b3c3e576706 100644 --- a/homeassistant/components/automation/numeric_state.py +++ b/homeassistant/components/automation/numeric_state.py @@ -4,6 +4,7 @@ Offer numeric state listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#numeric-state-trigger """ +import asyncio import logging import voluptuous as vol @@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({ _LOGGER = logging.getLogger(__name__) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for state changes based on configuration.""" entity_id = config.get(CONF_ENTITY_ID) diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 65cca462ed9f..53ac4fba7916 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -4,6 +4,7 @@ Offer state listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#state-trigger """ +import asyncio import voluptuous as vol from homeassistant.core import callback @@ -34,6 +35,7 @@ TRIGGER_SCHEMA = vol.All( ) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for state changes based on configuration.""" entity_id = config.get(CONF_ENTITY_ID) @@ -97,6 +99,7 @@ def async_trigger(hass, config, action): unsub = async_track_state_change( hass, entity_id, state_automation_listener, from_state, to_state) + @callback def async_remove(): """Remove state listeners async.""" unsub() diff --git a/homeassistant/components/automation/sun.py b/homeassistant/components/automation/sun.py index 2baa0726813a..4529b5a8b60c 100644 --- a/homeassistant/components/automation/sun.py +++ b/homeassistant/components/automation/sun.py @@ -4,6 +4,7 @@ Offer sun based automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#sun-trigger """ +import asyncio from datetime import timedelta import logging @@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.Schema({ }) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for events based on configuration.""" event = config.get(CONF_EVENT) diff --git a/homeassistant/components/automation/template.py b/homeassistant/components/automation/template.py index 9727041e7508..a83671d5fa8d 100644 --- a/homeassistant/components/automation/template.py +++ b/homeassistant/components/automation/template.py @@ -4,6 +4,7 @@ Offer template automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#template-trigger """ +import asyncio import logging import voluptuous as vol @@ -22,6 +23,7 @@ TRIGGER_SCHEMA = IF_ACTION_SCHEMA = vol.Schema({ }) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for state changes based on configuration.""" value_template = config.get(CONF_VALUE_TEMPLATE) diff --git a/homeassistant/components/automation/time.py b/homeassistant/components/automation/time.py index d0315f26de08..e33fd0f6ba92 100644 --- a/homeassistant/components/automation/time.py +++ b/homeassistant/components/automation/time.py @@ -4,6 +4,7 @@ Offer time listening automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#time-trigger """ +import asyncio import logging import voluptuous as vol @@ -29,6 +30,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({ CONF_SECONDS, CONF_AFTER)) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for state changes based on configuration.""" if CONF_AFTER in config: diff --git a/homeassistant/components/automation/zone.py b/homeassistant/components/automation/zone.py index 935dc3cf24c6..8ffc0498317e 100644 --- a/homeassistant/components/automation/zone.py +++ b/homeassistant/components/automation/zone.py @@ -4,6 +4,7 @@ Offer zone automation rules. For more details about this automation rule, please refer to the documentation at https://home-assistant.io/components/automation/#zone-trigger """ +import asyncio import voluptuous as vol from homeassistant.core import callback @@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.Schema({ }) +@asyncio.coroutine def async_trigger(hass, config, action): """Listen for state changes based on configuration.""" entity_id = config.get(CONF_ENTITY_ID) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 79c154ef2239..bb5990baeb54 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -4,6 +4,7 @@ Support for MQTT message handling. For more details about this component, please refer to the documentation at https://home-assistant.io/components/mqtt/ """ +import asyncio import logging import os import socket @@ -12,11 +13,12 @@ import time import voluptuous as vol from homeassistant.core import callback -from homeassistant.bootstrap import prepare_setup_platform +from homeassistant.bootstrap import async_prepare_setup_platform from homeassistant.config import load_yaml_config_file from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import template, config_validation as cv -from homeassistant.helpers.event import threaded_listener_factory +from homeassistant.util.async import ( + run_coroutine_threadsafe, run_callback_threadsafe) from homeassistant.const import ( EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_VALUE_TEMPLATE, CONF_USERNAME, CONF_PASSWORD, CONF_PORT, CONF_PROTOCOL, CONF_PAYLOAD) @@ -26,7 +28,7 @@ _LOGGER = logging.getLogger(__name__) DOMAIN = 'mqtt' -MQTT_CLIENT = None +DATA_MQTT = 'mqtt' SERVICE_PUBLISH = 'publish' EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received' @@ -183,11 +185,11 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None): hass.services.call(DOMAIN, SERVICE_PUBLISH, data) -@callback +@asyncio.coroutine def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): """Subscribe to an MQTT topic.""" @callback - def mqtt_topic_subscriber(event): + def async_mqtt_topic_subscriber(event): """Match subscribed MQTT topic.""" if not _match_topic(topic, event.data[ATTR_TOPIC]): return @@ -195,61 +197,82 @@ def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): hass.async_run_job(msg_callback, event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) - async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED, - mqtt_topic_subscriber) - - # Future: track subscriber count and unsubscribe in remove - MQTT_CLIENT.subscribe(topic, qos) + async_remove = hass.bus.async_listen( + EVENT_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber) + yield from hass.data[DATA_MQTT].async_subscribe(topic, qos) return async_remove -# pylint: disable=invalid-name -subscribe = threaded_listener_factory(async_subscribe) +def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): + """Subscribe to an MQTT topic.""" + async_remove = run_coroutine_threadsafe( + async_subscribe(hass, topic, msg_callback, qos), + hass.loop + ).result() + + def remove(): + """Remove listener convert.""" + run_callback_threadsafe(hass.loop, async_remove).result() + + return remove -def _setup_server(hass, config): - """Try to start embedded MQTT broker.""" +@asyncio.coroutine +def _async_setup_server(hass, config): + """Try to start embedded MQTT broker. + + This method is a coroutine. + """ conf = config.get(DOMAIN, {}) # Only setup if embedded config passed in or no broker specified if CONF_EMBEDDED not in conf and CONF_BROKER in conf: return None - server = prepare_setup_platform(hass, config, DOMAIN, 'server') + server = yield from async_prepare_setup_platform( + hass, config, DOMAIN, 'server') if server is None: _LOGGER.error("Unable to load embedded server") return None - success, broker_config = server.start(hass, conf.get(CONF_EMBEDDED)) + success, broker_config = \ + yield from server.async_start(hass, conf.get(CONF_EMBEDDED)) return success and broker_config -def _setup_discovery(hass, config): - """Try to start the discovery of MQTT devices.""" +@asyncio.coroutine +def _async_setup_discovery(hass, config): + """Try to start the discovery of MQTT devices. + + This method is a coroutine. + """ conf = config.get(DOMAIN, {}) - discovery = prepare_setup_platform(hass, config, DOMAIN, 'discovery') + discovery = yield from async_prepare_setup_platform( + hass, config, DOMAIN, 'discovery') if discovery is None: _LOGGER.error("Unable to load MQTT discovery") return None - success = discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config) + success = yield from discovery.async_start( + hass, conf[CONF_DISCOVERY_PREFIX], config) return success -def setup(hass, config): +@asyncio.coroutine +def async_setup(hass, config): """Start the MQTT protocol service.""" conf = config.get(DOMAIN, {}) client_id = conf.get(CONF_CLIENT_ID) keepalive = conf.get(CONF_KEEPALIVE) - broker_config = _setup_server(hass, config) + broker_config = yield from _async_setup_server(hass, config) if CONF_BROKER in conf: broker = conf[CONF_BROKER] @@ -283,27 +306,31 @@ def setup(hass, config): will_message = conf.get(CONF_WILL_MESSAGE) birth_message = conf.get(CONF_BIRTH_MESSAGE) - global MQTT_CLIENT try: - MQTT_CLIENT = MQTT(hass, broker, port, client_id, keepalive, - username, password, certificate, client_key, - client_cert, tls_insecure, protocol, will_message, - birth_message) + hass.data[DATA_MQTT] = MQTT( + hass, broker, port, client_id, keepalive, username, password, + certificate, client_key, client_cert, tls_insecure, protocol, + will_message, birth_message) except socket.error: _LOGGER.exception("Can't connect to the broker. " "Please check your settings and the broker itself") return False - def stop_mqtt(event): + @asyncio.coroutine + def async_stop_mqtt(event): """Stop MQTT component.""" - MQTT_CLIENT.stop() + yield from hass.data[DATA_MQTT].async_stop() - def start_mqtt(event): + @asyncio.coroutine + def async_start_mqtt(event): """Launch MQTT component when Home Assistant starts up.""" - MQTT_CLIENT.start() - hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_mqtt) + yield from hass.data[DATA_MQTT].async_start() + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt) - def publish_service(call): + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, async_start_mqtt) + + @asyncio.coroutine + def async_publish_service(call): """Handle MQTT publish service calls.""" msg_topic = call.data[ATTR_TOPIC] payload = call.data.get(ATTR_PAYLOAD) @@ -312,26 +339,28 @@ def setup(hass, config): retain = call.data[ATTR_RETAIN] try: if payload_template is not None: - payload = template.Template(payload_template, hass).render() + payload = \ + template.Template(payload_template, hass).async_render() except template.jinja2.TemplateError as exc: _LOGGER.error( "Unable to publish to '%s': rendering payload template of " "'%s' failed because %s", msg_topic, payload_template, exc) return - MQTT_CLIENT.publish(msg_topic, payload, qos, retain) - hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_mqtt) + yield from hass.data[DATA_MQTT].async_publish( + msg_topic, payload, qos, retain) - descriptions = load_yaml_config_file( - os.path.join(os.path.dirname(__file__), 'services.yaml')) + descriptions = yield from hass.loop.run_in_executor( + None, load_yaml_config_file, os.path.join( + os.path.dirname(__file__), 'services.yaml')) - hass.services.register(DOMAIN, SERVICE_PUBLISH, publish_service, - descriptions.get(SERVICE_PUBLISH), - schema=MQTT_PUBLISH_SCHEMA) + hass.services.async_register( + DOMAIN, SERVICE_PUBLISH, async_publish_service, + descriptions.get(SERVICE_PUBLISH), schema=MQTT_PUBLISH_SCHEMA) if conf.get(CONF_DISCOVERY): - _setup_discovery(hass, config) + yield from _async_setup_discovery(hass, config) return True @@ -349,6 +378,7 @@ class MQTT(object): self.topics = {} self.progress = {} self.birth_message = birth_message + self._mqttc = None if protocol == PROTOCOL_31: proto = mqtt.MQTTv31 @@ -364,8 +394,8 @@ class MQTT(object): self._mqttc.username_pw_set(username, password) if certificate is not None: - self._mqttc.tls_set(certificate, certfile=client_cert, - keyfile=client_key) + self._mqttc.tls_set( + certificate, certfile=client_cert, keyfile=client_key) if tls_insecure is not None: self._mqttc.tls_insecure_set(tls_insecure) @@ -375,40 +405,69 @@ class MQTT(object): self._mqttc.on_connect = self._mqtt_on_connect self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_message = self._mqtt_on_message + if will_message: self._mqttc.will_set(will_message.get(ATTR_TOPIC), will_message.get(ATTR_PAYLOAD), will_message.get(ATTR_QOS), will_message.get(ATTR_RETAIN)) - self._mqttc.connect(broker, port, keepalive) - def publish(self, topic, payload, qos, retain): - """Publish a MQTT message.""" - self._mqttc.publish(topic, payload, qos, retain) + self._mqttc.connect_async(broker, port, keepalive) - def start(self): - """Run the MQTT client.""" - self._mqttc.loop_start() + def async_publish(self, topic, payload, qos, retain): + """Publish a MQTT message. - def stop(self): - """Stop the MQTT client.""" - self._mqttc.disconnect() - self._mqttc.loop_stop() + This method must be run in the event loop and returns a coroutine. + """ + return self.hass.loop.run_in_executor( + None, self._mqttc.publish, topic, payload, qos, retain) - def subscribe(self, topic, qos): - """Subscribe to a topic.""" - assert isinstance(topic, str) + def async_start(self): + """Run the MQTT client. + + This method must be run in the event loop and returns a coroutine. + """ + return self.hass.loop.run_in_executor(None, self._mqttc.loop_start) + + def async_stop(self): + """Stop the MQTT client. + + This method must be run in the event loop and returns a coroutine. + """ + def stop(self): + """Stop the MQTT client.""" + self._mqttc.disconnect() + self._mqttc.loop_stop() + + return self.hass.loop.run_in_executor(None, stop) + + @asyncio.coroutine + def async_subscribe(self, topic, qos): + """Subscribe to a topic. + + This method is a coroutine. + """ + if not isinstance(topic, str): + raise HomeAssistantError("topic need to be a string!") if topic in self.topics: return - result, mid = self._mqttc.subscribe(topic, qos) + result, mid = yield from self.hass.loop.run_in_executor( + None, self._mqttc.subscribe, topic, qos) + _raise_on_error(result) self.progress[mid] = topic self.topics[topic] = None - def unsubscribe(self, topic): - """Unsubscribe from topic.""" - result, mid = self._mqttc.unsubscribe(topic) + @asyncio.coroutine + def async_unsubscribe(self, topic): + """Unsubscribe from topic. + + This method is a coroutine. + """ + result, mid = yield from self.hass.loop.run_in_executor( + None, self._mqttc.unsubscribe, topic) + _raise_on_error(result) self.progress[mid] = topic @@ -437,12 +496,14 @@ class MQTT(object): for topic, qos in old_topics.items(): # qos is None if we were in process of subscribing if qos is not None: - self.subscribe(topic, qos) + self.hass.add_job(self.async_subscribe, topic, qos) + if self.birth_message: - self.publish(self.birth_message.get(ATTR_TOPIC), - self.birth_message.get(ATTR_PAYLOAD), - self.birth_message.get(ATTR_QOS), - self.birth_message.get(ATTR_RETAIN)) + self.hass.add_job(self.async_publish( + self.birth_message.get(ATTR_TOPIC), + self.birth_message.get(ATTR_PAYLOAD), + self.birth_message.get(ATTR_QOS), + self.birth_message.get(ATTR_RETAIN))) def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos): """Subscribe successful callback.""" diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index fa29f03c5e5a..a3b120410c55 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -9,7 +9,6 @@ import json import logging import re -from homeassistant.core import callback import homeassistant.components.mqtt as mqtt from homeassistant.components.mqtt import DOMAIN from homeassistant.helpers.discovery import async_load_platform @@ -24,7 +23,7 @@ TOPIC_MATCHER = re.compile( SUPPORTED_COMPONENTS = ['binary_sensor', 'sensor'] -@callback +@asyncio.coroutine def async_start(hass, discovery_topic, hass_config): """Initialization of MQTT Discovery.""" @asyncio.coroutine @@ -56,7 +55,7 @@ def async_start(hass, discovery_topic, hass_config): yield from async_load_platform( hass, component, DOMAIN, payload, hass_config) - mqtt.async_subscribe(hass, discovery_topic + '/#', - async_device_message_received, 0) + yield from mqtt.async_subscribe( + hass, discovery_topic + '/#', async_device_message_received, 0) return True diff --git a/homeassistant/components/mqtt/server.py b/homeassistant/components/mqtt/server.py index 57ad04fd18d0..c51649a3bef2 100644 --- a/homeassistant/components/mqtt/server.py +++ b/homeassistant/components/mqtt/server.py @@ -4,15 +4,14 @@ Support for a local MQTT broker. For more details about this component, please refer to the documentation at https://home-assistant.io/components/mqtt/#use-the-embedded-broker """ +import asyncio import logging import tempfile import voluptuous as vol -from homeassistant.core import callback from homeassistant.const import EVENT_HOMEASSISTANT_STOP import homeassistant.helpers.config_validation as cv -from homeassistant.util.async import run_coroutine_threadsafe REQUIREMENTS = ['hbmqtt==0.8'] DEPENDENCIES = ['http'] @@ -29,8 +28,12 @@ HBMQTT_CONFIG_SCHEMA = vol.Any(None, vol.Schema({ }, extra=vol.ALLOW_EXTRA)) -def start(hass, server_config): - """Initialize MQTT Server.""" +@asyncio.coroutine +def async_start(hass, server_config): + """Initialize MQTT Server. + + This method is a coroutine. + """ from hbmqtt.broker import Broker, BrokerException try: @@ -42,19 +45,20 @@ def start(hass, server_config): client_config = None broker = Broker(server_config, hass.loop) - run_coroutine_threadsafe(broker.start(), hass.loop).result() + yield from broker.start() except BrokerException: logging.getLogger(__name__).exception('Error initializing MQTT server') return False, None finally: passwd.close() - @callback - def shutdown_mqtt_server(event): + @asyncio.coroutine + def async_shutdown_mqtt_server(event): """Shut down the MQTT server.""" - hass.async_add_job(broker.shutdown()) + yield from broker.shutdown() - hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, shutdown_mqtt_server) + hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_STOP, async_shutdown_mqtt_server) return True, client_config diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 1b4819ddf9be..12e031bfc3e9 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -34,6 +34,7 @@ def threaded_listener_factory(async_factory): return factory +@callback def async_track_state_change(hass, entity_ids, action, from_state=None, to_state=None): """Track specific state changes. @@ -84,6 +85,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None, track_state_change = threaded_listener_factory(async_track_state_change) +@callback def async_track_template(hass, template, action, variables=None): """Add a listener that track state changes with template condition.""" from . import condition @@ -111,6 +113,7 @@ def async_track_template(hass, template, action, variables=None): track_template = threaded_listener_factory(async_track_template) +@callback def async_track_point_in_time(hass, action, point_in_time): """Add a listener that fires once after a specific point in time.""" utc_point_in_time = dt_util.as_utc(point_in_time) @@ -127,6 +130,7 @@ def async_track_point_in_time(hass, action, point_in_time): track_point_in_time = threaded_listener_factory(async_track_point_in_time) +@callback def async_track_point_in_utc_time(hass, action, point_in_time): """Add a listener that fires once after a specific point in UTC time.""" # Ensure point_in_time is UTC @@ -160,6 +164,7 @@ track_point_in_utc_time = threaded_listener_factory( async_track_point_in_utc_time) +@callback def async_track_time_interval(hass, action, interval): """Add a listener that fires repetitively at every timedelta interval.""" remove = None @@ -189,6 +194,7 @@ def async_track_time_interval(hass, action, interval): track_time_interval = threaded_listener_factory(async_track_time_interval) +@callback def async_track_sunrise(hass, action, offset=None): """Add a listener that will fire a specified offset from sunrise daily.""" from homeassistant.components import sun @@ -225,6 +231,7 @@ def async_track_sunrise(hass, action, offset=None): track_sunrise = threaded_listener_factory(async_track_sunrise) +@callback def async_track_sunset(hass, action, offset=None): """Add a listener that will fire a specified offset from sunset daily.""" from homeassistant.components import sun @@ -261,6 +268,7 @@ def async_track_sunset(hass, action, offset=None): track_sunset = threaded_listener_factory(async_track_sunset) +@callback def async_track_utc_time_change(hass, action, year=None, month=None, day=None, hour=None, minute=None, second=None, local=False): @@ -305,6 +313,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None, track_utc_time_change = threaded_listener_factory(async_track_utc_time_change) +@callback def async_track_time_change(hass, action, year=None, month=None, day=None, hour=None, minute=None, second=None): """Add a listener that will fire if UTC time matches a pattern.""" diff --git a/tests/common.py b/tests/common.py index ac6856c4a64a..16cc580a5e11 100644 --- a/tests/common.py +++ b/tests/common.py @@ -86,7 +86,13 @@ def async_test_home_assistant(loop): loop._thread_ident = threading.get_ident() hass = ha.HomeAssistant(loop) - hass.async_track_tasks() + + def async_add_job(target, *args): + if isinstance(target, MagicMock): + return + hass._async_add_job_tracking(target, *args) + + hass.async_add_job = async_add_job hass.config.location_name = 'test home' hass.config.config_dir = get_test_config_dir() diff --git a/tests/components/alarm_control_panel/test_mqtt.py b/tests/components/alarm_control_panel/test_mqtt.py index c253c4a49fb0..f1bbb7118486 100644 --- a/tests/components/alarm_control_panel/test_mqtt.py +++ b/tests/components/alarm_control_panel/test_mqtt.py @@ -111,7 +111,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase): alarm_control_panel.alarm_arm_home(self.hass) self.hass.block_till_done() self.assertEqual(('alarm/command', 'ARM_HOME', 0, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) def test_arm_home_not_publishes_mqtt_with_invalid_code(self): """Test not publishing of MQTT messages with invalid code.""" @@ -146,7 +146,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase): alarm_control_panel.alarm_arm_away(self.hass) self.hass.block_till_done() self.assertEqual(('alarm/command', 'ARM_AWAY', 0, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) def test_arm_away_not_publishes_mqtt_with_invalid_code(self): """Test not publishing of MQTT messages with invalid code.""" @@ -181,7 +181,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase): alarm_control_panel.alarm_disarm(self.hass) self.hass.block_till_done() self.assertEqual(('alarm/command', 'DISARM', 0, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) def test_disarm_not_publishes_mqtt_with_invalid_code(self): """Test not publishing of MQTT messages with invalid code.""" diff --git a/tests/components/cover/test_mqtt.py b/tests/components/cover/test_mqtt.py index 05fa29263a9d..81518458e0ec 100644 --- a/tests/components/cover/test_mqtt.py +++ b/tests/components/cover/test_mqtt.py @@ -118,7 +118,7 @@ class TestCoverMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'OPEN', 0, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('cover.test') self.assertEqual(STATE_OPEN, state.state) @@ -126,7 +126,7 @@ class TestCoverMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'CLOSE', 0, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('cover.test') self.assertEqual(STATE_CLOSED, state.state) @@ -150,7 +150,7 @@ class TestCoverMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'OPEN', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('cover.test') self.assertEqual(STATE_UNKNOWN, state.state) @@ -174,7 +174,7 @@ class TestCoverMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'CLOSE', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('cover.test') self.assertEqual(STATE_UNKNOWN, state.state) @@ -198,7 +198,7 @@ class TestCoverMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'STOP', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('cover.test') self.assertEqual(STATE_UNKNOWN, state.state) diff --git a/tests/components/light/test_mqtt.py b/tests/components/light/test_mqtt.py index fb6f98f37b4c..4f0d4a273b6d 100644 --- a/tests/components/light/test_mqtt.py +++ b/tests/components/light/test_mqtt.py @@ -74,6 +74,7 @@ light: """ import unittest +from unittest import mock from homeassistant.bootstrap import setup_component from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ASSUMED_STATE @@ -328,7 +329,7 @@ class TestLightMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('test_light_rgb/set', 'on', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('light.test') self.assertEqual(STATE_ON, state.state) @@ -336,27 +337,20 @@ class TestLightMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('test_light_rgb/set', 'off', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('light.test') self.assertEqual(STATE_OFF, state.state) + self.mock_publish.reset_mock() light.turn_on(self.hass, 'light.test', rgb_color=[75, 75, 75], brightness=50) self.hass.block_till_done() - # Calls are threaded so we need to reorder them - bright_call, rgb_call, state_call = \ - sorted((call[1] for call in self.mock_publish.mock_calls[-3:]), - key=lambda call: call[0]) - - self.assertEqual(('test_light_rgb/set', 'on', 2, False), - state_call) - - self.assertEqual(('test_light_rgb/rgb/set', '75,75,75', 2, False), - rgb_call) - - self.assertEqual(('test_light_rgb/brightness/set', 50, 2, False), - bright_call) + self.mock_publish().async_publish.assert_has_calls([ + mock.call('test_light_rgb/set', 'on', 2, False), + mock.call('test_light_rgb/rgb/set', '75,75,75', 2, False), + mock.call('test_light_rgb/brightness/set', 50, 2, False), + ], any_order=True) state = self.hass.states.get('light.test') self.assertEqual(STATE_ON, state.state) diff --git a/tests/components/light/test_mqtt_json.py b/tests/components/light/test_mqtt_json.py index cbfa0470fed1..4f48181a9176 100755 --- a/tests/components/light/test_mqtt_json.py +++ b/tests/components/light/test_mqtt_json.py @@ -172,7 +172,7 @@ class TestLightMQTTJSON(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('test_light_rgb/set', '{"state": "ON"}', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('light.test') self.assertEqual(STATE_ON, state.state) @@ -180,7 +180,7 @@ class TestLightMQTTJSON(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('test_light_rgb/set', '{"state": "OFF"}', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('light.test') self.assertEqual(STATE_OFF, state.state) @@ -189,11 +189,11 @@ class TestLightMQTTJSON(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(2, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(2, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # Get the sent message - message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) + message_json = json.loads(self.mock_publish.mock_calls[-2][1][1]) self.assertEqual(50, message_json["brightness"]) self.assertEqual(75, message_json["color"]["r"]) self.assertEqual(75, message_json["color"]["g"]) @@ -228,11 +228,11 @@ class TestLightMQTTJSON(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # Get the sent message - message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) + message_json = json.loads(self.mock_publish.mock_calls[-2][1][1]) self.assertEqual(5, message_json["flash"]) self.assertEqual("ON", message_json["state"]) @@ -240,11 +240,11 @@ class TestLightMQTTJSON(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # Get the sent message - message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) + message_json = json.loads(self.mock_publish.mock_calls[-2][1][1]) self.assertEqual(15, message_json["flash"]) self.assertEqual("ON", message_json["state"]) @@ -268,11 +268,11 @@ class TestLightMQTTJSON(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # Get the sent message - message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) + message_json = json.loads(self.mock_publish.mock_calls[-2][1][1]) self.assertEqual(10, message_json["transition"]) self.assertEqual("ON", message_json["state"]) @@ -281,11 +281,11 @@ class TestLightMQTTJSON(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # Get the sent message - message_json = json.loads(self.mock_publish.mock_calls[-1][1][1]) + message_json = json.loads(self.mock_publish.mock_calls[-2][1][1]) self.assertEqual(10, message_json["transition"]) self.assertEqual("OFF", message_json["state"]) diff --git a/tests/components/light/test_mqtt_template.py b/tests/components/light/test_mqtt_template.py index b2a37c97001f..e097aba92a9c 100755 --- a/tests/components/light/test_mqtt_template.py +++ b/tests/components/light/test_mqtt_template.py @@ -196,7 +196,7 @@ class TestLightMQTTTemplate(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('test_light_rgb/set', 'on,,--', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('light.test') self.assertEqual(STATE_ON, state.state) @@ -205,7 +205,7 @@ class TestLightMQTTTemplate(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('test_light_rgb/set', 'off', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('light.test') self.assertEqual(STATE_OFF, state.state) @@ -215,12 +215,12 @@ class TestLightMQTTTemplate(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(2, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(2, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # check the payload - payload = self.mock_publish.mock_calls[-1][1][1] + payload = self.mock_publish.mock_calls[-2][1][1] self.assertEqual('on,50,75-75-75', payload) # check the state @@ -253,12 +253,12 @@ class TestLightMQTTTemplate(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # check the payload - payload = self.mock_publish.mock_calls[-1][1][1] + payload = self.mock_publish.mock_calls[-2][1][1] self.assertEqual('on,short', payload) # long flash @@ -266,12 +266,12 @@ class TestLightMQTTTemplate(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # check the payload - payload = self.mock_publish.mock_calls[-1][1][1] + payload = self.mock_publish.mock_calls[-2][1][1] self.assertEqual('on,long', payload) def test_transition(self): @@ -296,12 +296,12 @@ class TestLightMQTTTemplate(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # check the payload - payload = self.mock_publish.mock_calls[-1][1][1] + payload = self.mock_publish.mock_calls[-2][1][1] self.assertEqual('on,10', payload) # transition off @@ -309,12 +309,12 @@ class TestLightMQTTTemplate(unittest.TestCase): self.hass.block_till_done() self.assertEqual('test_light_rgb/set', - self.mock_publish.mock_calls[-1][1][0]) - self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2]) - self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3]) + self.mock_publish.mock_calls[-2][1][0]) + self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2]) + self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3]) # check the payload - payload = self.mock_publish.mock_calls[-1][1][1] + payload = self.mock_publish.mock_calls[-2][1][1] self.assertEqual('off,4', payload) def test_invalid_values(self): \ diff --git a/tests/components/lock/test_mqtt.py b/tests/components/lock/test_mqtt.py index f22729a1e5bb..c858d58dfa71 100644 --- a/tests/components/lock/test_mqtt.py +++ b/tests/components/lock/test_mqtt.py @@ -73,7 +73,7 @@ class TestLockMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'LOCK', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('lock.test') self.assertEqual(STATE_LOCKED, state.state) @@ -81,7 +81,7 @@ class TestLockMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'UNLOCK', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('lock.test') self.assertEqual(STATE_UNLOCKED, state.state) diff --git a/tests/components/mqtt/test_discovery.py b/tests/components/mqtt/test_discovery.py index bf6fa2f26035..389fb37b4896 100644 --- a/tests/components/mqtt/test_discovery.py +++ b/tests/components/mqtt/test_discovery.py @@ -12,9 +12,10 @@ def test_subscribing_config_topic(hass, mqtt_mock): """Test setting up discovery.""" hass_config = {} discovery_topic = 'homeassistant' - async_start(hass, discovery_topic, hass_config) - assert mqtt_mock.subscribe.called - call_args = mqtt_mock.subscribe.mock_calls[0][1] + yield from async_start(hass, discovery_topic, hass_config) + + assert mqtt_mock.async_subscribe.called + call_args = mqtt_mock.async_subscribe.mock_calls[0][1] assert call_args[0] == discovery_topic + '/#' assert call_args[1] == 0 @@ -24,7 +25,7 @@ def test_subscribing_config_topic(hass, mqtt_mock): def test_invalid_topic(mock_load_platform, hass, mqtt_mock): """Test sending in invalid JSON.""" mock_load_platform.return_value = mock_coro() - async_start(hass, 'homeassistant', {}) + yield from async_start(hass, 'homeassistant', {}) async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/not_config', '{}') @@ -37,7 +38,7 @@ def test_invalid_topic(mock_load_platform, hass, mqtt_mock): def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog): """Test sending in invalid JSON.""" mock_load_platform.return_value = mock_coro() - async_start(hass, 'homeassistant', {}) + yield from async_start(hass, 'homeassistant', {}) async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config', 'not json') @@ -51,7 +52,7 @@ def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog): def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog): """Test sending in invalid JSON.""" mock_load_platform.return_value = mock_coro() - async_start(hass, 'homeassistant', {}) + yield from async_start(hass, 'homeassistant', {}) async_fire_mqtt_message(hass, 'homeassistant/climate/bla/config', '{}') yield from hass.async_block_till_done() @@ -62,7 +63,7 @@ def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog): @asyncio.coroutine def test_correct_config_discovery(hass, mqtt_mock, caplog): """Test sending in invalid JSON.""" - async_start(hass, 'homeassistant', {}) + yield from async_start(hass, 'homeassistant', {}) async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config', '{ "name": "Beer" }') diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 204f1fe15b4d..2f96ada94aec 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1,5 +1,6 @@ """The tests for the MQTT component.""" -from collections import namedtuple +import asyncio +from collections import namedtuple, OrderedDict import unittest from unittest import mock import socket @@ -7,14 +8,29 @@ import socket import voluptuous as vol from homeassistant.core import callback -from homeassistant.bootstrap import setup_component +from homeassistant.bootstrap import setup_component, async_setup_component import homeassistant.components.mqtt as mqtt from homeassistant.const import ( EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP) from tests.common import ( - get_test_home_assistant, mock_mqtt_component, fire_mqtt_message) + get_test_home_assistant, mock_mqtt_component, fire_mqtt_message, mock_coro) + + +@asyncio.coroutine +def mock_mqtt_client(hass, config=None): + """Mock the MQTT paho client.""" + if config is None: + config = { + mqtt.CONF_BROKER: 'mock-broker' + } + + with mock.patch('paho.mqtt.client.Client') as mock_client: + yield from async_setup_component(hass, mqtt.DOMAIN, { + mqtt.DOMAIN: config + }) + return mock_client() # pylint: disable=invalid-name @@ -40,7 +56,7 @@ class TestMQTT(unittest.TestCase): """"Test if client start on HA launch.""" self.hass.bus.fire(EVENT_HOMEASSISTANT_START) self.hass.block_till_done() - self.assertTrue(mqtt.MQTT_CLIENT.start.called) + self.assertTrue(self.hass.data['mqtt'].async_start.called) def test_client_stops_on_home_assistant_start(self): """Test if client stops on HA launch.""" @@ -48,7 +64,7 @@ class TestMQTT(unittest.TestCase): self.hass.block_till_done() self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP) self.hass.block_till_done() - self.assertTrue(mqtt.MQTT_CLIENT.stop.called) + self.assertTrue(self.hass.data['mqtt'].async_stop.called) @mock.patch('paho.mqtt.client.Client') def test_setup_fails_if_no_connect_broker(self, _): @@ -69,14 +85,17 @@ class TestMQTT(unittest.TestCase): """Test setting up embedded server with no config.""" client_config = ('localhost', 1883, 'user', 'pass', None, '3.1.1') - with mock.patch('homeassistant.components.mqtt.server.start', - return_value=(True, client_config)) as _start: + with mock.patch('homeassistant.components.mqtt.server.async_start', + return_value=mock_coro( + return_value=(True, client_config)) + ) as _start: self.hass.config.components = set() assert setup_component(self.hass, mqtt.DOMAIN, {mqtt.DOMAIN: {}}) assert _start.call_count == 1 # Test with `embedded: None` + _start.return_value = mock_coro(return_value=(True, client_config)) self.hass.config.components = set() assert setup_component(self.hass, mqtt.DOMAIN, {mqtt.DOMAIN: {'embedded': None}}) @@ -105,7 +124,7 @@ class TestMQTT(unittest.TestCase): ATTR_SERVICE: mqtt.SERVICE_PUBLISH }) self.hass.block_till_done() - self.assertTrue(not mqtt.MQTT_CLIENT.publish.called) + self.assertTrue(not self.hass.data['mqtt'].async_publish.called) def test_service_call_with_template_payload_renders_template(self): """Test the service call with rendered template. @@ -114,8 +133,9 @@ class TestMQTT(unittest.TestCase): """ mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}") self.hass.block_till_done() - self.assertTrue(mqtt.MQTT_CLIENT.publish.called) - self.assertEqual(mqtt.MQTT_CLIENT.publish.call_args[0][1], "2") + self.assertTrue(self.hass.data['mqtt'].async_publish.called) + self.assertEqual( + self.hass.data['mqtt'].async_publish.call_args[0][1], "2") def test_service_call_with_payload_doesnt_render_template(self): """Test the service call with unrendered template. @@ -129,7 +149,7 @@ class TestMQTT(unittest.TestCase): mqtt.ATTR_PAYLOAD: payload, mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template }, blocking=True) - self.assertFalse(mqtt.MQTT_CLIENT.publish.called) + self.assertFalse(self.hass.data['mqtt'].async_publish.called) def test_service_call_with_ascii_qos_retain_flags(self): """Test the service call with args that can be misinterpreted. @@ -142,9 +162,10 @@ class TestMQTT(unittest.TestCase): mqtt.ATTR_QOS: '2', mqtt.ATTR_RETAIN: 'no' }, blocking=True) - self.assertTrue(mqtt.MQTT_CLIENT.publish.called) - self.assertEqual(mqtt.MQTT_CLIENT.publish.call_args[0][2], 2) - self.assertFalse(mqtt.MQTT_CLIENT.publish.call_args[0][3]) + self.assertTrue(self.hass.data['mqtt'].async_publish.called) + self.assertEqual( + self.hass.data['mqtt'].async_publish.call_args[0][2], 2) + self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3]) def test_subscribe_topic(self): """Test the subscription of a topic.""" @@ -231,15 +252,12 @@ class TestMQTTCallbacks(unittest.TestCase): def setUp(self): # pylint: disable=invalid-name """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - # mock_mqtt_component(self.hass) with mock.patch('paho.mqtt.client.Client'): self.hass.config.components = set() assert setup_component(self.hass, mqtt.DOMAIN, { mqtt.DOMAIN: { mqtt.CONF_BROKER: 'mock-broker', - mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth', - mqtt.ATTR_PAYLOAD: 'birth'} } }) @@ -261,7 +279,8 @@ class TestMQTTCallbacks(unittest.TestCase): MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8')) - mqtt.MQTT_CLIENT._mqtt_on_message(None, {'hass': self.hass}, message) + self.hass.data['mqtt']._mqtt_on_message( + None, {'hass': self.hass}, message) self.hass.block_till_done() self.assertEqual(1, len(calls)) @@ -273,68 +292,36 @@ class TestMQTTCallbacks(unittest.TestCase): def test_mqtt_failed_connection_results_in_disconnect(self): """Test if connection failure leads to disconnect.""" for result_code in range(1, 6): - mqtt.MQTT_CLIENT._mqttc = mock.MagicMock() - mqtt.MQTT_CLIENT._mqtt_on_connect(None, {'topics': {}}, 0, - result_code) - self.assertTrue(mqtt.MQTT_CLIENT._mqttc.disconnect.called) - - def test_mqtt_subscribes_topics_on_connect(self): - """Test subscription to topic on connect.""" - from collections import OrderedDict - prev_topics = OrderedDict() - prev_topics['topic/test'] = 1, - prev_topics['home/sensor'] = 2, - prev_topics['still/pending'] = None - - mqtt.MQTT_CLIENT.topics = prev_topics - mqtt.MQTT_CLIENT.progress = {1: 'still/pending'} - # Return values for subscribe calls (rc, mid) - mqtt.MQTT_CLIENT._mqttc.subscribe.side_effect = ((0, 2), (0, 3)) - mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0) - self.assertFalse(mqtt.MQTT_CLIENT._mqttc.disconnect.called) - - expected = [(topic, qos) for topic, qos in prev_topics.items() - if qos is not None] - self.assertEqual( - expected, - [call[1] for call in mqtt.MQTT_CLIENT._mqttc.subscribe.mock_calls]) - self.assertEqual({ - 1: 'still/pending', - 2: 'topic/test', - 3: 'home/sensor', - }, mqtt.MQTT_CLIENT.progress) - - def test_mqtt_birth_message_on_connect(self): \ - # pylint: disable=no-self-use - """Test birth message on connect.""" - mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0) - mqtt.MQTT_CLIENT._mqttc.publish.assert_called_with('birth', 'birth', 0, - False) + self.hass.data['mqtt']._mqttc = mock.MagicMock() + self.hass.data['mqtt']._mqtt_on_connect( + None, {'topics': {}}, 0, result_code) + self.assertTrue(self.hass.data['mqtt']._mqttc.disconnect.called) def test_mqtt_disconnect_tries_no_reconnect_on_stop(self): """Test the disconnect tries.""" - mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 0) - self.assertFalse(mqtt.MQTT_CLIENT._mqttc.reconnect.called) + self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0) + self.assertFalse(self.hass.data['mqtt']._mqttc.reconnect.called) @mock.patch('homeassistant.components.mqtt.time.sleep') def test_mqtt_disconnect_tries_reconnect(self, mock_sleep): """Test the re-connect tries.""" - mqtt.MQTT_CLIENT.topics = { + self.hass.data['mqtt'].topics = { 'test/topic': 1, 'test/progress': None } - mqtt.MQTT_CLIENT.progress = { + self.hass.data['mqtt'].progress = { 1: 'test/progress' } - mqtt.MQTT_CLIENT._mqttc.reconnect.side_effect = [1, 1, 1, 0] - mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 1) - self.assertTrue(mqtt.MQTT_CLIENT._mqttc.reconnect.called) - self.assertEqual(4, len(mqtt.MQTT_CLIENT._mqttc.reconnect.mock_calls)) + self.hass.data['mqtt']._mqttc.reconnect.side_effect = [1, 1, 1, 0] + self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 1) + self.assertTrue(self.hass.data['mqtt']._mqttc.reconnect.called) + self.assertEqual( + 4, len(self.hass.data['mqtt']._mqttc.reconnect.mock_calls)) self.assertEqual([1, 2, 4], [call[1][0] for call in mock_sleep.mock_calls]) - self.assertEqual({'test/topic': 1}, mqtt.MQTT_CLIENT.topics) - self.assertEqual({}, mqtt.MQTT_CLIENT.progress) + self.assertEqual({'test/topic': 1}, self.hass.data['mqtt'].topics) + self.assertEqual({}, self.hass.data['mqtt'].progress) def test_invalid_mqtt_topics(self): """Test invalid topics.""" @@ -356,7 +343,7 @@ class TestMQTTCallbacks(unittest.TestCase): MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) message = MQTTMessage(topic, 1, payload) with self.assertLogs(level='ERROR') as test_handle: - mqtt.MQTT_CLIENT._mqtt_on_message( + self.hass.data['mqtt']._mqtt_on_message( None, {'hass': self.hass}, message) @@ -365,3 +352,47 @@ class TestMQTTCallbacks(unittest.TestCase): "ERROR:homeassistant.components.mqtt:Illegal utf-8 unicode " "payload from MQTT topic: %s, Payload: " % topic, test_handle.output[0]) + + +@asyncio.coroutine +def test_birth_message(hass): + """Test sending birth message.""" + mqtt_client = yield from mock_mqtt_client(hass, { + mqtt.CONF_BROKER: 'mock-broker', + mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth', + mqtt.ATTR_PAYLOAD: 'birth'} + }) + calls = [] + mqtt_client.publish = lambda *args: calls.append(args) + hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0) + yield from hass.async_block_till_done() + assert calls[-1] == ('birth', 'birth', 0, False) + + +@asyncio.coroutine +def test_mqtt_subscribes_topics_on_connect(hass): + """Test subscription to topic on connect.""" + mqtt_client = yield from mock_mqtt_client(hass) + + prev_topics = OrderedDict() + prev_topics['topic/test'] = 1, + prev_topics['home/sensor'] = 2, + prev_topics['still/pending'] = None + + hass.data['mqtt'].topics = prev_topics + hass.data['mqtt'].progress = {1: 'still/pending'} + + # Return values for subscribe calls (rc, mid) + mqtt_client.subscribe.side_effect = ((0, 2), (0, 3)) + + hass.add_job = mock.MagicMock() + hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0) + + yield from hass.async_block_till_done() + + assert not mqtt_client.disconnect.called + + expected = [(topic, qos) for topic, qos in prev_topics.items() + if qos is not None] + + assert [call[1][1:] for call in hass.add_job.mock_calls] == expected diff --git a/tests/components/mqtt/test_server.py b/tests/components/mqtt/test_server.py index ceb648c6ef57..f017d257a266 100644 --- a/tests/components/mqtt/test_server.py +++ b/tests/components/mqtt/test_server.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, MagicMock, patch from homeassistant.bootstrap import setup_component import homeassistant.components.mqtt as mqtt -from tests.common import get_test_home_assistant +from tests.common import get_test_home_assistant, mock_coro class TestMQTT: @@ -21,9 +21,8 @@ class TestMQTT: @patch('passlib.apps.custom_app_context', Mock(return_value='')) @patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock())) - @patch('homeassistant.components.mqtt.server.run_coroutine_threadsafe', - Mock(return_value=MagicMock())) @patch('hbmqtt.broker.Broker', Mock(return_value=MagicMock())) + @patch('hbmqtt.broker.Broker.start', Mock(return_value=mock_coro())) @patch('homeassistant.components.mqtt.MQTT') def test_creating_config_with_http_pass(self, mock_mqtt): """Test if the MQTT server gets started and subscribe/publish msg.""" @@ -46,7 +45,7 @@ class TestMQTT: assert mock_mqtt.mock_calls[0][1][6] is None @patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock())) - @patch('homeassistant.components.mqtt.server.run_coroutine_threadsafe') + @patch('hbmqtt.broker.Broker.start', return_value=mock_coro()) def test_broker_config_fails(self, mock_run): """Test if the MQTT component fails if server fails.""" from hbmqtt.broker import BrokerException diff --git a/tests/components/switch/test_mqtt.py b/tests/components/switch/test_mqtt.py index f6f9ffa0f86f..3a5502c81502 100644 --- a/tests/components/switch/test_mqtt.py +++ b/tests/components/switch/test_mqtt.py @@ -72,7 +72,7 @@ class TestSensorMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'beer on', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('switch.test') self.assertEqual(STATE_ON, state.state) @@ -80,7 +80,7 @@ class TestSensorMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(('command-topic', 'beer off', 2, False), - self.mock_publish.mock_calls[-1][1]) + self.mock_publish.mock_calls[-2][1]) state = self.hass.states.get('switch.test') self.assertEqual(STATE_OFF, state.state)