diff --git a/homeassistant/components/device_automation/__init__.py b/homeassistant/components/device_automation/__init__.py index 93b0b9a4a9d9..945774da0b44 100644 --- a/homeassistant/components/device_automation/__init__.py +++ b/homeassistant/components/device_automation/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import MutableMapping +from collections.abc import Iterable, Mapping from functools import wraps from types import ModuleType from typing import Any @@ -13,9 +13,12 @@ import voluptuous_serialize from homeassistant.components import websocket_api from homeassistant.const import CONF_DEVICE_ID, CONF_DOMAIN, CONF_PLATFORM from homeassistant.core import HomeAssistant -from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.entity_registry import async_entries_for_device -from homeassistant.loader import IntegrationNotFound +from homeassistant.helpers import ( + config_validation as cv, + device_registry as dr, + entity_registry as er, +) +from homeassistant.loader import IntegrationNotFound, bind_hass from homeassistant.requirements import async_get_integration_with_requirements from .exceptions import DeviceNotFound, InvalidDeviceAutomationConfig @@ -49,6 +52,16 @@ TYPES = { } +@bind_hass +async def async_get_device_automations( + hass: HomeAssistant, + automation_type: str, + device_ids: Iterable[str] | None = None, +) -> Mapping[str, Any]: + """Return all the device automations for a type optionally limited to specific device ids.""" + return await _async_get_device_automations(hass, automation_type, device_ids) + + async def async_setup(hass, config): """Set up device automation.""" hass.components.websocket_api.async_register_command( @@ -96,7 +109,7 @@ async def async_get_device_automation_platform( async def _async_get_device_automations_from_domain( - hass, domain, automation_type, device_id + hass, domain, automation_type, device_ids, return_exceptions ): """List device automations.""" try: @@ -104,48 +117,67 @@ async def _async_get_device_automations_from_domain( hass, domain, automation_type ) except InvalidDeviceAutomationConfig: - return None + return {} function_name = TYPES[automation_type][1] - return await getattr(platform, function_name)(hass, device_id) - - -async def _async_get_device_automations(hass, automation_type, device_id): - """List device automations.""" - device_registry, entity_registry = await asyncio.gather( - hass.helpers.device_registry.async_get_registry(), - hass.helpers.entity_registry.async_get_registry(), + return await asyncio.gather( + *( + getattr(platform, function_name)(hass, device_id) + for device_id in device_ids + ), + return_exceptions=return_exceptions, ) - domains = set() - automations: list[MutableMapping[str, Any]] = [] - device = device_registry.async_get(device_id) - if device is None: - raise DeviceNotFound +async def _async_get_device_automations( + hass: HomeAssistant, automation_type: str, device_ids: Iterable[str] | None +) -> Mapping[str, list[dict[str, Any]]]: + """List device automations.""" + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) + domain_devices: dict[str, set[str]] = {} + device_entities_domains: dict[str, set[str]] = {} + match_device_ids = set(device_ids or device_registry.devices) + combined_results: dict[str, list[dict[str, Any]]] = {} - for entry_id in device.config_entries: - config_entry = hass.config_entries.async_get_entry(entry_id) - domains.add(config_entry.domain) + for entry in entity_registry.entities.values(): + if not entry.disabled_by and entry.device_id in match_device_ids: + device_entities_domains.setdefault(entry.device_id, set()).add(entry.domain) - entity_entries = async_entries_for_device(entity_registry, device_id) - for entity_entry in entity_entries: - domains.add(entity_entry.domain) + for device_id in match_device_ids: + combined_results[device_id] = [] + device = device_registry.async_get(device_id) + if device is None: + raise DeviceNotFound + for entry_id in device.config_entries: + if config_entry := hass.config_entries.async_get_entry(entry_id): + domain_devices.setdefault(config_entry.domain, set()).add(device_id) + for domain in device_entities_domains.get(device_id, []): + domain_devices.setdefault(domain, set()).add(device_id) - device_automations = await asyncio.gather( + # If specific device ids were requested, we allow + # InvalidDeviceAutomationConfig to be thrown, otherwise we skip + # devices that do not have valid triggers + return_exceptions = not bool(device_ids) + + for domain_results in await asyncio.gather( *( _async_get_device_automations_from_domain( - hass, domain, automation_type, device_id + hass, domain, automation_type, domain_device_ids, return_exceptions ) - for domain in domains + for domain, domain_device_ids in domain_devices.items() ) - ) - for device_automation in device_automations: - if device_automation is not None: - automations.extend(device_automation) + ): + for device_results in domain_results: + if device_results is None or isinstance( + device_results, InvalidDeviceAutomationConfig + ): + continue + for automation in device_results: + combined_results[automation["device_id"]].append(automation) - return automations + return combined_results async def _async_get_device_automation_capabilities(hass, automation_type, automation): @@ -207,7 +239,9 @@ def handle_device_errors(func): async def websocket_device_automation_list_actions(hass, connection, msg): """Handle request for device actions.""" device_id = msg["device_id"] - actions = await _async_get_device_automations(hass, "action", device_id) + actions = (await _async_get_device_automations(hass, "action", [device_id])).get( + device_id + ) connection.send_result(msg["id"], actions) @@ -222,7 +256,9 @@ async def websocket_device_automation_list_actions(hass, connection, msg): async def websocket_device_automation_list_conditions(hass, connection, msg): """Handle request for device conditions.""" device_id = msg["device_id"] - conditions = await _async_get_device_automations(hass, "condition", device_id) + conditions = ( + await _async_get_device_automations(hass, "condition", [device_id]) + ).get(device_id) connection.send_result(msg["id"], conditions) @@ -237,7 +273,9 @@ async def websocket_device_automation_list_conditions(hass, connection, msg): async def websocket_device_automation_list_triggers(hass, connection, msg): """Handle request for device triggers.""" device_id = msg["device_id"] - triggers = await _async_get_device_automations(hass, "trigger", device_id) + triggers = (await _async_get_device_automations(hass, "trigger", [device_id])).get( + device_id + ) connection.send_result(msg["id"], triggers) diff --git a/tests/common.py b/tests/common.py index 5de58a08472c..3d5e28be5146 100644 --- a/tests/common.py +++ b/tests/common.py @@ -29,10 +29,9 @@ from homeassistant.auth import ( providers as auth_providers, ) from homeassistant.auth.permissions import system_policies -from homeassistant.components import recorder +from homeassistant.components import device_automation, recorder from homeassistant.components.device_automation import ( # noqa: F401 _async_get_device_automation_capabilities as async_get_device_automation_capabilities, - _async_get_device_automations as async_get_device_automations, ) from homeassistant.components.mqtt.models import ReceiveMessage from homeassistant.config import async_process_component_config @@ -69,6 +68,16 @@ CLIENT_ID = "https://example.com/app" CLIENT_REDIRECT_URI = "https://example.com/app/callback" +async def async_get_device_automations( + hass: HomeAssistant, automation_type: str, device_id: str +) -> Any: + """Get a device automation for a single device id.""" + automations = await device_automation.async_get_device_automations( + hass, automation_type, [device_id] + ) + return automations.get(device_id) + + def threadsafe_callback_factory(func): """Create threadsafe functions out of callbacks. diff --git a/tests/components/device_automation/test_init.py b/tests/components/device_automation/test_init.py index 160e6354b8be..13190ed4b329 100644 --- a/tests/components/device_automation/test_init.py +++ b/tests/components/device_automation/test_init.py @@ -1,6 +1,7 @@ """The test for light device automation.""" import pytest +from homeassistant.components import device_automation import homeassistant.components.automation as automation from homeassistant.components.websocket_api.const import TYPE_RESULT from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON @@ -372,6 +373,76 @@ async def test_websocket_get_no_condition_capabilities( assert capabilities == expected_capabilities +async def test_async_get_device_automations_single_device_trigger( + hass, device_reg, entity_reg +): + """Test we get can fetch the triggers for a device id.""" + await async_setup_component(hass, "device_automation", {}) + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) + result = await device_automation.async_get_device_automations( + hass, "trigger", [device_entry.id] + ) + assert device_entry.id in result + assert len(result[device_entry.id]) == 2 + + +async def test_async_get_device_automations_all_devices_trigger( + hass, device_reg, entity_reg +): + """Test we get can fetch all the triggers when no device id is passed.""" + await async_setup_component(hass, "device_automation", {}) + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) + result = await device_automation.async_get_device_automations(hass, "trigger") + assert device_entry.id in result + assert len(result[device_entry.id]) == 2 + + +async def test_async_get_device_automations_all_devices_condition( + hass, device_reg, entity_reg +): + """Test we get can fetch all the conditions when no device id is passed.""" + await async_setup_component(hass, "device_automation", {}) + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) + result = await device_automation.async_get_device_automations(hass, "condition") + assert device_entry.id in result + assert len(result[device_entry.id]) == 2 + + +async def test_async_get_device_automations_all_devices_action( + hass, device_reg, entity_reg +): + """Test we get can fetch all the actions when no device id is passed.""" + await async_setup_component(hass, "device_automation", {}) + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + entity_reg.async_get_or_create("light", "test", "5678", device_id=device_entry.id) + result = await device_automation.async_get_device_automations(hass, "action") + assert device_entry.id in result + assert len(result[device_entry.id]) == 3 + + async def test_websocket_get_trigger_capabilities( hass, hass_ws_client, device_reg, entity_reg ): diff --git a/tests/components/remote/test_device_action.py b/tests/components/remote/test_device_action.py index 1193764da3a8..48e741a12a48 100644 --- a/tests/components/remote/test_device_action.py +++ b/tests/components/remote/test_device_action.py @@ -2,9 +2,6 @@ import pytest import homeassistant.components.automation as automation -from homeassistant.components.device_automation import ( - _async_get_device_automations as async_get_device_automations, -) from homeassistant.components.remote import DOMAIN from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON from homeassistant.helpers import device_registry @@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component from tests.common import ( MockConfigEntry, + async_get_device_automations, async_mock_service, mock_device_registry, mock_registry, diff --git a/tests/components/switch/test_device_action.py b/tests/components/switch/test_device_action.py index 9f8d821e74b2..2ccfb26d3ef5 100644 --- a/tests/components/switch/test_device_action.py +++ b/tests/components/switch/test_device_action.py @@ -2,9 +2,6 @@ import pytest import homeassistant.components.automation as automation -from homeassistant.components.device_automation import ( - _async_get_device_automations as async_get_device_automations, -) from homeassistant.components.switch import DOMAIN from homeassistant.const import CONF_PLATFORM, STATE_OFF, STATE_ON from homeassistant.helpers import device_registry @@ -12,6 +9,7 @@ from homeassistant.setup import async_setup_component from tests.common import ( MockConfigEntry, + async_get_device_automations, async_mock_service, mock_device_registry, mock_registry, diff --git a/tests/components/zha/test_device_action.py b/tests/components/zha/test_device_action.py index 4a777fcebb61..49fa11de26cf 100644 --- a/tests/components/zha/test_device_action.py +++ b/tests/components/zha/test_device_action.py @@ -8,14 +8,11 @@ import zigpy.zcl.clusters.security as security import zigpy.zcl.foundation as zcl_f import homeassistant.components.automation as automation -from homeassistant.components.device_automation import ( - _async_get_device_automations as async_get_device_automations, -) from homeassistant.components.zha import DOMAIN from homeassistant.helpers import device_registry as dr from homeassistant.setup import async_setup_component -from tests.common import async_mock_service, mock_coro +from tests.common import async_get_device_automations, async_mock_service, mock_coro from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401 SHORT_PRESS = "remote_button_short_press"