From fa65783f393068cd44da283fc5c6a76bf8a996b3 Mon Sep 17 00:00:00 2001 From: Gianluca Barbaro Date: Thu, 13 Apr 2017 16:38:09 +0200 Subject: [PATCH] MQTT: Managing binary payloads (#6976) * Managing binary payloads Hello, background: I wrote a HA camera component that gets the image from a binary payload. I'm testing it with Zanzito (https://play.google.com/store/apps/details?id=it.barbaro.zanzito) and it works apparently well: it gets the image and correctly displays it in the front-end. But I had to make the changes I'm proposing here: the message was being blocked because the utf-8 decoding failed. As far as I know, the utf-8 encoding is required for the topic, not for the payload. What I did here was try the utf-8 decoding, but even if unsuccessful, it dispatches the message anyway. Is there anything else I'm missing? thanks Gianluca * Update __init__.py * Update __init__.py * Update __init__.py * git test - ignore * Should work * minor fixes * updated mqtt/services.yaml * added two tests, modified threaded subscribe * removing polymer * requested changes * requested changes - minor fix * security wrap around payload_file_path * services.yaml updated * removed file publishing * minor fix --- homeassistant/components/mqtt/__init__.py | 57 ++++++++++++--------- tests/common.py | 5 +- tests/components/mqtt/test_init.py | 60 ++++++++++------------- 3 files changed, 64 insertions(+), 58 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 0e8c666d147f..2b6774939daa 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -201,7 +201,8 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None): @asyncio.coroutine -def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): +def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS, + encoding='utf-8'): """Subscribe to an MQTT topic.""" @callback def async_mqtt_topic_subscriber(dp_topic, dp_payload, dp_qos): @@ -209,7 +210,21 @@ def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): if not _match_topic(topic, dp_topic): return - hass.async_run_job(msg_callback, dp_topic, dp_payload, dp_qos) + if encoding is not None: + try: + payload = dp_payload.decode(encoding) + _LOGGER.debug("Received message on %s: %s", + dp_topic, payload) + except (AttributeError, UnicodeDecodeError): + _LOGGER.error("Illegal payload encoding %s from " + "MQTT topic: %s, Payload: %s", + encoding, dp_topic, dp_payload) + return + else: + _LOGGER.debug("Received binary message on %s", dp_topic) + payload = dp_payload + + hass.async_run_job(msg_callback, dp_topic, payload, dp_qos) async_remove = async_dispatcher_connect( hass, SIGNAL_MQTT_MESSAGE_RECEIVED, async_mqtt_topic_subscriber) @@ -218,10 +233,12 @@ def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): return async_remove -def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS): +def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS, + encoding='utf-8'): """Subscribe to an MQTT topic.""" async_remove = run_coroutine_threadsafe( - async_subscribe(hass, topic, msg_callback, qos), + async_subscribe(hass, topic, msg_callback, + qos, encoding), hass.loop ).result() @@ -372,16 +389,16 @@ def async_setup(hass, config): payload_template = call.data.get(ATTR_PAYLOAD_TEMPLATE) qos = call.data[ATTR_QOS] retain = call.data[ATTR_RETAIN] - try: - if payload_template is not None: + if payload_template is not None: + try: payload = \ template.Template(payload_template, hass).async_render() - except template.jinja2.TemplateError as exc: - _LOGGER.error( - "Unable to publish to '%s': rendering payload template of " - "'%s' failed because %s", - msg_topic, payload_template, exc) - return + except template.jinja2.TemplateError as exc: + _LOGGER.error( + "Unable to publish to '%s': rendering payload template of " + "'%s' failed because %s", + msg_topic, payload_template, exc) + return yield from hass.data[DATA_MQTT].async_publish( msg_topic, payload, qos, retain) @@ -564,18 +581,10 @@ class MQTT(object): def _mqtt_on_message(self, _mqttc, _userdata, msg): """Message received callback.""" - try: - payload = msg.payload.decode('utf-8') - except (AttributeError, UnicodeDecodeError): - _LOGGER.error("Illegal utf-8 unicode payload from " - "MQTT topic: %s, Payload: %s", msg.topic, - msg.payload) - else: - _LOGGER.info("Received message on %s: %s", msg.topic, payload) - dispatcher_send( - self.hass, SIGNAL_MQTT_MESSAGE_RECEIVED, msg.topic, payload, - msg.qos - ) + dispatcher_send( + self.hass, SIGNAL_MQTT_MESSAGE_RECEIVED, msg.topic, msg.payload, + msg.qos + ) def _mqtt_on_unsubscribe(self, _mqttc, _userdata, mid, granted_qos): """Unsubscribe successful callback.""" diff --git a/tests/common.py b/tests/common.py index 03a4de235d7c..a66273448790 100644 --- a/tests/common.py +++ b/tests/common.py @@ -170,8 +170,11 @@ def mock_service(hass, domain, service): @ha.callback def async_fire_mqtt_message(hass, topic, payload, qos=0): """Fire the MQTT message.""" + if isinstance(payload, str): + payload = payload.encode('utf-8') async_dispatcher_send( - hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, topic, payload, qos) + hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, topic, + payload, qos) def fire_mqtt_message(hass, topic, payload, qos=0): diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index f387c7c0bd73..0017674e82f5 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -209,6 +209,31 @@ class TestMQTT(unittest.TestCase): self.hass.block_till_done() self.assertEqual(0, len(self.calls)) + def test_subscribe_binary_topic(self): + """Test the subscription to a binary topic.""" + mqtt.subscribe(self.hass, 'test-topic', self.record_calls, + 0, None) + + fire_mqtt_message(self.hass, 'test-topic', 0x9a) + + self.hass.block_till_done() + self.assertEqual(1, len(self.calls)) + self.assertEqual('test-topic', self.calls[0][0]) + self.assertEqual(0x9a, self.calls[0][1]) + + def test_receiving_non_utf8_message_gets_logged(self): + """Test receiving a non utf8 encoded message.""" + mqtt.subscribe(self.hass, 'test-topic', self.record_calls) + + with self.assertLogs(level='ERROR') as test_handle: + fire_mqtt_message(self.hass, 'test-topic', 0x9a) + self.hass.block_till_done() + self.assertIn( + "ERROR:homeassistant.components.mqtt:Illegal payload " + "encoding utf-8 from MQTT " + "topic: test-topic, Payload: 154", + test_handle.output[0]) + class TestMQTTCallbacks(unittest.TestCase): """Test the MQTT callbacks.""" @@ -255,7 +280,8 @@ class TestMQTTCallbacks(unittest.TestCase): self.assertEqual(1, len(calls)) last_event = calls[0] - self.assertEqual('Hello World!', last_event['payload']) + self.assertEqual(bytearray('Hello World!', 'utf-8'), + last_event['payload']) self.assertEqual(message.topic, last_event['topic']) self.assertEqual(message.qos, last_event['qos']) @@ -298,38 +324,6 @@ class TestMQTTCallbacks(unittest.TestCase): self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'bad+topic') self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad\0one') - def test_receiving_non_utf8_message_gets_logged(self): - """Test receiving a non utf8 encoded message.""" - calls = [] - - @callback - def record(topic, payload, qos): - """Helper to record calls.""" - data = { - 'topic': topic, - 'payload': payload, - 'qos': qos, - } - calls.append(data) - - async_dispatcher_connect( - self.hass, mqtt.SIGNAL_MQTT_MESSAGE_RECEIVED, record) - - payload = 0x9a - topic = 'test_topic' - MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload']) - message = MQTTMessage(topic, 1, payload) - with self.assertLogs(level='ERROR') as test_handle: - self.hass.data['mqtt']._mqtt_on_message( - None, - {'hass': self.hass}, - message) - self.hass.block_till_done() - self.assertIn( - "ERROR:homeassistant.components.mqtt:Illegal utf-8 unicode " - "payload from MQTT topic: %s, Payload: " % topic, - test_handle.output[0]) - @asyncio.coroutine def test_setup_embedded_starts_with_no_config(hass):