diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index f55ab9cbe4a..ac1558b8aa0 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio from asyncio import Event -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Coroutine from pathlib import Path from typing import Any from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -19,6 +19,8 @@ from aioesphomeapi import ( HomeassistantServiceCall, ReconnectLogic, UserService, + VoiceAssistantAudioSettings, + VoiceAssistantEventType, VoiceAssistantFeature, ) import pytest @@ -32,6 +34,11 @@ from homeassistant.components.esphome.const import ( DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS, DOMAIN, ) +from homeassistant.components.esphome.entry_data import RuntimeEntryData +from homeassistant.components.esphome.voice_assistant import ( + VoiceAssistantAPIPipeline, + VoiceAssistantUDPPipeline, +) from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -40,6 +47,8 @@ from . import DASHBOARD_HOST, DASHBOARD_PORT, DASHBOARD_SLUG from tests.common import MockConfigEntry +_ONE_SECOND = 16000 * 2 # 16Khz 16-bit + @pytest.fixture(autouse=True) def mock_bluetooth(enable_bluetooth: None) -> None: @@ -196,6 +205,20 @@ class MockESPHomeDevice: self.home_assistant_state_subscription_callback: Callable[ [str, str | None], None ] + self.voice_assistant_handle_start_callback: Callable[ + [str, int, VoiceAssistantAudioSettings, str | None], + Coroutine[Any, Any, int | None], + ] + self.voice_assistant_handle_stop_callback: Callable[ + [], Coroutine[Any, Any, None] + ] + self.voice_assistant_handle_audio_callback: ( + Callable[ + [bytes], + Coroutine[Any, Any, None], + ] + | None + ) self.device_info = device_info def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None: @@ -255,6 +278,47 @@ class MockESPHomeDevice: """Mock a state subscription.""" self.home_assistant_state_subscription_callback(entity_id, attribute) + def set_subscribe_voice_assistant_callbacks( + self, + handle_start: Callable[ + [str, int, VoiceAssistantAudioSettings, str | None], + Coroutine[Any, Any, int | None], + ], + handle_stop: Callable[[], Coroutine[Any, Any, None]], + handle_audio: ( + Callable[ + [bytes], + Coroutine[Any, Any, None], + ] + | None + ) = None, + ) -> None: + """Set the voice assistant subscription callbacks.""" + self.voice_assistant_handle_start_callback = handle_start + self.voice_assistant_handle_stop_callback = handle_stop + self.voice_assistant_handle_audio_callback = handle_audio + + async def mock_voice_assistant_handle_start( + self, + conversation_id: str, + flags: int, + settings: VoiceAssistantAudioSettings, + wake_word_phrase: str | None, + ) -> int | None: + """Mock voice assistant handle start.""" + return await self.voice_assistant_handle_start_callback( + conversation_id, flags, settings, wake_word_phrase + ) + + async def mock_voice_assistant_handle_stop(self) -> None: + """Mock voice assistant handle stop.""" + await self.voice_assistant_handle_stop_callback() + + async def mock_voice_assistant_handle_audio(self, audio: bytes) -> None: + """Mock voice assistant handle audio.""" + assert self.voice_assistant_handle_audio_callback is not None + await self.voice_assistant_handle_audio_callback(audio) + async def _mock_generic_device_entry( hass: HomeAssistant, @@ -318,8 +382,33 @@ async def _mock_generic_device_entry( """Subscribe to home assistant states.""" mock_device.set_home_assistant_state_subscription_callback(on_state_sub) + def _subscribe_voice_assistant( + *, + handle_start: Callable[ + [str, int, VoiceAssistantAudioSettings, str | None], + Coroutine[Any, Any, int | None], + ], + handle_stop: Callable[[], Coroutine[Any, Any, None]], + handle_audio: ( + Callable[ + [bytes], + Coroutine[Any, Any, None], + ] + | None + ) = None, + ) -> Callable[[], None]: + """Subscribe to voice assistant.""" + mock_device.set_subscribe_voice_assistant_callbacks( + handle_start, handle_stop, handle_audio + ) + + def unsub(): + pass + + return unsub + mock_client.device_info = AsyncMock(return_value=mock_device.device_info) - mock_client.subscribe_voice_assistant = Mock() + mock_client.subscribe_voice_assistant = _subscribe_voice_assistant mock_client.list_entities_services = AsyncMock( return_value=mock_list_entities_services ) @@ -524,3 +613,57 @@ async def mock_esphome_device( ) return _mock_device + + +@pytest.fixture +def mock_voice_assistant_api_pipeline() -> VoiceAssistantAPIPipeline: + """Return the API Pipeline factory.""" + mock_pipeline = Mock(spec=VoiceAssistantAPIPipeline) + + def mock_constructor( + hass: HomeAssistant, + entry_data: RuntimeEntryData, + handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], + handle_finished: Callable[[], None], + api_client: APIClient, + ): + """Fake the constructor.""" + mock_pipeline.hass = hass + mock_pipeline.entry_data = entry_data + mock_pipeline.handle_event = handle_event + mock_pipeline.handle_finished = handle_finished + mock_pipeline.api_client = api_client + return mock_pipeline + + mock_pipeline.side_effect = mock_constructor + with patch( + "homeassistant.components.esphome.voice_assistant.VoiceAssistantAPIPipeline", + new=mock_pipeline, + ): + yield mock_pipeline + + +@pytest.fixture +def mock_voice_assistant_udp_pipeline() -> VoiceAssistantUDPPipeline: + """Return the API Pipeline factory.""" + mock_pipeline = Mock(spec=VoiceAssistantUDPPipeline) + + def mock_constructor( + hass: HomeAssistant, + entry_data: RuntimeEntryData, + handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], + handle_finished: Callable[[], None], + ): + """Fake the constructor.""" + mock_pipeline.hass = hass + mock_pipeline.entry_data = entry_data + mock_pipeline.handle_event = handle_event + mock_pipeline.handle_finished = handle_finished + return mock_pipeline + + mock_pipeline.side_effect = mock_constructor + with patch( + "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPPipeline", + new=mock_pipeline, + ): + yield mock_pipeline diff --git a/tests/components/esphome/test_manager.py b/tests/components/esphome/test_manager.py index 92c21842e78..01f267581f4 100644 --- a/tests/components/esphome/test_manager.py +++ b/tests/components/esphome/test_manager.py @@ -2,7 +2,7 @@ import asyncio from collections.abc import Awaitable, Callable -from unittest.mock import AsyncMock, call +from unittest.mock import AsyncMock, call, patch from aioesphomeapi import ( APIClient, @@ -17,6 +17,7 @@ from aioesphomeapi import ( UserService, UserServiceArg, UserServiceArgType, + VoiceAssistantFeature, ) import pytest @@ -28,6 +29,10 @@ from homeassistant.components.esphome.const import ( DOMAIN, STABLE_BLE_VERSION_STR, ) +from homeassistant.components.esphome.voice_assistant import ( + VoiceAssistantAPIPipeline, + VoiceAssistantUDPPipeline, +) from homeassistant.const import ( CONF_HOST, CONF_PASSWORD, @@ -39,7 +44,7 @@ from homeassistant.data_entry_flow import FlowResultType from homeassistant.helpers import device_registry as dr, issue_registry as ir from homeassistant.setup import async_setup_component -from .conftest import MockESPHomeDevice +from .conftest import _ONE_SECOND, MockESPHomeDevice from tests.common import MockConfigEntry, async_capture_events, async_mock_service @@ -1181,3 +1186,102 @@ async def test_entry_missing_unique_id( await mock_esphome_device(mock_client=mock_client, mock_storage=True) await hass.async_block_till_done() assert entry.unique_id == "11:22:33:44:55:aa" + + +async def test_manager_voice_assistant_handlers_api( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], + caplog: pytest.LogCaptureFixture, + mock_voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, +) -> None: + """Test the handlers are correctly executed in manager.py.""" + + device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.API_AUDIO + }, + ) + + await hass.async_block_till_done() + + with ( + patch( + "homeassistant.components.esphome.manager.VoiceAssistantAPIPipeline", + new=mock_voice_assistant_api_pipeline, + ), + ): + port: int | None = await device.mock_voice_assistant_handle_start( + "", 0, None, None + ) + + assert port == 0 + + port: int | None = await device.mock_voice_assistant_handle_start( + "", 0, None, None + ) + + assert "Voice assistant UDP server was not stopped" in caplog.text + + await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND)) + + mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_called_with( + bytes(_ONE_SECOND) + ) + + mock_voice_assistant_api_pipeline.receive_audio_bytes.reset_mock() + + await device.mock_voice_assistant_handle_stop() + mock_voice_assistant_api_pipeline.handle_finished() + + await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND)) + + mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_not_called() + + +async def test_manager_voice_assistant_handlers_udp( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], + mock_voice_assistant_udp_pipeline: VoiceAssistantUDPPipeline, +) -> None: + """Test the handlers are correctly executed in manager.py.""" + + device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + }, + ) + + await hass.async_block_till_done() + + with ( + patch( + "homeassistant.components.esphome.manager.VoiceAssistantUDPPipeline", + new=mock_voice_assistant_udp_pipeline, + ), + ): + await device.mock_voice_assistant_handle_start("", 0, None, None) + + mock_voice_assistant_udp_pipeline.run_pipeline.assert_called() + + await device.mock_voice_assistant_handle_stop() + mock_voice_assistant_udp_pipeline.handle_finished() + + mock_voice_assistant_udp_pipeline.stop.assert_called() + mock_voice_assistant_udp_pipeline.close.assert_called() diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py index c347c3dc7d3..eafc0243dc6 100644 --- a/tests/components/esphome/test_voice_assistant.py +++ b/tests/components/esphome/test_voice_assistant.py @@ -37,15 +37,13 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers import intent as intent_helper import homeassistant.helpers.device_registry as dr -from .conftest import MockESPHomeDevice +from .conftest import _ONE_SECOND, MockESPHomeDevice _TEST_INPUT_TEXT = "This is an input test" _TEST_OUTPUT_TEXT = "This is an output test" _TEST_OUTPUT_URL = "output.mp3" _TEST_MEDIA_ID = "12345" -_ONE_SECOND = 16000 * 2 # 16Khz 16-bit - @pytest.fixture def voice_assistant_udp_pipeline( @@ -813,6 +811,7 @@ async def test_wake_word_abort_exception( async def test_timer_events( hass: HomeAssistant, + device_registry: dr.DeviceRegistry, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], @@ -831,8 +830,8 @@ async def test_timer_events( | VoiceAssistantFeature.TIMERS }, ) - dev_reg = dr.async_get(hass) - dev = dev_reg.async_get_device( + await hass.async_block_till_done() + dev = device_registry.async_get_device( connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} ) @@ -886,6 +885,7 @@ async def test_timer_events( async def test_unknown_timer_event( hass: HomeAssistant, + device_registry: dr.DeviceRegistry, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], @@ -904,8 +904,8 @@ async def test_unknown_timer_event( | VoiceAssistantFeature.TIMERS }, ) - dev_reg = dr.async_get(hass) - dev = dev_reg.async_get_device( + await hass.async_block_till_done() + dev = device_registry.async_get_device( connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} )