diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index 2cfbc09ed081..fb7f5c3b21c4 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -3,6 +3,7 @@ from __future__ import annotations import codecs +from collections.abc import Callable from typing import Any, Literal from google.api_core.exceptions import GoogleAPICallError @@ -89,10 +90,14 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]: return result -def _format_tool(tool: llm.Tool) -> dict[str, Any]: +def _format_tool( + tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None +) -> dict[str, Any]: """Format tool specification.""" - parameters = _format_schema(convert(tool.parameters)) + parameters = _format_schema( + convert(tool.parameters, custom_serializer=custom_serializer) + ) return protos.Tool( { @@ -193,7 +198,9 @@ class GoogleGenerativeAIConversationEntity( f"Error preparing LLM API: {err}", ) return result - tools = [_format_tool(tool) for tool in llm_api.tools] + tools = [ + _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools + ] try: prompt = await self._async_render_prompt(user_input, llm_api, llm_context) diff --git a/homeassistant/components/google_generative_ai_conversation/manifest.json b/homeassistant/components/google_generative_ai_conversation/manifest.json index 168fee105a0a..9e0dc1ddeab7 100644 --- a/homeassistant/components/google_generative_ai_conversation/manifest.json +++ b/homeassistant/components/google_generative_ai_conversation/manifest.json @@ -9,5 +9,5 @@ "integration_type": "service", "iot_class": "cloud_polling", "quality_scale": "platinum", - "requirements": ["google-generativeai==0.6.0", "voluptuous-openapi==0.0.4"] + "requirements": ["google-generativeai==0.6.0"] } diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 40242f5c6cc7..46be803bcadb 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -1,7 +1,8 @@ """Conversation support for OpenAI.""" +from collections.abc import Callable import json -from typing import Literal +from typing import Any, Literal import openai from openai._types import NOT_GIVEN @@ -58,9 +59,14 @@ async def async_setup_entry( async_add_entities([agent]) -def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam: +def _format_tool( + tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None +) -> ChatCompletionToolParam: """Format tool specification.""" - tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters)) + tool_spec = FunctionDefinition( + name=tool.name, + parameters=convert(tool.parameters, custom_serializer=custom_serializer), + ) if tool.description: tool_spec["description"] = tool.description return ChatCompletionToolParam(type="function", function=tool_spec) @@ -139,7 +145,9 @@ class OpenAIConversationEntity( return conversation.ConversationResult( response=intent_response, conversation_id=user_input.conversation_id ) - tools = [_format_tool(tool) for tool in llm_api.tools] + tools = [ + _format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools + ] if user_input.conversation_id is None: conversation_id = ulid.ulid_now() diff --git a/homeassistant/components/openai_conversation/manifest.json b/homeassistant/components/openai_conversation/manifest.json index 480712574c4e..0c06a3d4cd83 100644 --- a/homeassistant/components/openai_conversation/manifest.json +++ b/homeassistant/components/openai_conversation/manifest.json @@ -8,5 +8,5 @@ "documentation": "https://www.home-assistant.io/integrations/openai_conversation", "integration_type": "service", "iot_class": "cloud_polling", - "requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"] + "requirements": ["openai==1.3.8"] } diff --git a/homeassistant/helpers/intent.py b/homeassistant/helpers/intent.py index b1ddf5eacc77..502b20eaf8f7 100644 --- a/homeassistant/helpers/intent.py +++ b/homeassistant/helpers/intent.py @@ -352,7 +352,7 @@ class MatchTargetsCandidate: matched_name: str | None = None -def _find_areas( +def find_areas( name: str, areas: area_registry.AreaRegistry ) -> Iterable[area_registry.AreaEntry]: """Find all areas matching a name (including aliases).""" @@ -372,7 +372,7 @@ def _find_areas( break -def _find_floors( +def find_floors( name: str, floors: floor_registry.FloorRegistry ) -> Iterable[floor_registry.FloorEntry]: """Find all floors matching a name (including aliases).""" @@ -530,7 +530,7 @@ def async_match_targets( # noqa: C901 if not states: return MatchTargetsResult(False, MatchFailedReason.STATE) - # Exit early so we can to avoid registry lookups + # Exit early so we can avoid registry lookups if not ( constraints.name or constraints.features @@ -580,7 +580,7 @@ def async_match_targets( # noqa: C901 if constraints.floor_name: # Filter by areas associated with floor fr = floor_registry.async_get(hass) - targeted_floors = list(_find_floors(constraints.floor_name, fr)) + targeted_floors = list(find_floors(constraints.floor_name, fr)) if not targeted_floors: return MatchTargetsResult( False, @@ -609,7 +609,7 @@ def async_match_targets( # noqa: C901 possible_area_ids = {area.id for area in ar.async_list_areas()} if constraints.area_name: - targeted_areas = list(_find_areas(constraints.area_name, ar)) + targeted_areas = list(find_areas(constraints.area_name, ar)) if not targeted_areas: return MatchTargetsResult( False, diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index a4e18fdb2c0d..480b9cb52373 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass from decimal import Decimal from enum import Enum @@ -11,6 +12,7 @@ from typing import Any import slugify as unicode_slug import voluptuous as vol +from voluptuous_openapi import UNSUPPORTED, convert from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE from homeassistant.components.conversation.trace import ( @@ -20,22 +22,39 @@ from homeassistant.components.conversation.trace import ( from homeassistant.components.cover.intent import INTENT_CLOSE_COVER, INTENT_OPEN_COVER from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.components.intent import async_device_supports_timers +from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN from homeassistant.components.weather.intent import INTENT_GET_WEATHER -from homeassistant.core import Context, HomeAssistant, callback +from homeassistant.const import ( + ATTR_DOMAIN, + ATTR_ENTITY_ID, + ATTR_SERVICE, + EVENT_HOMEASSISTANT_CLOSE, + EVENT_SERVICE_REMOVED, + SERVICE_TURN_ON, +) +from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id from homeassistant.exceptions import HomeAssistantError from homeassistant.util import yaml +from homeassistant.util.hass_dict import HassKey from homeassistant.util.json import JsonObjectType from . import ( area_registry as ar, + config_validation as cv, device_registry as dr, entity_registry as er, floor_registry as fr, intent, + selector, service, ) from .singleton import singleton +SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey( + "llm_script_parameters_cache" +) + + LLM_API_ASSIST = "assist" BASE_PROMPT = ( @@ -143,6 +162,7 @@ class APIInstance: api_prompt: str llm_context: LLMContext tools: list[Tool] + custom_serializer: Callable[[Any], Any] | None = None async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: """Call a LLM tool, validate args and return the response.""" @@ -284,6 +304,7 @@ class AssistAPI(API): api_prompt=self._async_get_api_prompt(llm_context, exposed_entities), llm_context=llm_context, tools=self._async_get_tools(llm_context, exposed_entities), + custom_serializer=_selector_serializer, ) @callback @@ -372,7 +393,7 @@ class AssistAPI(API): exposed_domains: set[str] | None = None if exposed_entities is not None: exposed_domains = { - entity_id.split(".")[0] for entity_id in exposed_entities + split_entity_id(entity_id)[0] for entity_id in exposed_entities } intent_handlers = [ intent_handler @@ -381,11 +402,22 @@ class AssistAPI(API): or intent_handler.platforms & exposed_domains ] - return [ + tools: list[Tool] = [ IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler) for intent_handler in intent_handlers ] + if llm_context.assistant is not None: + for state in self.hass.states.async_all(SCRIPT_DOMAIN): + if not async_should_expose( + self.hass, llm_context.assistant, state.entity_id + ): + continue + + tools.append(ScriptTool(self.hass, state.entity_id)) + + return tools + def _get_exposed_entities( hass: HomeAssistant, assistant: str @@ -413,13 +445,15 @@ def _get_exposed_entities( entities = {} for state in hass.states.async_all(): + if state.domain == SCRIPT_DOMAIN: + continue + if not async_should_expose(hass, assistant, state.entity_id): continue entity_entry = entity_registry.async_get(state.entity_id) names = [state.name] area_names = [] - description: str | None = None if entity_entry is not None: names.extend(entity_entry.aliases) @@ -439,25 +473,11 @@ def _get_exposed_entities( area_names.append(area.name) area_names.extend(area.aliases) - if ( - state.domain == "script" - and entity_entry.unique_id - and ( - service_desc := service.async_get_cached_service_description( - hass, "script", entity_entry.unique_id - ) - ) - ): - description = service_desc.get("description") - info: dict[str, Any] = { "names": ", ".join(names), "state": state.state, } - if description: - info["description"] = description - if area_names: info["areas"] = ", ".join(area_names) @@ -473,3 +493,231 @@ def _get_exposed_entities( entities[state.entity_id] = info return entities + + +def _selector_serializer(schema: Any) -> Any: # noqa: C901 + """Convert selectors into OpenAPI schema.""" + if not isinstance(schema, selector.Selector): + return UNSUPPORTED + + if isinstance(schema, selector.BackupLocationSelector): + return {"type": "string", "pattern": "^(?:\\/backup|\\w+)$"} + + if isinstance(schema, selector.BooleanSelector): + return {"type": "boolean"} + + if isinstance(schema, selector.ColorRGBSelector): + return { + "type": "array", + "items": {"type": "number"}, + "minItems": 3, + "maxItems": 3, + "format": "RGB", + } + + if isinstance(schema, selector.ConditionSelector): + return convert(cv.CONDITIONS_SCHEMA) + + if isinstance(schema, selector.ConstantSelector): + return {"enum": [schema.config["value"]]} + + result: dict[str, Any] + if isinstance(schema, selector.ColorTempSelector): + result = {"type": "number"} + if "min" in schema.config: + result["minimum"] = schema.config["min"] + elif "min_mireds" in schema.config: + result["minimum"] = schema.config["min_mireds"] + if "max" in schema.config: + result["maximum"] = schema.config["max"] + elif "max_mireds" in schema.config: + result["maximum"] = schema.config["max_mireds"] + return result + + if isinstance(schema, selector.CountrySelector): + if schema.config.get("countries"): + return {"type": "string", "enum": schema.config["countries"]} + return {"type": "string", "format": "ISO 3166-1 alpha-2"} + + if isinstance(schema, selector.DateSelector): + return {"type": "string", "format": "date"} + + if isinstance(schema, selector.DateTimeSelector): + return {"type": "string", "format": "date-time"} + + if isinstance(schema, selector.DurationSelector): + return convert(cv.time_period_dict) + + if isinstance(schema, selector.EntitySelector): + if schema.config.get("multiple"): + return {"type": "array", "items": {"type": "string", "format": "entity_id"}} + + return {"type": "string", "format": "entity_id"} + + if isinstance(schema, selector.LanguageSelector): + if schema.config.get("languages"): + return {"type": "string", "enum": schema.config["languages"]} + return {"type": "string", "format": "RFC 5646"} + + if isinstance(schema, (selector.LocationSelector, selector.MediaSelector)): + return convert(schema.DATA_SCHEMA) + + if isinstance(schema, selector.NumberSelector): + result = {"type": "number"} + if "min" in schema.config: + result["minimum"] = schema.config["min"] + if "max" in schema.config: + result["maximum"] = schema.config["max"] + return result + + if isinstance(schema, selector.ObjectSelector): + return {"type": "object"} + + if isinstance(schema, selector.SelectSelector): + options = [ + x["value"] if isinstance(x, dict) else x for x in schema.config["options"] + ] + if schema.config.get("multiple"): + return { + "type": "array", + "items": {"type": "string", "enum": options}, + "uniqueItems": True, + } + return {"type": "string", "enum": options} + + if isinstance(schema, selector.TargetSelector): + return convert(cv.TARGET_SERVICE_FIELDS) + + if isinstance(schema, selector.TemplateSelector): + return {"type": "string", "format": "jinja2"} + + if isinstance(schema, selector.TimeSelector): + return {"type": "string", "format": "time"} + + if isinstance(schema, selector.TriggerSelector): + return convert(cv.TRIGGER_SCHEMA) + + if schema.config.get("multiple"): + return {"type": "array", "items": {"type": "string"}} + + return {"type": "string"} + + +class ScriptTool(Tool): + """LLM Tool representing a Script.""" + + def __init__( + self, + hass: HomeAssistant, + script_entity_id: str, + ) -> None: + """Init the class.""" + entity_registry = er.async_get(hass) + + self.name = split_entity_id(script_entity_id)[1] + self.parameters = vol.Schema({}) + entity_entry = entity_registry.async_get(script_entity_id) + if entity_entry and entity_entry.unique_id: + parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE) + + if parameters_cache is None: + parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {} + + @callback + def clear_cache(event: Event) -> None: + """Clear script parameter cache on script reload or delete.""" + if ( + event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN + and event.data[ATTR_SERVICE] in parameters_cache + ): + parameters_cache.pop(event.data[ATTR_SERVICE]) + + cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache) + + @callback + def on_homeassistant_close(event: Event) -> None: + """Cleanup.""" + cancel() + + hass.bus.async_listen_once( + EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close + ) + + if entity_entry.unique_id in parameters_cache: + self.description, self.parameters = parameters_cache[ + entity_entry.unique_id + ] + return + + if service_desc := service.async_get_cached_service_description( + hass, SCRIPT_DOMAIN, entity_entry.unique_id + ): + self.description = service_desc.get("description") + schema: dict[vol.Marker, Any] = {} + fields = service_desc.get("fields", {}) + + for field, config in fields.items(): + description = config.get("description") + if not description: + description = config.get("name") + if config.get("required"): + key = vol.Required(field, description=description) + else: + key = vol.Optional(field, description=description) + if "selector" in config: + schema[key] = selector.selector(config["selector"]) + else: + schema[key] = cv.string + + self.parameters = vol.Schema(schema) + + parameters_cache[entity_entry.unique_id] = ( + self.description, + self.parameters, + ) + + async def async_call( + self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext + ) -> JsonObjectType: + """Run the script.""" + + for field, validator in self.parameters.schema.items(): + if field not in tool_input.tool_args: + continue + if isinstance(validator, selector.AreaSelector): + area_reg = ar.async_get(hass) + if validator.config.get("multiple"): + areas: list[ar.AreaEntry] = [] + for area in tool_input.tool_args[field]: + areas.extend(intent.find_areas(area, area_reg)) + tool_input.tool_args[field] = list({area.id for area in areas}) + else: + area = tool_input.tool_args[field] + area = list(intent.find_areas(area, area_reg))[0].id + tool_input.tool_args[field] = area + + elif isinstance(validator, selector.FloorSelector): + floor_reg = fr.async_get(hass) + if validator.config.get("multiple"): + floors: list[fr.FloorEntry] = [] + for floor in tool_input.tool_args[field]: + floors.extend(intent.find_floors(floor, floor_reg)) + tool_input.tool_args[field] = list( + {floor.floor_id for floor in floors} + ) + else: + floor = tool_input.tool_args[field] + floor = list(intent.find_floors(floor, floor_reg))[0].floor_id + tool_input.tool_args[field] = floor + + await hass.services.async_call( + SCRIPT_DOMAIN, + SERVICE_TURN_ON, + { + ATTR_ENTITY_ID: SCRIPT_DOMAIN + "." + self.name, + ATTR_VARIABLES: tool_input.tool_args, + }, + context=llm_context.context, + ) + + return {"success": True} diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index 1db4dd9f80bd..16aaa40db867 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -75,6 +75,13 @@ class Selector[_T: Mapping[str, Any]]: self.config = self.CONFIG_SCHEMA(config) + def __eq__(self, other: object) -> bool: + """Check equality.""" + if not isinstance(other, Selector): + return NotImplemented + + return self.selector_type == other.selector_type and self.config == other.config + def serialize(self) -> dict[str, dict[str, _T]]: """Serialize Selector for voluptuous_serialize.""" return {"selector": {self.selector_type: self.config}} @@ -278,7 +285,7 @@ class AssistPipelineSelector(Selector[AssistPipelineSelectorConfig]): CONFIG_SCHEMA = vol.Schema({}) - def __init__(self, config: AssistPipelineSelectorConfig) -> None: + def __init__(self, config: AssistPipelineSelectorConfig | None = None) -> None: """Instantiate a selector.""" super().__init__(config) @@ -430,10 +437,10 @@ class ColorTempSelector(Selector[ColorTempSelectorConfig]): range_min = self.config.get("min") range_max = self.config.get("max") - if not range_min: + if range_min is None: range_min = self.config.get("min_mireds") - if not range_max: + if range_max is None: range_max = self.config.get("max_mireds") value: int = vol.All( @@ -517,7 +524,7 @@ class ConstantSelector(Selector[ConstantSelectorConfig]): } ) - def __init__(self, config: ConstantSelectorConfig | None = None) -> None: + def __init__(self, config: ConstantSelectorConfig) -> None: """Instantiate a selector.""" super().__init__(config) @@ -560,7 +567,7 @@ class QrCodeSelector(Selector[QrCodeSelectorConfig]): } ) - def __init__(self, config: QrCodeSelectorConfig | None = None) -> None: + def __init__(self, config: QrCodeSelectorConfig) -> None: """Instantiate a selector.""" super().__init__(config) @@ -588,7 +595,7 @@ class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]): } ) - def __init__(self, config: ConversationAgentSelectorConfig) -> None: + def __init__(self, config: ConversationAgentSelectorConfig | None = None) -> None: """Instantiate a selector.""" super().__init__(config) @@ -820,7 +827,7 @@ class FloorSelectorConfig(TypedDict, total=False): @SELECTORS.register("floor") -class FloorSelector(Selector[AreaSelectorConfig]): +class FloorSelector(Selector[FloorSelectorConfig]): """Selector of a single or list of floors.""" selector_type = "floor" @@ -934,7 +941,7 @@ class LanguageSelector(Selector[LanguageSelectorConfig]): } ) - def __init__(self, config: LanguageSelectorConfig) -> None: + def __init__(self, config: LanguageSelectorConfig | None = None) -> None: """Instantiate a selector.""" super().__init__(config) @@ -1159,7 +1166,7 @@ class SelectSelector(Selector[SelectSelectorConfig]): } ) - def __init__(self, config: SelectSelectorConfig | None = None) -> None: + def __init__(self, config: SelectSelectorConfig) -> None: """Instantiate a selector.""" super().__init__(config) @@ -1434,7 +1441,7 @@ class FileSelector(Selector[FileSelectorConfig]): } ) - def __init__(self, config: FileSelectorConfig | None = None) -> None: + def __init__(self, config: FileSelectorConfig) -> None: """Instantiate a selector.""" super().__init__(config) diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index f3be7c5515e7..25d108742395 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -58,6 +58,7 @@ SQLAlchemy==2.0.31 typing-extensions>=4.12.2,<5.0 ulid-transform==0.9.0 urllib3>=1.26.5,<2 +voluptuous-openapi==0.0.4 voluptuous-serialize==2.6.0 voluptuous==0.13.1 webrtc-noise-gain==1.2.3 diff --git a/pyproject.toml b/pyproject.toml index d7fbe67edba7..6ecbb8b51d13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ dependencies = [ "urllib3>=1.26.5,<2", "voluptuous==0.13.1", "voluptuous-serialize==2.6.0", + "voluptuous-openapi==0.0.4", "yarl==1.9.4", ] diff --git a/requirements.txt b/requirements.txt index cff85c2478f7..5b1c57c7e1c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,4 +41,5 @@ ulid-transform==0.9.0 urllib3>=1.26.5,<2 voluptuous==0.13.1 voluptuous-serialize==2.6.0 +voluptuous-openapi==0.0.4 yarl==1.9.4 diff --git a/requirements_all.txt b/requirements_all.txt index 48eae313cf48..14c4ed00a0a9 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2846,10 +2846,6 @@ voip-utils==0.1.0 # homeassistant.components.volkszaehler volkszaehler==0.4.0 -# homeassistant.components.google_generative_ai_conversation -# homeassistant.components.openai_conversation -voluptuous-openapi==0.0.4 - # homeassistant.components.volvooncall volvooncall==0.10.3 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 3b3adcf409ae..0e70472b67b9 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2217,10 +2217,6 @@ vilfo-api-client==0.5.0 # homeassistant.components.voip voip-utils==0.1.0 -# homeassistant.components.google_generative_ai_conversation -# homeassistant.components.openai_conversation -voluptuous-openapi==0.0.4 - # homeassistant.components.volvooncall volvooncall==0.10.3 diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 5389490b401a..872297b09ec5 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -8,6 +8,7 @@ import voluptuous as vol from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.intent import async_register_timer_handler +from homeassistant.components.script.config import ScriptConfig from homeassistant.core import Context, HomeAssistant, State from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import ( @@ -18,6 +19,7 @@ from homeassistant.helpers import ( floor_registry as fr, intent, llm, + selector, ) from homeassistant.setup import async_setup_component from homeassistant.util import yaml @@ -564,11 +566,6 @@ async def test_assist_api_prompt( "names": "Unnamed Device", "state": "unavailable", }, - "script.test_script": { - "description": "This is a test script", - "names": "test_script", - "state": "off", - }, } exposed_entities_prompt = ( "An overview of the areas and the devices in this smart home:\n" @@ -634,3 +631,323 @@ async def test_assist_api_prompt( {area_prompt} {exposed_entities_prompt}""" ) + + +async def test_script_tool( + hass: HomeAssistant, + area_registry: ar.AreaRegistry, + floor_registry: fr.FloorRegistry, +) -> None: + """Test ScriptTool for the assist API.""" + assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "intent", {}) + context = Context() + llm_context = llm.LLMContext( + platform="test_platform", + context=context, + user_prompt="test_text", + language="*", + assistant="conversation", + device_id=None, + ) + + # Create a script with a unique ID + assert await async_setup_component( + hass, + "script", + { + "script": { + "test_script": { + "description": "This is a test script", + "sequence": [], + "fields": { + "beer": {"description": "Number of beers", "required": True}, + "wine": {"selector": {"number": {"min": 0, "max": 3}}}, + "where": {"selector": {"area": {}}}, + "area_list": {"selector": {"area": {"multiple": True}}}, + "floor": {"selector": {"floor": {}}}, + "floor_list": {"selector": {"floor": {"multiple": True}}}, + "extra_field": {"selector": {"area": {}}}, + }, + }, + "unexposed_script": { + "sequence": [], + }, + } + }, + ) + async_expose_entity(hass, "conversation", "script.test_script", True) + + area = area_registry.async_create("Living room") + floor = floor_registry.async_create("2") + + assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data + + api = await llm.async_get_api(hass, "assist", llm_context) + + tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)] + assert len(tools) == 1 + + tool = tools[0] + assert tool.name == "test_script" + assert tool.description == "This is a test script" + schema = { + vol.Required("beer", description="Number of beers"): cv.string, + vol.Optional("wine"): selector.NumberSelector({"min": 0, "max": 3}), + vol.Optional("where"): selector.AreaSelector(), + vol.Optional("area_list"): selector.AreaSelector({"multiple": True}), + vol.Optional("floor"): selector.FloorSelector(), + vol.Optional("floor_list"): selector.FloorSelector({"multiple": True}), + vol.Optional("extra_field"): selector.AreaSelector(), + } + assert tool.parameters.schema == schema + + assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == { + "test_script": ("This is a test script", vol.Schema(schema)) + } + + tool_input = llm.ToolInput( + tool_name="test_script", + tool_args={ + "beer": "3", + "wine": 0, + "where": "Living room", + "area_list": ["Living room"], + "floor": "2", + "floor_list": ["2"], + }, + ) + + with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call: + response = await api.async_call_tool(tool_input) + + mock_service_call.assert_awaited_once_with( + "script", + "turn_on", + { + "entity_id": "script.test_script", + "variables": { + "beer": "3", + "wine": 0, + "where": area.id, + "area_list": [area.id], + "floor": floor.floor_id, + "floor_list": [floor.floor_id], + }, + }, + context=context, + ) + assert response == {"success": True} + + # Test reload script with new parameters + config = { + "script": { + "test_script": ScriptConfig( + { + "description": "This is a new test script", + "sequence": [], + "mode": "single", + "max": 2, + "max_exceeded": "WARNING", + "trace": {}, + "fields": { + "beer": {"description": "Number of beers", "required": True}, + }, + } + ) + } + } + + with patch( + "homeassistant.helpers.entity_component.EntityComponent.async_prepare_reload", + return_value=config, + ): + await hass.services.async_call("script", "reload", blocking=True) + + assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {} + + api = await llm.async_get_api(hass, "assist", llm_context) + + tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)] + assert len(tools) == 1 + + tool = tools[0] + assert tool.name == "test_script" + assert tool.description == "This is a new test script" + schema = {vol.Required("beer", description="Number of beers"): cv.string} + assert tool.parameters.schema == schema + + assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == { + "test_script": ("This is a new test script", vol.Schema(schema)) + } + + +async def test_selector_serializer( + hass: HomeAssistant, llm_context: llm.LLMContext +) -> None: + """Test serialization of Selectors in Open API format.""" + api = await llm.async_get_api(hass, "assist", llm_context) + selector_serializer = api.custom_serializer + + assert selector_serializer(selector.ActionSelector()) == {"type": "string"} + assert selector_serializer(selector.AddonSelector()) == {"type": "string"} + assert selector_serializer(selector.AreaSelector()) == {"type": "string"} + assert selector_serializer(selector.AreaSelector({"multiple": True})) == { + "type": "array", + "items": {"type": "string"}, + } + assert selector_serializer(selector.AssistPipelineSelector()) == {"type": "string"} + assert selector_serializer( + selector.AttributeSelector({"entity_id": "sensor.test"}) + ) == {"type": "string"} + assert selector_serializer(selector.BackupLocationSelector()) == { + "type": "string", + "pattern": "^(?:\\/backup|\\w+)$", + } + assert selector_serializer(selector.BooleanSelector()) == {"type": "boolean"} + assert selector_serializer(selector.ColorRGBSelector()) == { + "type": "array", + "items": {"type": "number"}, + "maxItems": 3, + "minItems": 3, + "format": "RGB", + } + assert selector_serializer(selector.ColorTempSelector()) == {"type": "number"} + assert selector_serializer(selector.ColorTempSelector({"min": 0, "max": 1000})) == { + "type": "number", + "minimum": 0, + "maximum": 1000, + } + assert selector_serializer( + selector.ColorTempSelector({"min_mireds": 100, "max_mireds": 1000}) + ) == {"type": "number", "minimum": 100, "maximum": 1000} + assert selector_serializer(selector.ConfigEntrySelector()) == {"type": "string"} + assert selector_serializer(selector.ConstantSelector({"value": "test"})) == { + "enum": ["test"] + } + assert selector_serializer(selector.ConstantSelector({"value": 1})) == {"enum": [1]} + assert selector_serializer(selector.ConstantSelector({"value": True})) == { + "enum": [True] + } + assert selector_serializer(selector.QrCodeSelector({"data": "test"})) == { + "type": "string" + } + assert selector_serializer(selector.ConversationAgentSelector()) == { + "type": "string" + } + assert selector_serializer(selector.CountrySelector()) == { + "type": "string", + "format": "ISO 3166-1 alpha-2", + } + assert selector_serializer( + selector.CountrySelector({"countries": ["GB", "FR"]}) + ) == {"type": "string", "enum": ["GB", "FR"]} + assert selector_serializer(selector.DateSelector()) == { + "type": "string", + "format": "date", + } + assert selector_serializer(selector.DateTimeSelector()) == { + "type": "string", + "format": "date-time", + } + assert selector_serializer(selector.DeviceSelector()) == {"type": "string"} + assert selector_serializer(selector.DeviceSelector({"multiple": True})) == { + "type": "array", + "items": {"type": "string"}, + } + assert selector_serializer(selector.EntitySelector()) == { + "type": "string", + "format": "entity_id", + } + assert selector_serializer(selector.EntitySelector({"multiple": True})) == { + "type": "array", + "items": {"type": "string", "format": "entity_id"}, + } + assert selector_serializer(selector.FloorSelector()) == {"type": "string"} + assert selector_serializer(selector.FloorSelector({"multiple": True})) == { + "type": "array", + "items": {"type": "string"}, + } + assert selector_serializer(selector.IconSelector()) == {"type": "string"} + assert selector_serializer(selector.LabelSelector()) == {"type": "string"} + assert selector_serializer(selector.LabelSelector({"multiple": True})) == { + "type": "array", + "items": {"type": "string"}, + } + assert selector_serializer(selector.LanguageSelector()) == { + "type": "string", + "format": "RFC 5646", + } + assert selector_serializer( + selector.LanguageSelector({"languages": ["en", "fr"]}) + ) == {"type": "string", "enum": ["en", "fr"]} + assert selector_serializer(selector.LocationSelector()) == { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + "radius": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + } + assert selector_serializer(selector.MediaSelector()) == { + "type": "object", + "properties": { + "entity_id": {"type": "string"}, + "media_content_id": {"type": "string"}, + "media_content_type": {"type": "string"}, + "metadata": {"type": "object", "additionalProperties": True}, + }, + "required": ["entity_id", "media_content_id", "media_content_type"], + } + assert selector_serializer(selector.NumberSelector({"mode": "box"})) == { + "type": "number" + } + assert selector_serializer(selector.NumberSelector({"min": 30, "max": 100})) == { + "type": "number", + "minimum": 30, + "maximum": 100, + } + assert selector_serializer(selector.ObjectSelector()) == {"type": "object"} + assert selector_serializer( + selector.SelectSelector( + { + "options": [ + {"value": "A", "label": "Letter A"}, + {"value": "B", "label": "Letter B"}, + {"value": "C", "label": "Letter C"}, + ] + } + ) + ) == {"type": "string", "enum": ["A", "B", "C"]} + assert selector_serializer( + selector.SelectSelector({"options": ["A", "B", "C"], "multiple": True}) + ) == { + "type": "array", + "items": {"type": "string", "enum": ["A", "B", "C"]}, + "uniqueItems": True, + } + assert selector_serializer( + selector.StateSelector({"entity_id": "sensor.test"}) + ) == {"type": "string"} + assert selector_serializer(selector.TemplateSelector()) == { + "type": "string", + "format": "jinja2", + } + assert selector_serializer(selector.TextSelector()) == {"type": "string"} + assert selector_serializer(selector.TextSelector({"multiple": True})) == { + "type": "array", + "items": {"type": "string"}, + } + assert selector_serializer(selector.ThemeSelector()) == {"type": "string"} + assert selector_serializer(selector.TimeSelector()) == { + "type": "string", + "format": "time", + } + assert selector_serializer(selector.TriggerSelector()) == { + "type": "array", + "items": {"type": "string"}, + } + assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == { + "type": "string" + } diff --git a/tests/helpers/test_selector.py b/tests/helpers/test_selector.py index 6db313baa246..e93ec3b8c223 100644 --- a/tests/helpers/test_selector.py +++ b/tests/helpers/test_selector.py @@ -55,6 +55,8 @@ def _test_selector( config = {selector_type: schema} selector.validate_selector(config) selector_instance = selector.selector(config) + assert selector_instance == selector.selector(config) + assert selector_instance != 5 # We do not allow enums in the config, as they cannot serialize assert not any(isinstance(val, Enum) for val in selector_instance.config.values())