Run pipeline from audio stream function (#90748)

* Run pipeline from audio stream function

* Fix tests

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Paulus Schoutsen 2023-04-04 00:06:51 -04:00 committed by GitHub
parent 4f1574b859
commit 6e4c78686e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 383 additions and 158 deletions

View file

@ -1,12 +1,33 @@
"""The Voice Assistant integration."""
from __future__ import annotations
from homeassistant.core import HomeAssistant
from collections.abc import AsyncIterable
from homeassistant.components import stt
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .error import PipelineNotFound
from .pipeline import (
PipelineEvent,
PipelineEventCallback,
PipelineEventType,
PipelineInput,
PipelineRun,
PipelineStage,
async_get_pipeline,
)
from .websocket_api import async_register_websocket_api
__all__ = (
"DOMAIN",
"async_setup",
"async_pipeline_from_audio_stream",
"PipelineEvent",
"PipelineEventType",
)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Voice Assistant integration."""
@ -14,3 +35,55 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_register_websocket_api(hass)
return True
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
event_callback: PipelineEventCallback,
stt_metadata: stt.SpeechMetadata,
stt_stream: AsyncIterable[bytes],
language: str | None = None,
pipeline_id: str | None = None,
conversation_id: str | None = None,
context: Context | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
if language is None:
language = hass.config.language
# Temporary workaround for language codes
if language == "en":
language = "en-US"
if stt_metadata.language == "":
stt_metadata.language = language
if context is None:
context = Context()
pipeline = async_get_pipeline(
hass,
pipeline_id=pipeline_id,
language=language,
)
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)
pipeline_input = PipelineInput(
conversation_id=conversation_id,
stt_metadata=stt_metadata,
stt_stream=stt_stream,
run=PipelineRun(
hass,
context=context,
pipeline=pipeline,
start_stage=PipelineStage.STT,
end_stage=PipelineStage.TTS,
event_callback=event_callback,
),
)
await pipeline_input.validate()
await pipeline_input.execute()

View file

@ -0,0 +1,30 @@
"""Voice Assistant errors."""
from homeassistant.exceptions import HomeAssistantError
class PipelineError(HomeAssistantError):
"""Base class for pipeline errors."""
def __init__(self, code: str, message: str) -> None:
"""Set error message."""
self.code = code
self.message = message
super().__init__(f"Pipeline error code={code}, message={message}")
class PipelineNotFound(PipelineError):
"""Unspecified pipeline picked."""
class SpeechToTextError(PipelineError):
"""Error in speech to text portion of pipeline."""
class IntentRecognitionError(PipelineError):
"""Error in intent recognition portion of pipeline."""
class TextToSpeechError(PipelineError):
"""Error in text to speech portion of pipeline."""

View file

@ -16,6 +16,12 @@ from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.util.dt import utcnow
from .const import DOMAIN
from .error import (
IntentRecognitionError,
PipelineError,
SpeechToTextError,
TextToSpeechError,
)
_LOGGER = logging.getLogger(__name__)
@ -39,29 +45,6 @@ def async_get_pipeline(
)
class PipelineError(Exception):
"""Base class for pipeline errors."""
def __init__(self, code: str, message: str) -> None:
"""Set error message."""
self.code = code
self.message = message
super().__init__(f"Pipeline error code={code}, message={message}")
class SpeechToTextError(PipelineError):
"""Error in speech to text portion of pipeline."""
class IntentRecognitionError(PipelineError):
"""Error in intent recognition portion of pipeline."""
class TextToSpeechError(PipelineError):
"""Error in text to speech portion of pipeline."""
class PipelineEventType(StrEnum):
"""Event types emitted during a pipeline run."""
@ -93,6 +76,9 @@ class PipelineEvent:
}
PipelineEventCallback = Callable[[PipelineEvent], None]
@dataclass
class Pipeline:
"""A voice assistant pipeline."""
@ -146,7 +132,7 @@ class PipelineRun:
pipeline: Pipeline
start_stage: PipelineStage
end_stage: PipelineStage
event_callback: Callable[[PipelineEvent], None]
event_callback: PipelineEventCallback
language: str = None # type: ignore[assignment]
runner_data: Any | None = None
stt_provider: stt.Provider | None = None

View file

@ -1268,7 +1268,7 @@ def mock_integration(
def mock_import_platform(platform_name: str) -> NoReturn:
raise ImportError(
f"Mocked unable to import platform '{platform_name}'",
f"Mocked unable to import platform '{integration.pkg_path}.{platform_name}'",
name=f"{integration.pkg_path}.{platform_name}",
)

View file

@ -0,0 +1,139 @@
"""Test fixtures for voice assistant."""
from collections.abc import AsyncIterable
from typing import Any
from unittest.mock import AsyncMock, Mock
import pytest
from homeassistant.components import stt, tts
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import MockModule, mock_integration, mock_platform
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mock_get_cache_files,
mock_init_cache_dir,
)
_TRANSCRIPT = "test transcript"
class MockSttProvider(stt.Provider):
"""Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None:
"""Init test provider."""
self.hass = hass
self.text = text
self.received = []
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["en-US"]
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
async for data in stream:
if not data:
break
self.received.append(data)
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
@property
def default_language(self) -> str:
"""Return the default language."""
return "en"
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return ["en-US"]
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age"]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
"""Load TTS data."""
return ("mp3", b"")
class MockTTS:
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> tts.Provider:
"""Set up a mock speech component."""
return MockTTSProvider()
@pytest.fixture
async def mock_stt_provider(hass) -> MockSttProvider:
"""Mock STT provider."""
return MockSttProvider(hass, _TRANSCRIPT)
@pytest.fixture(autouse=True)
async def init_components(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811,
):
"""Initialize relevant components with empty configs."""
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS())
mock_platform(
hass,
"test.stt",
Mock(async_get_engine=AsyncMock(return_value=mock_stt_provider)),
)
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "voice_assistant", {})

View file

@ -0,0 +1,85 @@
# serializer version: 1
# name: test_pipeline_from_audio_stream
list([
dict({
'data': dict({
'language': 'en-US',
'pipeline': 'en-US',
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'engine': 'test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'language': 'en-US',
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
}),
'type': <PipelineEventType.STT_START: 'stt-start'>,
}),
dict({
'data': dict({
'stt_output': dict({
'text': 'test transcript',
}),
}),
'type': <PipelineEventType.STT_END: 'stt-end'>,
}),
dict({
'data': dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': dict({
'engine': 'test',
'tts_input': "Sorry, I couldn't understand that",
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,
}),
dict({
'data': dict({
}),
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---

View file

@ -0,0 +1,42 @@
"""Test Voice Assistant init."""
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import stt, voice_assistant
from homeassistant.core import HomeAssistant
async def test_pipeline_from_audio_stream(
hass: HomeAssistant, mock_stt_provider, snapshot: SnapshotAssertion
) -> None:
"""Test creating a pipeline from an audio stream."""
events = []
async def audio_data():
yield b"part1"
yield b"part2"
yield b""
await voice_assistant.async_pipeline_from_audio_stream(
hass,
events.append,
stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
audio_data(),
)
processed = []
for event in events:
as_dict = event.as_dict()
as_dict.pop("timestamp")
processed.append(as_dict)
assert processed == snapshot
assert mock_stt_provider.received == [b"part1", b"part2"]

View file

@ -1,143 +1,13 @@
"""Websocket tests for Voice Assistant integration."""
import asyncio
from collections.abc import AsyncIterable
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import stt, tts
from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import MockModule, mock_integration, mock_platform
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
mock_get_cache_files,
mock_init_cache_dir,
)
from tests.typing import WebSocketGenerator
_TRANSCRIPT = "test transcript"
class MockSttProvider(stt.Provider):
"""Mock STT provider."""
def __init__(self, hass: HomeAssistant, text: str) -> None:
"""Init test provider."""
self.hass = hass
self.text = text
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["en-US"]
@property
def supported_formats(self) -> list[stt.AudioFormats]:
"""Return a list of supported formats."""
return [stt.AudioFormats.WAV]
@property
def supported_codecs(self) -> list[stt.AudioCodecs]:
"""Return a list of supported codecs."""
return [stt.AudioCodecs.PCM]
@property
def supported_bit_rates(self) -> list[stt.AudioBitRates]:
"""Return a list of supported bitrates."""
return [stt.AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[stt.AudioSampleRates]:
"""Return a list of supported samplerates."""
return [stt.AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[stt.AudioChannels]:
"""Return a list of supported channels."""
return [stt.AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes]
) -> stt.SpeechResult:
"""Process an audio stream."""
return stt.SpeechResult(self.text, stt.SpeechResultState.SUCCESS)
class MockSTT:
"""A mock STT platform."""
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> stt.Provider:
"""Set up a mock speech component."""
return MockSttProvider(hass, _TRANSCRIPT)
class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
@property
def default_language(self) -> str:
"""Return the default language."""
return "en"
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return ["en-US"]
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return ["voice", "age"]
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any] | None = None
) -> tts.TtsAudioType:
"""Load TTS dat."""
return ("mp3", b"")
class MockTTS:
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> tts.Provider:
"""Set up a mock speech component."""
return MockTTSProvider()
@pytest.fixture(autouse=True)
async def init_components(
hass: HomeAssistant,
mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811
):
"""Initialize relevant components with empty configs."""
mock_integration(hass, MockModule(domain="test"))
mock_platform(hass, "test.tts", MockTTS())
mock_platform(hass, "test.stt", MockSTT())
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "voice_assistant", {})
async def test_text_only_pipeline(
hass: HomeAssistant,
@ -211,7 +81,7 @@ async def test_audio_pipeline(
assert msg["event"]["data"] == snapshot
# End of audio stream (handler id + empty payload)
await client.send_bytes(b"1")
await client.send_bytes(bytes([1]))
msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end"
@ -438,7 +308,7 @@ async def test_stt_stream_failed(
) -> None:
"""Test events from a pipeline run with a non-existent STT provider."""
with patch(
"tests.components.voice_assistant.test_websocket.MockSttProvider.async_process_audio_stream",
"tests.components.voice_assistant.conftest.MockSttProvider.async_process_audio_stream",
new=MagicMock(side_effect=RuntimeError),
):
client = await hass_ws_client(hass)