Add api to device_automation to return all matching devices (#53361)

This commit is contained in:
J. Nick Koston 2021-08-10 14:21:34 -05:00 committed by GitHub
parent ac29571db3
commit 4bde4504ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 159 additions and 48 deletions

View file

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
from collections.abc import Iterable, Mapping
from functools import wraps
from types import ModuleType
from typing import Any
@ -13,9 +13,12 @@ import voluptuous_serialize
from homeassistant.components import websocket_api
from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_registry import async_entries_for_device
from homeassistant.loader import IntegrationNotFound
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
entity_registry as er,
)
from homeassistant.loader import IntegrationNotFound, bind_hass
from homeassistant.requirements import async_get_integration_with_requirements
from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig
@ -49,6 +52,16 @@ TYPES = {
}
@bind_hass
async def async_get_device_automations(
hass: HomeAssistant,
automation_type: str,
device_ids: Iterable[str] | None = None,
) -> Mapping[str, Any]:
"""Return all the device automations for a type optionally limited to specific device ids."""
return await _async_get_device_automations(hass, automation_type, device_ids)
async def async_setup(hass, config):
"""Set up device automation."""
hass.components.websocket_api.async_register_command(
@ -96,7 +109,7 @@ async def async_get_device_automation_platform(
async def _async_get_device_automations_from_domain(
hass, domain, automation_type, device_id
hass, domain, automation_type, device_ids, return_exceptions
):
"""List device automations."""
try:
@ -104,48 +117,67 @@ async def _async_get_device_automations_from_domain(
hass, domain, automation_type
)
except InvalidDeviceAutomationConfig:
return None
return {}
function_name = TYPES[automation_type][1]
return await getattr(platform, function_name)(hass, device_id)
async def _async_get_device_automations(hass, automation_type, device_id):
"""List device automations."""
device_registry, entity_registry = await asyncio.gather(
hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(),
return await asyncio.gather(
*(
getattr(platform, function_name)(hass, device_id)
for device_id in device_ids
),
return_exceptions=return_exceptions,
)
domains = set()
automations: list[MutableMapping[str, Any]] = []
device = device_registry.async_get(device_id)
if device is None:
raise DeviceNotFound
async def _async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None
) -> Mapping[str, list[dict[str, Any]]]:
"""List device automations."""
device_registry = dr.async_get(hass)
entity_registry = er.async_get(hass)
domain_devices: dict[str, set[str]] = {}
device_entities_domains: dict[str, set[str]] = {}
match_device_ids = set(device_ids or device_registry.devices)
combined_results: dict[str, list[dict[str, Any]]] = {}
for entry_id in device.config_entries:
config_entry = hass.config_entries.async_get_entry(entry_id)
domains.add(config_entry.domain)
for entry in entity_registry.entities.values():
if not entry.disabled_by and entry.device_id in match_device_ids:
device_entities_domains.setdefault(entry.device_id, set()).add(entry.domain)
entity_entries = async_entries_for_device(entity_registry, device_id)
for entity_entry in entity_entries:
domains.add(entity_entry.domain)
for device_id in match_device_ids:
combined_results[device_id] = []
device = device_registry.async_get(device_id)
if device is None:
raise DeviceNotFound
for entry_id in device.config_entries:
if config_entry := hass.config_entries.async_get_entry(entry_id):
domain_devices.setdefault(config_entry.domain, set()).add(device_id)
for domain in device_entities_domains.get(device_id, []):
domain_devices.setdefault(domain, set()).add(device_id)
device_automations = await asyncio.gather(
# If specific device ids were requested, we allow
# InvalidDeviceAutomationConfig to be thrown, otherwise we skip
# devices that do not have valid triggers
return_exceptions = not bool(device_ids)
for domain_results in await asyncio.gather(
*(
_async_get_device_automations_from_domain(
hass, domain, automation_type, device_id
hass, domain, automation_type, domain_device_ids, return_exceptions
)
for domain in domains
for domain, domain_device_ids in domain_devices.items()
)
)
for device_automation in device_automations:
if device_automation is not None:
automations.extend(device_automation)
):
for device_results in domain_results:
if device_results is None or isinstance(
device_results, InvalidDeviceAutomationConfig
):
continue
for automation in device_results:
combined_results[automation["device_id"]].append(automation)
return automations
return combined_results
async def _async_get_device_automation_capabilities(hass, automation_type, automation):
@ -207,7 +239,9 @@ def handle_device_errors(func):
async def websocket_device_automation_list_actions(hass, connection, msg):
"""Handle request for device actions."""
device_id = msg["device_id"]
actions = await _async_get_device_automations(hass, "action", device_id)
actions = (await _async_get_device_automations(hass, "action", [device_id])).get(
device_id
)
connection.send_result(msg["id"], actions)
@ -222,7 +256,9 @@ async def websocket_device_automation_list_actions(hass, connection, msg):
async def websocket_device_automation_list_conditions(hass, connection, msg):
"""Handle request for device conditions."""
device_id = msg["device_id"]
conditions = await _async_get_device_automations(hass, "condition", device_id)
conditions = (
await _async_get_device_automations(hass, "condition", [device_id])
).get(device_id)
connection.send_result(msg["id"], conditions)
@ -237,7 +273,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg):
async def websocket_device_automation_list_triggers(hass, connection, msg):
"""Handle request for device triggers."""
device_id = msg["device_id"]
triggers = await _async_get_device_automations(hass, "trigger", device_id)
triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get(
device_id
)
connection.send_result(msg["id"], triggers)

View file

@ -29,10 +29,9 @@ from homeassistant.auth import (
providers as auth_providers,
)
from homeassistant.auth.permissions import system_policies
from homeassistant.components import recorder
from homeassistant.components import device_automation, recorder
from homeassistant.components.device_automation import ( # noqa: F401
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
_async_get_device_automations as async_get_device_automations,
)
from homeassistant.components.mqtt.models import ReceiveMessage
from homeassistant.config import async_process_component_config
@ -69,6 +68,16 @@ CLIENT_ID = "https://example.com/app"
CLIENT_REDIRECT_URI = "https://example.com/app/callback"
async def async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_id: str
) -> Any:
"""Get a device automation for a single device id."""
automations = await device_automation.async_get_device_automations(
hass, automation_type, [device_id]
)
return automations.get(device_id)
def threadsafe_callback_factory(func):
"""Create threadsafe functions out of callbacks.

View file

@ -1,6 +1,7 @@
"""The test for light device automation."""
import pytest
from homeassistant.components import device_automation
import homeassistant.components.automation as automation
from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
@ -372,6 +373,76 @@ async def test_websocket_get_no_condition_capabilities(
assert capabilities == expected_capabilities
async def test_async_get_device_automations_single_device_trigger(
hass, device_reg, entity_reg
):
"""Test we get can fetch the triggers for a device id."""
await async_setup_component(hass, "device_automation", {})
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(
hass, "trigger", [device_entry.id]
)
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
async def test_async_get_device_automations_all_devices_trigger(
hass, device_reg, entity_reg
):
"""Test we get can fetch all the triggers when no device id is passed."""
await async_setup_component(hass, "device_automation", {})
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "trigger")
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
async def test_async_get_device_automations_all_devices_condition(
hass, device_reg, entity_reg
):
"""Test we get can fetch all the conditions when no device id is passed."""
await async_setup_component(hass, "device_automation", {})
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "condition")
assert device_entry.id in result
assert len(result[device_entry.id]) == 2
async def test_async_get_device_automations_all_devices_action(
hass, device_reg, entity_reg
):
"""Test we get can fetch all the actions when no device id is passed."""
await async_setup_component(hass, "device_automation", {})
config_entry = MockConfigEntry(domain="test", data={})
config_entry.add_to_hass(hass)
device_entry = device_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)
entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id)
result = await device_automation.async_get_device_automations(hass, "action")
assert device_entry.id in result
assert len(result[device_entry.id]) == 3
async def test_websocket_get_trigger_capabilities(
hass, hass_ws_client, device_reg, entity_reg
):

View file

@ -2,9 +2,6 @@
import pytest
import homeassistant.components.automation as automation
from homeassistant.components.device_automation import (
_async_get_device_automations as async_get_device_automations,
)
from homeassistant.components.remote import DOMAIN
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
from homeassistant.helpers import device_registry
@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
async_get_device_automations,
async_mock_service,
mock_device_registry,
mock_registry,

View file

@ -2,9 +2,6 @@
import pytest
import homeassistant.components.automation as automation
from homeassistant.components.device_automation import (
_async_get_device_automations as async_get_device_automations,
)
from homeassistant.components.switch import DOMAIN
from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON
from homeassistant.helpers import device_registry
@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
async_get_device_automations,
async_mock_service,
mock_device_registry,
mock_registry,

View file

@ -8,14 +8,11 @@ import zigpy.zcl.clusters.security as security
import zigpy.zcl.foundation as zcl_f
import homeassistant.components.automation as automation
from homeassistant.components.device_automation import (
_async_get_device_automations as async_get_device_automations,
)
from homeassistant.components.zha import DOMAIN
from homeassistant.helpers import device_registry as dr
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service, mock_coro
from tests.common import async_get_device_automations, async_mock_service, mock_coro
from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401
SHORT_PRESS = "remote_button_short_press"