MQTT convert to async (#6064)

* Migrate mqtt to async

* address paulus comment / convert it complet async

* adress paulus comment / remove future

* Automation triggers should be async

* Fix MQTT async calls

* Show that event helpers are callbacks

* Fix tests

* Lint
This commit is contained in:
Pascal Vizeli 2017-02-18 23:17:18 +01:00 committed by Paulus Schoutsen
parent fa2c1dafdf
commit e1cbd6b4c0
25 changed files with 356 additions and 231 deletions

View file

@ -412,7 +412,7 @@ def _async_process_trigger(hass, config, trigger_configs, name, action):
if platform is None:
return None
remove = platform.async_trigger(hass, conf, action)
remove = yield from platform.async_trigger(hass, conf, action)
if not remove:
_LOGGER.error("Error setting up trigger %s", name)

View file

@ -4,6 +4,7 @@ Offer event listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#event-trigger
"""
import asyncio
import logging
import voluptuous as vol
@ -24,6 +25,7 @@ TRIGGER_SCHEMA = vol.Schema({
})
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for events based on configuration."""
event_type = config.get(CONF_EVENT_TYPE)

View file

@ -4,6 +4,7 @@ Trigger an automation when a LiteJet switch is released.
For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/automation.litejet/
"""
import asyncio
import logging
import voluptuous as vol
@ -32,6 +33,7 @@ TRIGGER_SCHEMA = vol.Schema({
})
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for events based on configuration."""
number = config.get(CONF_NUMBER)

View file

@ -4,6 +4,7 @@ Offer MQTT listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#mqtt-trigger
"""
import asyncio
import json
import voluptuous as vol
@ -24,6 +25,7 @@ TRIGGER_SCHEMA = vol.Schema({
})
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for state changes based on configuration."""
topic = config.get(CONF_TOPIC)
@ -49,4 +51,6 @@ def async_trigger(hass, config, action):
'trigger': data
})
return mqtt.async_subscribe(hass, topic, mqtt_automation_listener)
remove = yield from mqtt.async_subscribe(
hass, topic, mqtt_automation_listener)
return remove

View file

@ -4,6 +4,7 @@ Offer numeric state listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#numeric-state-trigger
"""
import asyncio
import logging
import voluptuous as vol
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({
_LOGGER = logging.getLogger(__name__)
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID)

View file

@ -4,6 +4,7 @@ Offer state listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#state-trigger
"""
import asyncio
import voluptuous as vol
from homeassistant.core import callback
@ -34,6 +35,7 @@ TRIGGER_SCHEMA = vol.All(
)
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID)
@ -97,6 +99,7 @@ def async_trigger(hass, config, action):
unsub = async_track_state_change(
hass, entity_id, state_automation_listener, from_state, to_state)
@callback
def async_remove():
"""Remove state listeners async."""
unsub()

View file

@ -4,6 +4,7 @@ Offer sun based automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#sun-trigger
"""
import asyncio
from datetime import timedelta
import logging
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.Schema({
})
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for events based on configuration."""
event = config.get(CONF_EVENT)

View file

@ -4,6 +4,7 @@ Offer template automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#template-trigger
"""
import asyncio
import logging
import voluptuous as vol
@ -22,6 +23,7 @@ TRIGGER_SCHEMA = IF_ACTION_SCHEMA = vol.Schema({
})
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for state changes based on configuration."""
value_template = config.get(CONF_VALUE_TEMPLATE)

View file

@ -4,6 +4,7 @@ Offer time listening automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#time-trigger
"""
import asyncio
import logging
import voluptuous as vol
@ -29,6 +30,7 @@ TRIGGER_SCHEMA = vol.All(vol.Schema({
CONF_SECONDS, CONF_AFTER))
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for state changes based on configuration."""
if CONF_AFTER in config:

View file

@ -4,6 +4,7 @@ Offer zone automation rules.
For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#zone-trigger
"""
import asyncio
import voluptuous as vol
from homeassistant.core import callback
@ -26,6 +27,7 @@ TRIGGER_SCHEMA = vol.Schema({
})
@asyncio.coroutine
def async_trigger(hass, config, action):
"""Listen for state changes based on configuration."""
entity_id = config.get(CONF_ENTITY_ID)

View file

@ -4,6 +4,7 @@ Support for MQTT message handling.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/
"""
import asyncio
import logging
import os
import socket
@ -12,11 +13,12 @@ import time
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.bootstrap import prepare_setup_platform
from homeassistant.bootstrap import async_prepare_setup_platform
from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import template, config_validation as cv
from homeassistant.helpers.event import threaded_listener_factory
from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe)
from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, CONF_VALUE_TEMPLATE,
CONF_USERNAME, CONF_PASSWORD, CONF_PORT, CONF_PROTOCOL, CONF_PAYLOAD)
@ -26,7 +28,7 @@ _LOGGER = logging.getLogger(__name__)
DOMAIN = 'mqtt'
MQTT_CLIENT = None
DATA_MQTT = 'mqtt'
SERVICE_PUBLISH = 'publish'
EVENT_MQTT_MESSAGE_RECEIVED = 'mqtt_message_received'
@ -183,11 +185,11 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
hass.services.call(DOMAIN, SERVICE_PUBLISH, data)
@callback
@asyncio.coroutine
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic."""
@callback
def mqtt_topic_subscriber(event):
def async_mqtt_topic_subscriber(event):
"""Match subscribed MQTT topic."""
if not _match_topic(topic, event.data[ATTR_TOPIC]):
return
@ -195,61 +197,82 @@ def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
hass.async_run_job(msg_callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS])
async_remove = hass.bus.async_listen(EVENT_MQTT_MESSAGE_RECEIVED,
mqtt_topic_subscriber)
# Future: track subscriber count and unsubscribe in remove
MQTT_CLIENT.subscribe(topic, qos)
async_remove = hass.bus.async_listen(
EVENT_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber)
yield from hass.data[DATA_MQTT].async_subscribe(topic, qos)
return async_remove
# pylint: disable=invalid-name
subscribe = threaded_listener_factory(async_subscribe)
def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic."""
async_remove = run_coroutine_threadsafe(
async_subscribe(hass, topic, msg_callback, qos),
hass.loop
).result()
def remove():
"""Remove listener convert."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove
def _setup_server(hass, config):
"""Try to start embedded MQTT broker."""
@asyncio.coroutine
def _async_setup_server(hass, config):
"""Try to start embedded MQTT broker.
This method is a coroutine.
"""
conf = config.get(DOMAIN, {})
# Only setup if embedded config passed in or no broker specified
if CONF_EMBEDDED not in conf and CONF_BROKER in conf:
return None
server = prepare_setup_platform(hass, config, DOMAIN, 'server')
server = yield from async_prepare_setup_platform(
hass, config, DOMAIN, 'server')
if server is None:
_LOGGER.error("Unable to load embedded server")
return None
success, broker_config = server.start(hass, conf.get(CONF_EMBEDDED))
success, broker_config = \
yield from server.async_start(hass, conf.get(CONF_EMBEDDED))
return success and broker_config
def _setup_discovery(hass, config):
"""Try to start the discovery of MQTT devices."""
@asyncio.coroutine
def _async_setup_discovery(hass, config):
"""Try to start the discovery of MQTT devices.
This method is a coroutine.
"""
conf = config.get(DOMAIN, {})
discovery = prepare_setup_platform(hass, config, DOMAIN, 'discovery')
discovery = yield from async_prepare_setup_platform(
hass, config, DOMAIN, 'discovery')
if discovery is None:
_LOGGER.error("Unable to load MQTT discovery")
return None
success = discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config)
success = yield from discovery.async_start(
hass, conf[CONF_DISCOVERY_PREFIX], config)
return success
def setup(hass, config):
@asyncio.coroutine
def async_setup(hass, config):
"""Start the MQTT protocol service."""
conf = config.get(DOMAIN, {})
client_id = conf.get(CONF_CLIENT_ID)
keepalive = conf.get(CONF_KEEPALIVE)
broker_config = _setup_server(hass, config)
broker_config = yield from _async_setup_server(hass, config)
if CONF_BROKER in conf:
broker = conf[CONF_BROKER]
@ -283,27 +306,31 @@ def setup(hass, config):
will_message = conf.get(CONF_WILL_MESSAGE)
birth_message = conf.get(CONF_BIRTH_MESSAGE)
global MQTT_CLIENT
try:
MQTT_CLIENT = MQTT(hass, broker, port, client_id, keepalive,
username, password, certificate, client_key,
client_cert, tls_insecure, protocol, will_message,
birth_message)
hass.data[DATA_MQTT] = MQTT(
hass, broker, port, client_id, keepalive, username, password,
certificate, client_key, client_cert, tls_insecure, protocol,
will_message, birth_message)
except socket.error:
_LOGGER.exception("Can't connect to the broker. "
"Please check your settings and the broker itself")
return False
def stop_mqtt(event):
@asyncio.coroutine
def async_stop_mqtt(event):
"""Stop MQTT component."""
MQTT_CLIENT.stop()
yield from hass.data[DATA_MQTT].async_stop()
def start_mqtt(event):
@asyncio.coroutine
def async_start_mqtt(event):
"""Launch MQTT component when Home Assistant starts up."""
MQTT_CLIENT.start()
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_mqtt)
yield from hass.data[DATA_MQTT].async_start()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
def publish_service(call):
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, async_start_mqtt)
@asyncio.coroutine
def async_publish_service(call):
"""Handle MQTT publish service calls."""
msg_topic = call.data[ATTR_TOPIC]
payload = call.data.get(ATTR_PAYLOAD)
@ -312,26 +339,28 @@ def setup(hass, config):
retain = call.data[ATTR_RETAIN]
try:
if payload_template is not None:
payload = template.Template(payload_template, hass).render()
payload = \
template.Template(payload_template, hass).async_render()
except template.jinja2.TemplateError as exc:
_LOGGER.error(
"Unable to publish to '%s': rendering payload template of "
"'%s' failed because %s",
msg_topic, payload_template, exc)
return
MQTT_CLIENT.publish(msg_topic, payload, qos, retain)
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_mqtt)
yield from hass.data[DATA_MQTT].async_publish(
msg_topic, payload, qos, retain)
descriptions = load_yaml_config_file(
os.path.join(os.path.dirname(__file__), 'services.yaml'))
descriptions = yield from hass.loop.run_in_executor(
None, load_yaml_config_file, os.path.join(
os.path.dirname(__file__), 'services.yaml'))
hass.services.register(DOMAIN, SERVICE_PUBLISH, publish_service,
descriptions.get(SERVICE_PUBLISH),
schema=MQTT_PUBLISH_SCHEMA)
hass.services.async_register(
DOMAIN, SERVICE_PUBLISH, async_publish_service,
descriptions.get(SERVICE_PUBLISH), schema=MQTT_PUBLISH_SCHEMA)
if conf.get(CONF_DISCOVERY):
_setup_discovery(hass, config)
yield from _async_setup_discovery(hass, config)
return True
@ -349,6 +378,7 @@ class MQTT(object):
self.topics = {}
self.progress = {}
self.birth_message = birth_message
self._mqttc = None
if protocol == PROTOCOL_31:
proto = mqtt.MQTTv31
@ -364,8 +394,8 @@ class MQTT(object):
self._mqttc.username_pw_set(username, password)
if certificate is not None:
self._mqttc.tls_set(certificate, certfile=client_cert,
keyfile=client_key)
self._mqttc.tls_set(
certificate, certfile=client_cert, keyfile=client_key)
if tls_insecure is not None:
self._mqttc.tls_insecure_set(tls_insecure)
@ -375,40 +405,69 @@ class MQTT(object):
self._mqttc.on_connect = self._mqtt_on_connect
self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message
if will_message:
self._mqttc.will_set(will_message.get(ATTR_TOPIC),
will_message.get(ATTR_PAYLOAD),
will_message.get(ATTR_QOS),
will_message.get(ATTR_RETAIN))
self._mqttc.connect(broker, port, keepalive)
def publish(self, topic, payload, qos, retain):
"""Publish a MQTT message."""
self._mqttc.publish(topic, payload, qos, retain)
self._mqttc.connect_async(broker, port, keepalive)
def start(self):
"""Run the MQTT client."""
self._mqttc.loop_start()
def async_publish(self, topic, payload, qos, retain):
"""Publish a MQTT message.
def stop(self):
"""Stop the MQTT client."""
self._mqttc.disconnect()
self._mqttc.loop_stop()
This method must be run in the event loop and returns a coroutine.
"""
return self.hass.loop.run_in_executor(
None, self._mqttc.publish, topic, payload, qos, retain)
def subscribe(self, topic, qos):
"""Subscribe to a topic."""
assert isinstance(topic, str)
def async_start(self):
"""Run the MQTT client.
This method must be run in the event loop and returns a coroutine.
"""
return self.hass.loop.run_in_executor(None, self._mqttc.loop_start)
def async_stop(self):
"""Stop the MQTT client.
This method must be run in the event loop and returns a coroutine.
"""
def stop(self):
"""Stop the MQTT client."""
self._mqttc.disconnect()
self._mqttc.loop_stop()
return self.hass.loop.run_in_executor(None, stop)
@asyncio.coroutine
def async_subscribe(self, topic, qos):
"""Subscribe to a topic.
This method is a coroutine.
"""
if not isinstance(topic, str):
raise HomeAssistantError("topic need to be a string!")
if topic in self.topics:
return
result, mid = self._mqttc.subscribe(topic, qos)
result, mid = yield from self.hass.loop.run_in_executor(
None, self._mqttc.subscribe, topic, qos)
_raise_on_error(result)
self.progress[mid] = topic
self.topics[topic] = None
def unsubscribe(self, topic):
"""Unsubscribe from topic."""
result, mid = self._mqttc.unsubscribe(topic)
@asyncio.coroutine
def async_unsubscribe(self, topic):
"""Unsubscribe from topic.
This method is a coroutine.
"""
result, mid = yield from self.hass.loop.run_in_executor(
None, self._mqttc.unsubscribe, topic)
_raise_on_error(result)
self.progress[mid] = topic
@ -437,12 +496,14 @@ class MQTT(object):
for topic, qos in old_topics.items():
# qos is None if we were in process of subscribing
if qos is not None:
self.subscribe(topic, qos)
self.hass.add_job(self.async_subscribe, topic, qos)
if self.birth_message:
self.publish(self.birth_message.get(ATTR_TOPIC),
self.birth_message.get(ATTR_PAYLOAD),
self.birth_message.get(ATTR_QOS),
self.birth_message.get(ATTR_RETAIN))
self.hass.add_job(self.async_publish(
self.birth_message.get(ATTR_TOPIC),
self.birth_message.get(ATTR_PAYLOAD),
self.birth_message.get(ATTR_QOS),
self.birth_message.get(ATTR_RETAIN)))
def _mqtt_on_subscribe(self, _mqttc, _userdata, mid, granted_qos):
"""Subscribe successful callback."""

View file

@ -9,7 +9,6 @@ import json
import logging
import re
from homeassistant.core import callback
import homeassistant.components.mqtt as mqtt
from homeassistant.components.mqtt import DOMAIN
from homeassistant.helpers.discovery import async_load_platform
@ -24,7 +23,7 @@ TOPIC_MATCHER = re.compile(
SUPPORTED_COMPONENTS = ['binary_sensor', 'sensor']
@callback
@asyncio.coroutine
def async_start(hass, discovery_topic, hass_config):
"""Initialization of MQTT Discovery."""
@asyncio.coroutine
@ -56,7 +55,7 @@ def async_start(hass, discovery_topic, hass_config):
yield from async_load_platform(
hass, component, DOMAIN, payload, hass_config)
mqtt.async_subscribe(hass, discovery_topic + '/#',
async_device_message_received, 0)
yield from mqtt.async_subscribe(
hass, discovery_topic + '/#', async_device_message_received, 0)
return True

View file

@ -4,15 +4,14 @@ Support for a local MQTT broker.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/#use-the-embedded-broker
"""
import asyncio
import logging
import tempfile
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_coroutine_threadsafe
REQUIREMENTS = ['hbmqtt==0.8']
DEPENDENCIES = ['http']
@ -29,8 +28,12 @@ HBMQTT_CONFIG_SCHEMA = vol.Any(None, vol.Schema({
}, extra=vol.ALLOW_EXTRA))
def start(hass, server_config):
"""Initialize MQTT Server."""
@asyncio.coroutine
def async_start(hass, server_config):
"""Initialize MQTT Server.
This method is a coroutine.
"""
from hbmqtt.broker import Broker, BrokerException
try:
@ -42,19 +45,20 @@ def start(hass, server_config):
client_config = None
broker = Broker(server_config, hass.loop)
run_coroutine_threadsafe(broker.start(), hass.loop).result()
yield from broker.start()
except BrokerException:
logging.getLogger(__name__).exception('Error initializing MQTT server')
return False, None
finally:
passwd.close()
@callback
def shutdown_mqtt_server(event):
@asyncio.coroutine
def async_shutdown_mqtt_server(event):
"""Shut down the MQTT server."""
hass.async_add_job(broker.shutdown())
yield from broker.shutdown()
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, shutdown_mqtt_server)
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, async_shutdown_mqtt_server)
return True, client_config

View file

@ -34,6 +34,7 @@ def threaded_listener_factory(async_factory):
return factory
@callback
def async_track_state_change(hass, entity_ids, action, from_state=None,
to_state=None):
"""Track specific state changes.
@ -84,6 +85,7 @@ def async_track_state_change(hass, entity_ids, action, from_state=None,
track_state_change = threaded_listener_factory(async_track_state_change)
@callback
def async_track_template(hass, template, action, variables=None):
"""Add a listener that track state changes with template condition."""
from . import condition
@ -111,6 +113,7 @@ def async_track_template(hass, template, action, variables=None):
track_template = threaded_listener_factory(async_track_template)
@callback
def async_track_point_in_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in time."""
utc_point_in_time = dt_util.as_utc(point_in_time)
@ -127,6 +130,7 @@ def async_track_point_in_time(hass, action, point_in_time):
track_point_in_time = threaded_listener_factory(async_track_point_in_time)
@callback
def async_track_point_in_utc_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in UTC time."""
# Ensure point_in_time is UTC
@ -160,6 +164,7 @@ track_point_in_utc_time = threaded_listener_factory(
async_track_point_in_utc_time)
@callback
def async_track_time_interval(hass, action, interval):
"""Add a listener that fires repetitively at every timedelta interval."""
remove = None
@ -189,6 +194,7 @@ def async_track_time_interval(hass, action, interval):
track_time_interval = threaded_listener_factory(async_track_time_interval)
@callback
def async_track_sunrise(hass, action, offset=None):
"""Add a listener that will fire a specified offset from sunrise daily."""
from homeassistant.components import sun
@ -225,6 +231,7 @@ def async_track_sunrise(hass, action, offset=None):
track_sunrise = threaded_listener_factory(async_track_sunrise)
@callback
def async_track_sunset(hass, action, offset=None):
"""Add a listener that will fire a specified offset from sunset daily."""
from homeassistant.components import sun
@ -261,6 +268,7 @@ def async_track_sunset(hass, action, offset=None):
track_sunset = threaded_listener_factory(async_track_sunset)
@callback
def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
hour=None, minute=None, second=None,
local=False):
@ -305,6 +313,7 @@ def async_track_utc_time_change(hass, action, year=None, month=None, day=None,
track_utc_time_change = threaded_listener_factory(async_track_utc_time_change)
@callback
def async_track_time_change(hass, action, year=None, month=None, day=None,
hour=None, minute=None, second=None):
"""Add a listener that will fire if UTC time matches a pattern."""

View file

@ -86,7 +86,13 @@ def async_test_home_assistant(loop):
loop._thread_ident = threading.get_ident()
hass = ha.HomeAssistant(loop)
hass.async_track_tasks()
def async_add_job(target, *args):
if isinstance(target, MagicMock):
return
hass._async_add_job_tracking(target, *args)
hass.async_add_job = async_add_job
hass.config.location_name = 'test home'
hass.config.config_dir = get_test_config_dir()

View file

@ -111,7 +111,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_arm_home(self.hass)
self.hass.block_till_done()
self.assertEqual(('alarm/command', 'ARM_HOME', 0, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
def test_arm_home_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code."""
@ -146,7 +146,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_arm_away(self.hass)
self.hass.block_till_done()
self.assertEqual(('alarm/command', 'ARM_AWAY', 0, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
def test_arm_away_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code."""
@ -181,7 +181,7 @@ class TestAlarmControlPanelMQTT(unittest.TestCase):
alarm_control_panel.alarm_disarm(self.hass)
self.hass.block_till_done()
self.assertEqual(('alarm/command', 'DISARM', 0, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
def test_disarm_not_publishes_mqtt_with_invalid_code(self):
"""Test not publishing of MQTT messages with invalid code."""

View file

@ -118,7 +118,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'OPEN', 0, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test')
self.assertEqual(STATE_OPEN, state.state)
@ -126,7 +126,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'CLOSE', 0, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test')
self.assertEqual(STATE_CLOSED, state.state)
@ -150,7 +150,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'OPEN', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state)
@ -174,7 +174,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'CLOSE', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state)
@ -198,7 +198,7 @@ class TestCoverMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'STOP', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('cover.test')
self.assertEqual(STATE_UNKNOWN, state.state)

View file

@ -74,6 +74,7 @@ light:
"""
import unittest
from unittest import mock
from homeassistant.bootstrap import setup_component
from homeassistant.const import STATE_ON, STATE_OFF, ATTR_ASSUMED_STATE
@ -328,7 +329,7 @@ class TestLightMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'on', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state)
@ -336,27 +337,20 @@ class TestLightMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'off', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test')
self.assertEqual(STATE_OFF, state.state)
self.mock_publish.reset_mock()
light.turn_on(self.hass, 'light.test', rgb_color=[75, 75, 75],
brightness=50)
self.hass.block_till_done()
# Calls are threaded so we need to reorder them
bright_call, rgb_call, state_call = \
sorted((call[1] for call in self.mock_publish.mock_calls[-3:]),
key=lambda call: call[0])
self.assertEqual(('test_light_rgb/set', 'on', 2, False),
state_call)
self.assertEqual(('test_light_rgb/rgb/set', '75,75,75', 2, False),
rgb_call)
self.assertEqual(('test_light_rgb/brightness/set', 50, 2, False),
bright_call)
self.mock_publish().async_publish.assert_has_calls([
mock.call('test_light_rgb/set', 'on', 2, False),
mock.call('test_light_rgb/rgb/set', '75,75,75', 2, False),
mock.call('test_light_rgb/brightness/set', 50, 2, False),
], any_order=True)
state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state)

View file

@ -172,7 +172,7 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', '{"state": "ON"}', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state)
@ -180,7 +180,7 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', '{"state": "OFF"}', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test')
self.assertEqual(STATE_OFF, state.state)
@ -189,11 +189,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(2, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(2, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(50, message_json["brightness"])
self.assertEqual(75, message_json["color"]["r"])
self.assertEqual(75, message_json["color"]["g"])
@ -228,11 +228,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(5, message_json["flash"])
self.assertEqual("ON", message_json["state"])
@ -240,11 +240,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(15, message_json["flash"])
self.assertEqual("ON", message_json["state"])
@ -268,11 +268,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(10, message_json["transition"])
self.assertEqual("ON", message_json["state"])
@ -281,11 +281,11 @@ class TestLightMQTTJSON(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# Get the sent message
message_json = json.loads(self.mock_publish.mock_calls[-1][1][1])
message_json = json.loads(self.mock_publish.mock_calls[-2][1][1])
self.assertEqual(10, message_json["transition"])
self.assertEqual("OFF", message_json["state"])

View file

@ -196,7 +196,7 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'on,,--', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test')
self.assertEqual(STATE_ON, state.state)
@ -205,7 +205,7 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('test_light_rgb/set', 'off', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('light.test')
self.assertEqual(STATE_OFF, state.state)
@ -215,12 +215,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(2, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(2, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload
payload = self.mock_publish.mock_calls[-1][1][1]
payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('on,50,75-75-75', payload)
# check the state
@ -253,12 +253,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload
payload = self.mock_publish.mock_calls[-1][1][1]
payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('on,short', payload)
# long flash
@ -266,12 +266,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload
payload = self.mock_publish.mock_calls[-1][1][1]
payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('on,long', payload)
def test_transition(self):
@ -296,12 +296,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload
payload = self.mock_publish.mock_calls[-1][1][1]
payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('on,10', payload)
# transition off
@ -309,12 +309,12 @@ class TestLightMQTTTemplate(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual('test_light_rgb/set',
self.mock_publish.mock_calls[-1][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-1][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-1][1][3])
self.mock_publish.mock_calls[-2][1][0])
self.assertEqual(0, self.mock_publish.mock_calls[-2][1][2])
self.assertEqual(False, self.mock_publish.mock_calls[-2][1][3])
# check the payload
payload = self.mock_publish.mock_calls[-1][1][1]
payload = self.mock_publish.mock_calls[-2][1][1]
self.assertEqual('off,4', payload)
def test_invalid_values(self): \

View file

@ -73,7 +73,7 @@ class TestLockMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'LOCK', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('lock.test')
self.assertEqual(STATE_LOCKED, state.state)
@ -81,7 +81,7 @@ class TestLockMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'UNLOCK', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('lock.test')
self.assertEqual(STATE_UNLOCKED, state.state)

View file

@ -12,9 +12,10 @@ def test_subscribing_config_topic(hass, mqtt_mock):
"""Test setting up discovery."""
hass_config = {}
discovery_topic = 'homeassistant'
async_start(hass, discovery_topic, hass_config)
assert mqtt_mock.subscribe.called
call_args = mqtt_mock.subscribe.mock_calls[0][1]
yield from async_start(hass, discovery_topic, hass_config)
assert mqtt_mock.async_subscribe.called
call_args = mqtt_mock.async_subscribe.mock_calls[0][1]
assert call_args[0] == discovery_topic + '/#'
assert call_args[1] == 0
@ -24,7 +25,7 @@ def test_subscribing_config_topic(hass, mqtt_mock):
def test_invalid_topic(mock_load_platform, hass, mqtt_mock):
"""Test sending in invalid JSON."""
mock_load_platform.return_value = mock_coro()
async_start(hass, 'homeassistant', {})
yield from async_start(hass, 'homeassistant', {})
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/not_config',
'{}')
@ -37,7 +38,7 @@ def test_invalid_topic(mock_load_platform, hass, mqtt_mock):
def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog):
"""Test sending in invalid JSON."""
mock_load_platform.return_value = mock_coro()
async_start(hass, 'homeassistant', {})
yield from async_start(hass, 'homeassistant', {})
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
'not json')
@ -51,7 +52,7 @@ def test_invalid_json(mock_load_platform, hass, mqtt_mock, caplog):
def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog):
"""Test sending in invalid JSON."""
mock_load_platform.return_value = mock_coro()
async_start(hass, 'homeassistant', {})
yield from async_start(hass, 'homeassistant', {})
async_fire_mqtt_message(hass, 'homeassistant/climate/bla/config', '{}')
yield from hass.async_block_till_done()
@ -62,7 +63,7 @@ def test_only_valid_components(mock_load_platform, hass, mqtt_mock, caplog):
@asyncio.coroutine
def test_correct_config_discovery(hass, mqtt_mock, caplog):
"""Test sending in invalid JSON."""
async_start(hass, 'homeassistant', {})
yield from async_start(hass, 'homeassistant', {})
async_fire_mqtt_message(hass, 'homeassistant/binary_sensor/bla/config',
'{ "name": "Beer" }')

View file

@ -1,5 +1,6 @@
"""The tests for the MQTT component."""
from collections import namedtuple
import asyncio
from collections import namedtuple, OrderedDict
import unittest
from unittest import mock
import socket
@ -7,14 +8,29 @@ import socket
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.bootstrap import setup_component
from homeassistant.bootstrap import setup_component, async_setup_component
import homeassistant.components.mqtt as mqtt
from homeassistant.const import (
EVENT_CALL_SERVICE, ATTR_DOMAIN, ATTR_SERVICE, EVENT_HOMEASSISTANT_START,
EVENT_HOMEASSISTANT_STOP)
from tests.common import (
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message)
get_test_home_assistant, mock_mqtt_component, fire_mqtt_message, mock_coro)
@asyncio.coroutine
def mock_mqtt_client(hass, config=None):
"""Mock the MQTT paho client."""
if config is None:
config = {
mqtt.CONF_BROKER: 'mock-broker'
}
with mock.patch('paho.mqtt.client.Client') as mock_client:
yield from async_setup_component(hass, mqtt.DOMAIN, {
mqtt.DOMAIN: config
})
return mock_client()
# pylint: disable=invalid-name
@ -40,7 +56,7 @@ class TestMQTT(unittest.TestCase):
""""Test if client start on HA launch."""
self.hass.bus.fire(EVENT_HOMEASSISTANT_START)
self.hass.block_till_done()
self.assertTrue(mqtt.MQTT_CLIENT.start.called)
self.assertTrue(self.hass.data['mqtt'].async_start.called)
def test_client_stops_on_home_assistant_start(self):
"""Test if client stops on HA launch."""
@ -48,7 +64,7 @@ class TestMQTT(unittest.TestCase):
self.hass.block_till_done()
self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
self.hass.block_till_done()
self.assertTrue(mqtt.MQTT_CLIENT.stop.called)
self.assertTrue(self.hass.data['mqtt'].async_stop.called)
@mock.patch('paho.mqtt.client.Client')
def test_setup_fails_if_no_connect_broker(self, _):
@ -69,14 +85,17 @@ class TestMQTT(unittest.TestCase):
"""Test setting up embedded server with no config."""
client_config = ('localhost', 1883, 'user', 'pass', None, '3.1.1')
with mock.patch('homeassistant.components.mqtt.server.start',
return_value=(True, client_config)) as _start:
with mock.patch('homeassistant.components.mqtt.server.async_start',
return_value=mock_coro(
return_value=(True, client_config))
) as _start:
self.hass.config.components = set()
assert setup_component(self.hass, mqtt.DOMAIN,
{mqtt.DOMAIN: {}})
assert _start.call_count == 1
# Test with `embedded: None`
_start.return_value = mock_coro(return_value=(True, client_config))
self.hass.config.components = set()
assert setup_component(self.hass, mqtt.DOMAIN,
{mqtt.DOMAIN: {'embedded': None}})
@ -105,7 +124,7 @@ class TestMQTT(unittest.TestCase):
ATTR_SERVICE: mqtt.SERVICE_PUBLISH
})
self.hass.block_till_done()
self.assertTrue(not mqtt.MQTT_CLIENT.publish.called)
self.assertTrue(not self.hass.data['mqtt'].async_publish.called)
def test_service_call_with_template_payload_renders_template(self):
"""Test the service call with rendered template.
@ -114,8 +133,9 @@ class TestMQTT(unittest.TestCase):
"""
mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}")
self.hass.block_till_done()
self.assertTrue(mqtt.MQTT_CLIENT.publish.called)
self.assertEqual(mqtt.MQTT_CLIENT.publish.call_args[0][1], "2")
self.assertTrue(self.hass.data['mqtt'].async_publish.called)
self.assertEqual(
self.hass.data['mqtt'].async_publish.call_args[0][1], "2")
def test_service_call_with_payload_doesnt_render_template(self):
"""Test the service call with unrendered template.
@ -129,7 +149,7 @@ class TestMQTT(unittest.TestCase):
mqtt.ATTR_PAYLOAD: payload,
mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template
}, blocking=True)
self.assertFalse(mqtt.MQTT_CLIENT.publish.called)
self.assertFalse(self.hass.data['mqtt'].async_publish.called)
def test_service_call_with_ascii_qos_retain_flags(self):
"""Test the service call with args that can be misinterpreted.
@ -142,9 +162,10 @@ class TestMQTT(unittest.TestCase):
mqtt.ATTR_QOS: '2',
mqtt.ATTR_RETAIN: 'no'
}, blocking=True)
self.assertTrue(mqtt.MQTT_CLIENT.publish.called)
self.assertEqual(mqtt.MQTT_CLIENT.publish.call_args[0][2], 2)
self.assertFalse(mqtt.MQTT_CLIENT.publish.call_args[0][3])
self.assertTrue(self.hass.data['mqtt'].async_publish.called)
self.assertEqual(
self.hass.data['mqtt'].async_publish.call_args[0][2], 2)
self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3])
def test_subscribe_topic(self):
"""Test the subscription of a topic."""
@ -231,15 +252,12 @@ class TestMQTTCallbacks(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
# mock_mqtt_component(self.hass)
with mock.patch('paho.mqtt.client.Client'):
self.hass.config.components = set()
assert setup_component(self.hass, mqtt.DOMAIN, {
mqtt.DOMAIN: {
mqtt.CONF_BROKER: 'mock-broker',
mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth',
mqtt.ATTR_PAYLOAD: 'birth'}
}
})
@ -261,7 +279,8 @@ class TestMQTTCallbacks(unittest.TestCase):
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
mqtt.MQTT_CLIENT._mqtt_on_message(None, {'hass': self.hass}, message)
self.hass.data['mqtt']._mqtt_on_message(
None, {'hass': self.hass}, message)
self.hass.block_till_done()
self.assertEqual(1, len(calls))
@ -273,68 +292,36 @@ class TestMQTTCallbacks(unittest.TestCase):
def test_mqtt_failed_connection_results_in_disconnect(self):
"""Test if connection failure leads to disconnect."""
for result_code in range(1, 6):
mqtt.MQTT_CLIENT._mqttc = mock.MagicMock()
mqtt.MQTT_CLIENT._mqtt_on_connect(None, {'topics': {}}, 0,
result_code)
self.assertTrue(mqtt.MQTT_CLIENT._mqttc.disconnect.called)
def test_mqtt_subscribes_topics_on_connect(self):
"""Test subscription to topic on connect."""
from collections import OrderedDict
prev_topics = OrderedDict()
prev_topics['topic/test'] = 1,
prev_topics['home/sensor'] = 2,
prev_topics['still/pending'] = None
mqtt.MQTT_CLIENT.topics = prev_topics
mqtt.MQTT_CLIENT.progress = {1: 'still/pending'}
# Return values for subscribe calls (rc, mid)
mqtt.MQTT_CLIENT._mqttc.subscribe.side_effect = ((0, 2), (0, 3))
mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0)
self.assertFalse(mqtt.MQTT_CLIENT._mqttc.disconnect.called)
expected = [(topic, qos) for topic, qos in prev_topics.items()
if qos is not None]
self.assertEqual(
expected,
[call[1] for call in mqtt.MQTT_CLIENT._mqttc.subscribe.mock_calls])
self.assertEqual({
1: 'still/pending',
2: 'topic/test',
3: 'home/sensor',
}, mqtt.MQTT_CLIENT.progress)
def test_mqtt_birth_message_on_connect(self): \
# pylint: disable=no-self-use
"""Test birth message on connect."""
mqtt.MQTT_CLIENT._mqtt_on_connect(None, None, 0, 0)
mqtt.MQTT_CLIENT._mqttc.publish.assert_called_with('birth', 'birth', 0,
False)
self.hass.data['mqtt']._mqttc = mock.MagicMock()
self.hass.data['mqtt']._mqtt_on_connect(
None, {'topics': {}}, 0, result_code)
self.assertTrue(self.hass.data['mqtt']._mqttc.disconnect.called)
def test_mqtt_disconnect_tries_no_reconnect_on_stop(self):
"""Test the disconnect tries."""
mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 0)
self.assertFalse(mqtt.MQTT_CLIENT._mqttc.reconnect.called)
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 0)
self.assertFalse(self.hass.data['mqtt']._mqttc.reconnect.called)
@mock.patch('homeassistant.components.mqtt.time.sleep')
def test_mqtt_disconnect_tries_reconnect(self, mock_sleep):
"""Test the re-connect tries."""
mqtt.MQTT_CLIENT.topics = {
self.hass.data['mqtt'].topics = {
'test/topic': 1,
'test/progress': None
}
mqtt.MQTT_CLIENT.progress = {
self.hass.data['mqtt'].progress = {
1: 'test/progress'
}
mqtt.MQTT_CLIENT._mqttc.reconnect.side_effect = [1, 1, 1, 0]
mqtt.MQTT_CLIENT._mqtt_on_disconnect(None, None, 1)
self.assertTrue(mqtt.MQTT_CLIENT._mqttc.reconnect.called)
self.assertEqual(4, len(mqtt.MQTT_CLIENT._mqttc.reconnect.mock_calls))
self.hass.data['mqtt']._mqttc.reconnect.side_effect = [1, 1, 1, 0]
self.hass.data['mqtt']._mqtt_on_disconnect(None, None, 1)
self.assertTrue(self.hass.data['mqtt']._mqttc.reconnect.called)
self.assertEqual(
4, len(self.hass.data['mqtt']._mqttc.reconnect.mock_calls))
self.assertEqual([1, 2, 4],
[call[1][0] for call in mock_sleep.mock_calls])
self.assertEqual({'test/topic': 1}, mqtt.MQTT_CLIENT.topics)
self.assertEqual({}, mqtt.MQTT_CLIENT.progress)
self.assertEqual({'test/topic': 1}, self.hass.data['mqtt'].topics)
self.assertEqual({}, self.hass.data['mqtt'].progress)
def test_invalid_mqtt_topics(self):
"""Test invalid topics."""
@ -356,7 +343,7 @@ class TestMQTTCallbacks(unittest.TestCase):
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage(topic, 1, payload)
with self.assertLogs(level='ERROR') as test_handle:
mqtt.MQTT_CLIENT._mqtt_on_message(
self.hass.data['mqtt']._mqtt_on_message(
None,
{'hass': self.hass},
message)
@ -365,3 +352,47 @@ class TestMQTTCallbacks(unittest.TestCase):
"ERROR:homeassistant.components.mqtt:Illegal utf-8 unicode "
"payload from MQTT topic: %s, Payload: " % topic,
test_handle.output[0])
@asyncio.coroutine
def test_birth_message(hass):
"""Test sending birth message."""
mqtt_client = yield from mock_mqtt_client(hass, {
mqtt.CONF_BROKER: 'mock-broker',
mqtt.CONF_BIRTH_MESSAGE: {mqtt.ATTR_TOPIC: 'birth',
mqtt.ATTR_PAYLOAD: 'birth'}
})
calls = []
mqtt_client.publish = lambda *args: calls.append(args)
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
yield from hass.async_block_till_done()
assert calls[-1] == ('birth', 'birth', 0, False)
@asyncio.coroutine
def test_mqtt_subscribes_topics_on_connect(hass):
"""Test subscription to topic on connect."""
mqtt_client = yield from mock_mqtt_client(hass)
prev_topics = OrderedDict()
prev_topics['topic/test'] = 1,
prev_topics['home/sensor'] = 2,
prev_topics['still/pending'] = None
hass.data['mqtt'].topics = prev_topics
hass.data['mqtt'].progress = {1: 'still/pending'}
# Return values for subscribe calls (rc, mid)
mqtt_client.subscribe.side_effect = ((0, 2), (0, 3))
hass.add_job = mock.MagicMock()
hass.data['mqtt']._mqtt_on_connect(None, None, 0, 0)
yield from hass.async_block_till_done()
assert not mqtt_client.disconnect.called
expected = [(topic, qos) for topic, qos in prev_topics.items()
if qos is not None]
assert [call[1][1:] for call in hass.add_job.mock_calls] == expected

View file

@ -4,7 +4,7 @@ from unittest.mock import Mock, MagicMock, patch
from homeassistant.bootstrap import setup_component
import homeassistant.components.mqtt as mqtt
from tests.common import get_test_home_assistant
from tests.common import get_test_home_assistant, mock_coro
class TestMQTT:
@ -21,9 +21,8 @@ class TestMQTT:
@patch('passlib.apps.custom_app_context', Mock(return_value=''))
@patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock()))
@patch('homeassistant.components.mqtt.server.run_coroutine_threadsafe',
Mock(return_value=MagicMock()))
@patch('hbmqtt.broker.Broker', Mock(return_value=MagicMock()))
@patch('hbmqtt.broker.Broker.start', Mock(return_value=mock_coro()))
@patch('homeassistant.components.mqtt.MQTT')
def test_creating_config_with_http_pass(self, mock_mqtt):
"""Test if the MQTT server gets started and subscribe/publish msg."""
@ -46,7 +45,7 @@ class TestMQTT:
assert mock_mqtt.mock_calls[0][1][6] is None
@patch('tempfile.NamedTemporaryFile', Mock(return_value=MagicMock()))
@patch('homeassistant.components.mqtt.server.run_coroutine_threadsafe')
@patch('hbmqtt.broker.Broker.start', return_value=mock_coro())
def test_broker_config_fails(self, mock_run):
"""Test if the MQTT component fails if server fails."""
from hbmqtt.broker import BrokerException

View file

@ -72,7 +72,7 @@ class TestSensorMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'beer on', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('switch.test')
self.assertEqual(STATE_ON, state.state)
@ -80,7 +80,7 @@ class TestSensorMQTT(unittest.TestCase):
self.hass.block_till_done()
self.assertEqual(('command-topic', 'beer off', 2, False),
self.mock_publish.mock_calls[-1][1])
self.mock_publish.mock_calls[-2][1])
state = self.hass.states.get('switch.test')
self.assertEqual(STATE_OFF, state.state)