Add cloud tts entity (#108293)

* Add cloud tts entity

* Test test_login_view_missing_entity

* Fix pipeline iteration for migration

* Update tests

* Make migration more strict

* Fix docstring
This commit is contained in:
Martin Hjelmare 2024-01-22 17:24:15 +01:00 committed by GitHub
parent d0da457a04
commit e086cd9fef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 428 additions and 102 deletions

View file

@ -65,7 +65,7 @@ from .subscription import async_subscription_info
DEFAULT_MODE = MODE_PROD
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT]
PLATFORMS = [Platform.BINARY_SENSOR, Platform.STT, Platform.TTS]
SERVICE_REMOTE_CONNECT = "remote_connect"
SERVICE_REMOTE_DISCONNECT = "remote_disconnect"
@ -288,9 +288,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
loaded = False
stt_platform_loaded = asyncio.Event()
tts_platform_loaded = asyncio.Event()
stt_tts_entities_added = asyncio.Event()
hass.data[DATA_PLATFORMS_SETUP] = {
Platform.STT: stt_platform_loaded,
Platform.TTS: tts_platform_loaded,
"stt_tts_entities_added": stt_tts_entities_added,
}
async def _on_start() -> None:
@ -330,6 +332,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
account_link.async_setup(hass)
# Load legacy tts platform for backwards compatibility.
hass.async_create_task(
async_load_platform(
hass,
@ -377,8 +380,10 @@ def _remote_handle_prefs_updated(cloud: Cloud[CloudClient]) -> None:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
stt_platform_loaded.set()
stt_tts_entities_added: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][
"stt_tts_entities_added"
]
stt_tts_entities_added.set()
return True

View file

@ -9,16 +9,23 @@ from homeassistant.components.assist_pipeline import (
)
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
from homeassistant.components.stt import DOMAIN as STT_DOMAIN
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
import homeassistant.helpers.entity_registry as er
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
from .const import (
DATA_PLATFORMS_SETUP,
DOMAIN,
STT_ENTITY_UNIQUE_ID,
TTS_ENTITY_UNIQUE_ID,
)
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
"""Create a cloud assist pipeline."""
# Wait for stt and tts platforms to set up before creating the pipeline.
# Wait for stt and tts platforms to set up and entities to be added
# before creating the pipeline.
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
await asyncio.gather(*(event.wait() for event in platforms_setup.values()))
# Make sure the pipeline store is loaded, needed because assist_pipeline
@ -29,8 +36,11 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
new_stt_engine_id = entity_registry.async_get_entity_id(
STT_DOMAIN, DOMAIN, STT_ENTITY_UNIQUE_ID
)
if new_stt_engine_id is None:
# If there's no cloud stt entity, we can't create a cloud pipeline.
new_tts_engine_id = entity_registry.async_get_entity_id(
TTS_DOMAIN, DOMAIN, TTS_ENTITY_UNIQUE_ID
)
if new_stt_engine_id is None or new_tts_engine_id is None:
# If there's no cloud stt or tts entity, we can't create a cloud pipeline.
return None
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
@ -43,7 +53,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
if (
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
and pipeline.stt_engine in (DOMAIN, new_stt_engine_id)
and pipeline.tts_engine == DOMAIN
and pipeline.tts_engine in (DOMAIN, new_tts_engine_id)
):
return pipeline.id
return None
@ -52,7 +62,7 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
cloud_pipeline := await async_create_default_pipeline(
hass,
stt_engine_id=new_stt_engine_id,
tts_engine_id=DOMAIN,
tts_engine_id=new_tts_engine_id,
pipeline_name="Home Assistant Cloud",
)
) is None:
@ -61,25 +71,34 @@ async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
return cloud_pipeline.id
async def async_migrate_cloud_pipeline_stt_engine(
hass: HomeAssistant, stt_engine_id: str
async def async_migrate_cloud_pipeline_engine(
hass: HomeAssistant, platform: Platform, engine_id: str
) -> None:
"""Migrate the speech-to-text engine in the cloud assist pipeline."""
# Migrate existing pipelines with cloud stt to use new cloud stt engine id.
# Added in 2024.01.0. Can be removed in 2025.01.0.
"""Migrate the pipeline engines in the cloud assist pipeline."""
# Migrate existing pipelines with cloud stt or tts to use new cloud engine id.
# Added in 2024.02.0. Can be removed in 2025.02.0.
# We need to make sure that both stt and tts are loaded before this migration.
# Assist pipeline will call default engine when setting up the store.
# Wait for the stt or tts platform loaded event here.
if platform == Platform.STT:
wait_for_platform = Platform.TTS
pipeline_attribute = "stt_engine"
elif platform == Platform.TTS:
wait_for_platform = Platform.STT
pipeline_attribute = "tts_engine"
else:
raise ValueError(f"Invalid platform {platform}")
# We need to make sure that tts is loaded before this migration.
# Assist pipeline will call default engine of tts when setting up the store.
# Wait for the tts platform loaded event here.
platforms_setup: dict[str, asyncio.Event] = hass.data[DATA_PLATFORMS_SETUP]
await platforms_setup[Platform.TTS].wait()
await platforms_setup[wait_for_platform].wait()
# Make sure the pipeline store is loaded, needed because assist_pipeline
# is an after dependency of cloud
await async_setup_pipeline_store(hass)
kwargs: dict[str, str] = {pipeline_attribute: engine_id}
pipelines = async_get_pipelines(hass)
for pipeline in pipelines:
if pipeline.stt_engine != DOMAIN:
continue
await async_update_pipeline(hass, pipeline, stt_engine=stt_engine_id)
if getattr(pipeline, pipeline_attribute) == DOMAIN:
await async_update_pipeline(hass, pipeline, **kwargs)

View file

@ -73,3 +73,4 @@ MODE_PROD = "production"
DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"

View file

@ -104,10 +104,18 @@ class CloudPreferences:
@callback
def async_listen_updates(
self, listener: Callable[[CloudPreferences], Coroutine[Any, Any, None]]
) -> None:
) -> Callable[[], None]:
"""Listen for updates to the preferences."""
@callback
def unsubscribe() -> None:
"""Remove the listener."""
self._listeners.remove(listener)
self._listeners.append(listener)
return unsubscribe
async def async_update(
self,
*,

View file

@ -1,6 +1,7 @@
"""Support for the cloud for speech to text service."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable
import logging
@ -19,12 +20,13 @@ from homeassistant.components.stt import (
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .assist_pipeline import async_migrate_cloud_pipeline_stt_engine
from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient
from .const import DOMAIN, STT_ENTITY_UNIQUE_ID
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
_LOGGER = logging.getLogger(__name__)
@ -35,18 +37,20 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Home Assistant Cloud speech platform via config entry."""
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
stt_platform_loaded.set()
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
async_add_entities([CloudProviderEntity(cloud)])
class CloudProviderEntity(SpeechToTextEntity):
"""NabuCasa speech API provider."""
"""Home Assistant Cloud speech API provider."""
_attr_name = "Home Assistant Cloud"
_attr_unique_id = STT_ENTITY_UNIQUE_ID
def __init__(self, cloud: Cloud[CloudClient]) -> None:
"""Home Assistant NabuCasa Speech to text."""
"""Initialize cloud Speech to text entity."""
self.cloud = cloud
@property
@ -81,7 +85,9 @@ class CloudProviderEntity(SpeechToTextEntity):
async def async_added_to_hass(self) -> None:
"""Run when entity is about to be added to hass."""
await async_migrate_cloud_pipeline_stt_engine(self.hass, self.entity_id)
await async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.STT, engine_id=self.entity_id
)
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]

View file

@ -1,6 +1,7 @@
"""Support for the cloud for text-to-speech service."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
@ -12,16 +13,21 @@ from homeassistant.components.tts import (
ATTR_AUDIO_OUTPUT,
ATTR_VOICE,
CONF_LANG,
PLATFORM_SCHEMA,
PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA,
Provider,
TextToSpeechEntity,
TtsAudioType,
Voice,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient
from .const import DOMAIN
from .const import DATA_PLATFORMS_SETUP, DOMAIN, TTS_ENTITY_UNIQUE_ID
from .prefs import CloudPreferences
ATTR_GENDER = "gender"
@ -48,7 +54,7 @@ def validate_lang(value: dict[str, Any]) -> dict[str, Any]:
PLATFORM_SCHEMA = vol.All(
PLATFORM_SCHEMA.extend(
TTS_PLATFORM_SCHEMA.extend(
{
vol.Optional(CONF_LANG): str,
vol.Optional(ATTR_GENDER): str,
@ -81,8 +87,95 @@ async def async_get_engine(
return cloud_provider
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Home Assistant Cloud text-to-speech platform."""
tts_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.TTS]
tts_platform_loaded.set()
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
async_add_entities([CloudTTSEntity(cloud)])
class CloudTTSEntity(TextToSpeechEntity):
"""Home Assistant Cloud text-to-speech entity."""
_attr_name = "Home Assistant Cloud"
_attr_unique_id = TTS_ENTITY_UNIQUE_ID
def __init__(self, cloud: Cloud[CloudClient]) -> None:
"""Initialize cloud text-to-speech entity."""
self.cloud = cloud
self._language, self._gender = cloud.client.prefs.tts_default_voice
async def _sync_prefs(self, prefs: CloudPreferences) -> None:
"""Sync preferences."""
self._language, self._gender = prefs.tts_default_voice
@property
def default_language(self) -> str:
"""Return the default language."""
return self._language
@property
def default_options(self) -> dict[str, Any]:
"""Return a dict include default options."""
return {
ATTR_GENDER: self._gender,
ATTR_AUDIO_OUTPUT: AudioOutput.MP3,
}
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return SUPPORT_LANGUAGES
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotion."""
return [ATTR_GENDER, ATTR_VOICE, ATTR_AUDIO_OUTPUT]
async def async_added_to_hass(self) -> None:
"""Handle entity which will be added."""
await super().async_added_to_hass()
await async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.TTS, engine_id=self.entity_id
)
self.async_on_remove(
self.cloud.client.prefs.async_listen_updates(self._sync_prefs)
)
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
"""Return a list of supported voices for a language."""
if not (voices := TTS_VOICES.get(language)):
return None
return [Voice(voice, voice) for voice in voices]
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS from Home Assistant Cloud."""
# Process TTS
try:
data = await self.cloud.voice.process_tts(
text=message,
language=language,
gender=options.get(ATTR_GENDER),
voice=options.get(ATTR_VOICE),
output=options[ATTR_AUDIO_OUTPUT],
)
except VoiceError as err:
_LOGGER.error("Voice error: %s", err)
return (None, None)
return (str(options[ATTR_AUDIO_OUTPUT].value), data)
class CloudProvider(Provider):
"""NabuCasa Cloud speech API provider."""
"""Home Assistant Cloud speech API provider."""
def __init__(
self, cloud: Cloud[CloudClient], language: str | None, gender: str | None
@ -136,7 +229,7 @@ class CloudProvider(Provider):
async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS from NabuCasa Cloud."""
"""Load TTS from Home Assistant Cloud."""
# Process TTS
try:
data = await self.cloud.voice.process_tts(

View file

@ -7,6 +7,54 @@ from homeassistant.components import cloud
from homeassistant.components.cloud import const, prefs as cloud_prefs
from homeassistant.setup import async_setup_component
PIPELINE_DATA = {
"items": [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "Home Assistant Cloud",
"stt_engine": "cloud",
"stt_language": "language_1",
"tts_engine": "cloud",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_2",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
"wake_word_entity": None,
"wake_word_id": None,
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
}
async def mock_cloud(hass, config=None):
"""Mock cloud."""

View file

@ -15,11 +15,22 @@ import jwt
import pytest
from homeassistant.components.cloud import CloudClient, const, prefs
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
from . import mock_cloud, mock_cloud_prefs
@pytest.fixture(autouse=True)
async def load_homeassistant(hass: HomeAssistant) -> None:
"""Load the homeassistant integration.
This is needed for the cloud integration to work.
"""
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture(name="cloud")
async def cloud_fixture() -> AsyncGenerator[MagicMock, None]:
"""Mock the cloud object.

View file

@ -0,0 +1,16 @@
"""Test the cloud assist pipeline."""
import pytest
from homeassistant.components.cloud.assist_pipeline import (
async_migrate_cloud_pipeline_engine,
)
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
async def test_migrate_pipeline_invalid_platform(hass: HomeAssistant) -> None:
"""Test migrate pipeline with invalid platform."""
with pytest.raises(ValueError):
await async_migrate_cloud_pipeline_engine(
hass, Platform.BINARY_SENSOR, "test-engine-id"
)

View file

@ -147,15 +147,19 @@ async def test_google_actions_sync_fails(
assert mock_request_sync.call_count == 1
async def test_login_view_missing_stt_entity(
@pytest.mark.parametrize(
"entity_id", ["stt.home_assistant_cloud", "tts.home_assistant_cloud"]
)
async def test_login_view_missing_entity(
hass: HomeAssistant,
setup_cloud: None,
entity_registry: er.EntityRegistry,
hass_client: ClientSessionGenerator,
entity_id: str,
) -> None:
"""Test logging in when the cloud stt entity is missing."""
# Make sure that the cloud stt entity does not exist.
entity_registry.async_remove("stt.home_assistant_cloud")
"""Test logging in when a cloud assist pipeline needed entity is missing."""
# Make sure that the cloud entity does not exist.
entity_registry.async_remove(entity_id)
await hass.async_block_till_done()
cloud_client = await hass_client()
@ -243,7 +247,7 @@ async def test_login_view_create_pipeline(
create_pipeline_mock.assert_awaited_once_with(
hass,
stt_engine_id="stt.home_assistant_cloud",
tts_engine_id="cloud",
tts_engine_id="tts.home_assistant_cloud",
pipeline_name="Home Assistant Cloud",
)
@ -282,7 +286,7 @@ async def test_login_view_create_pipeline_fail(
create_pipeline_mock.assert_awaited_once_with(
hass,
stt_engine_id="stt.home_assistant_cloud",
tts_engine_id="cloud",
tts_engine_id="tts.home_assistant_cloud",
pipeline_name="Home Assistant Cloud",
)

View file

@ -14,62 +14,10 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import PIPELINE_DATA
from tests.typing import ClientSessionGenerator
PIPELINE_DATA = {
"items": [
{
"conversation_engine": "conversation_engine_1",
"conversation_language": "language_1",
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
"language": "language_1",
"name": "Home Assistant Cloud",
"stt_engine": "cloud",
"stt_language": "language_1",
"tts_engine": "cloud",
"tts_language": "language_1",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
},
{
"conversation_engine": "conversation_engine_2",
"conversation_language": "language_2",
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
"language": "language_2",
"name": "name_2",
"stt_engine": "stt_engine_2",
"stt_language": "language_2",
"tts_engine": "tts_engine_2",
"tts_language": "language_2",
"tts_voice": "The Voice",
"wake_word_entity": None,
"wake_word_id": None,
},
{
"conversation_engine": "conversation_engine_3",
"conversation_language": "language_3",
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
}
@pytest.fixture(autouse=True)
async def load_homeassistant(hass: HomeAssistant) -> None:
"""Load the homeassistant integration."""
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None, None]:
@ -143,6 +91,7 @@ async def test_migrating_pipelines(
hass_storage: dict[str, Any],
) -> None:
"""Test migrating pipelines when cloud stt entity is added."""
entity_id = "stt.home_assistant_cloud"
cloud.voice.process_stt = AsyncMock(
return_value=STTResponse(True, "Turn the Kitchen Lights on")
)
@ -157,18 +106,18 @@ async def test_migrating_pipelines(
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
on_start_callback = cloud.register_on_start.call_args[0][0]
await on_start_callback()
await cloud.login("test-user", "test-pass")
await hass.async_block_till_done()
state = hass.states.get("stt.home_assistant_cloud")
state = hass.states.get(entity_id)
assert state
assert state.state == STATE_UNKNOWN
# The stt engine should be updated to the new cloud stt engine id.
# The stt/tts engines should have been updated to the new cloud engine ids.
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"] == entity_id
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
== "stt.home_assistant_cloud"
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"]
== "tts.home_assistant_cloud"
)
# The other items should stay the same.
@ -189,7 +138,6 @@ async def test_migrating_pipelines(
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == "cloud"
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]

View file

@ -1,23 +1,36 @@
"""Tests for cloud tts."""
from collections.abc import Callable, Coroutine
from collections.abc import AsyncGenerator, Callable, Coroutine
from copy import deepcopy
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
from hass_nabucasa.voice import MAP_VOICE, VoiceError, VoiceTokenError
import pytest
import voluptuous as vol
from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
from homeassistant.components.cloud import DOMAIN, const, tts
from homeassistant.components.tts import DOMAIN as TTS_DOMAIN
from homeassistant.components.tts.helper import get_engine_instance
from homeassistant.config import async_process_ha_core_config
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_registry import EntityRegistry
from homeassistant.setup import async_setup_component
from . import PIPELINE_DATA
from tests.typing import ClientSessionGenerator
@pytest.fixture(autouse=True)
async def delay_save_fixture() -> AsyncGenerator[None, None]:
"""Load the homeassistant integration."""
with patch("homeassistant.helpers.collection.SAVE_DELAY", new=0):
yield
@pytest.fixture(autouse=True)
async def internal_url_mock(hass: HomeAssistant) -> None:
"""Mock internal URL of the instance."""
@ -70,6 +83,10 @@ def test_schema() -> None:
"gender": "female",
},
),
(
"tts.home_assistant_cloud",
None,
),
],
)
async def test_prefs_default_voice(
@ -104,9 +121,17 @@ async def test_prefs_default_voice(
assert engine.default_options == {"gender": "male", "audio_output": "mp3"}
@pytest.mark.parametrize(
"engine_id",
[
DOMAIN,
"tts.home_assistant_cloud",
],
)
async def test_provider_properties(
hass: HomeAssistant,
cloud: MagicMock,
engine_id: str,
) -> None:
"""Test cloud provider."""
assert await async_setup_component(hass, "homeassistant", {})
@ -115,7 +140,7 @@ async def test_provider_properties(
on_start_callback = cloud.register_on_start.call_args[0][0]
await on_start_callback()
engine = get_engine_instance(hass, DOMAIN)
engine = get_engine_instance(hass, engine_id)
assert engine is not None
assert engine.supported_options == ["gender", "voice", "audio_output"]
@ -132,6 +157,7 @@ async def test_provider_properties(
[
({"platform": DOMAIN}, DOMAIN),
({"engine_id": DOMAIN}, DOMAIN),
({"engine_id": "tts.home_assistant_cloud"}, "tts.home_assistant_cloud"),
],
)
@pytest.mark.parametrize(
@ -241,3 +267,144 @@ async def test_get_tts_audio_logged_out(
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
assert mock_process_tts.call_args.kwargs["gender"] == "female"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
@pytest.mark.parametrize(
("mock_process_tts_return_value", "mock_process_tts_side_effect"),
[
(b"", None),
(None, VoiceError("Boom!")),
],
)
async def test_tts_entity(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
entity_registry: EntityRegistry,
cloud: MagicMock,
mock_process_tts_return_value: bytes | None,
mock_process_tts_side_effect: Exception | None,
) -> None:
"""Test text-to-speech entity."""
mock_process_tts = AsyncMock(
return_value=mock_process_tts_return_value,
side_effect=mock_process_tts_side_effect,
)
cloud.voice.process_tts = mock_process_tts
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
on_start_callback = cloud.register_on_start.call_args[0][0]
await on_start_callback()
client = await hass_client()
entity_id = "tts.home_assistant_cloud"
state = hass.states.get(entity_id)
assert state
assert state.state == STATE_UNKNOWN
url = "/api/tts_get_url"
data = {
"engine_id": entity_id,
"message": "There is someone at the door.",
}
req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK
response = await req.json()
assert response == {
"url": (
"http://example.local:8123/api/tts_proxy/"
"42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_e09b5a0968_{entity_id}.mp3"
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_e09b5a0968_{entity_id}.mp3"
),
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None
assert mock_process_tts.call_args.kwargs["text"] == "There is someone at the door."
assert mock_process_tts.call_args.kwargs["language"] == "en-US"
assert mock_process_tts.call_args.kwargs["gender"] == "female"
assert mock_process_tts.call_args.kwargs["output"] == "mp3"
state = hass.states.get(entity_id)
assert state
assert state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
# Test removing the entity
entity_registry.async_remove(entity_id)
await hass.async_block_till_done()
state = hass.states.get(entity_id)
assert state is None
async def test_migrating_pipelines(
hass: HomeAssistant,
cloud: MagicMock,
hass_client: ClientSessionGenerator,
hass_storage: dict[str, Any],
) -> None:
"""Test migrating pipelines when cloud tts entity is added."""
entity_id = "tts.home_assistant_cloud"
mock_process_tts = AsyncMock(
return_value=b"",
)
cloud.voice.process_tts = mock_process_tts
hass_storage[STORAGE_KEY] = {
"version": 1,
"minor_version": 1,
"key": "assist_pipeline.pipelines",
"data": deepcopy(PIPELINE_DATA),
}
assert await async_setup_component(hass, "assist_pipeline", {})
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
await cloud.login("test-user", "test-pass")
await hass.async_block_till_done()
state = hass.states.get(entity_id)
assert state
assert state.state == STATE_UNKNOWN
# The stt/tts engines should have been updated to the new cloud engine ids.
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_engine"]
== "stt.home_assistant_cloud"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_engine"] == entity_id
# The other items should stay the same.
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_engine"]
== "conversation_engine_1"
)
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["conversation_language"]
== "language_1"
)
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["id"]
== "01GX8ZWBAQYWNB1XV3EXEZ75DY"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["language"] == "language_1"
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["name"] == "Home Assistant Cloud"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["stt_language"] == "language_1"
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_language"] == "language_1"
assert (
hass_storage[STORAGE_KEY]["data"]["items"][0]["tts_voice"]
== "Arnold Schwarzenegger"
)
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_entity"] is None
assert hass_storage[STORAGE_KEY]["data"]["items"][0]["wake_word_id"] is None
assert hass_storage[STORAGE_KEY]["data"]["items"][1] == PIPELINE_DATA["items"][1]
assert hass_storage[STORAGE_KEY]["data"]["items"][2] == PIPELINE_DATA["items"][2]