1
0
mirror of https://github.com/home-assistant/core synced 2024-06-29 06:15:03 +00:00

Add script llm tool (#118936)

* Add script llm tool

* Add tests

* More tests

* more test

* more test

* Add area and floor resolving

* coverage

* coverage

* fix ColorTempSelector

* fix mypy

* fix mypy

* add script reload test

* Cache script tool parameters

* Make custom_serializer a part of api

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Denis Shulyaka 2024-06-25 18:43:26 +03:00 committed by GitHub
parent 77fea8a73e
commit 2386ed3830
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 639 additions and 55 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import codecs
from collections.abc import Callable
from typing import Any, Literal
from google.api_core.exceptions import GoogleAPICallError
@ -89,10 +90,14 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
return result
def _format_tool(tool: llm.Tool) -> dict[str, Any]:
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
"""Format tool specification."""
parameters = _format_schema(convert(tool.parameters))
parameters = _format_schema(
convert(tool.parameters, custom_serializer=custom_serializer)
)
return protos.Tool(
{
@ -193,7 +198,9 @@ class GoogleGenerativeAIConversationEntity(
f"Error preparing LLM API: {err}",
)
return result
tools = [_format_tool(tool) for tool in llm_api.tools]
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
try:
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)

View File

@ -9,5 +9,5 @@
"integration_type": "service",
"iot_class": "cloud_polling",
"quality_scale": "platinum",
"requirements": ["google-generativeai==0.6.0", "voluptuous-openapi==0.0.4"]
"requirements": ["google-generativeai==0.6.0"]
}

View File

@ -1,7 +1,8 @@
"""Conversation support for OpenAI."""
from collections.abc import Callable
import json
from typing import Literal
from typing import Any, Literal
import openai
from openai._types import NOT_GIVEN
@ -58,9 +59,14 @@ async def async_setup_entry(
async_add_entities([agent])
def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam:
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> ChatCompletionToolParam:
"""Format tool specification."""
tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters))
tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description:
tool_spec["description"] = tool.description
return ChatCompletionToolParam(type="function", function=tool_spec)
@ -139,7 +145,9 @@ class OpenAIConversationEntity(
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool) for tool in llm_api.tools]
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]
if user_input.conversation_id is None:
conversation_id = ulid.ulid_now()

View File

@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
"integration_type": "service",
"iot_class": "cloud_polling",
"requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"]
"requirements": ["openai==1.3.8"]
}

View File

@ -352,7 +352,7 @@ class MatchTargetsCandidate:
matched_name: str | None = None
def _find_areas(
def find_areas(
name: str, areas: area_registry.AreaRegistry
) -> Iterable[area_registry.AreaEntry]:
"""Find all areas matching a name (including aliases)."""
@ -372,7 +372,7 @@ def _find_areas(
break
def _find_floors(
def find_floors(
name: str, floors: floor_registry.FloorRegistry
) -> Iterable[floor_registry.FloorEntry]:
"""Find all floors matching a name (including aliases)."""
@ -530,7 +530,7 @@ def async_match_targets( # noqa: C901
if not states:
return MatchTargetsResult(False, MatchFailedReason.STATE)
# Exit early so we can to avoid registry lookups
# Exit early so we can avoid registry lookups
if not (
constraints.name
or constraints.features
@ -580,7 +580,7 @@ def async_match_targets( # noqa: C901
if constraints.floor_name:
# Filter by areas associated with floor
fr = floor_registry.async_get(hass)
targeted_floors = list(_find_floors(constraints.floor_name, fr))
targeted_floors = list(find_floors(constraints.floor_name, fr))
if not targeted_floors:
return MatchTargetsResult(
False,
@ -609,7 +609,7 @@ def async_match_targets( # noqa: C901
possible_area_ids = {area.id for area in ar.async_list_areas()}
if constraints.area_name:
targeted_areas = list(_find_areas(constraints.area_name, ar))
targeted_areas = list(find_areas(constraints.area_name, ar))
if not targeted_areas:
return MatchTargetsResult(
False,

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from decimal import Decimal
from enum import Enum
@ -11,6 +12,7 @@ from typing import Any
import slugify as unicode_slug
import voluptuous as vol
from voluptuous_openapi import UNSUPPORTED, convert
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
from homeassistant.components.conversation.trace import (
@ -20,22 +22,39 @@ from homeassistant.components.conversation.trace import (
from homeassistant.components.cover.intent import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.const import (
ATTR_DOMAIN,
ATTR_ENTITY_ID,
ATTR_SERVICE,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED,
SERVICE_TURN_ON,
)
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
from . import (
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_registry as er,
floor_registry as fr,
intent,
selector,
service,
)
from .singleton import singleton
SCRIPT_PARAMETERS_CACHE: HassKey[dict[str, tuple[str | None, vol.Schema]]] = HassKey(
"llm_script_parameters_cache"
)
LLM_API_ASSIST = "assist"
BASE_PROMPT = (
@ -143,6 +162,7 @@ class APIInstance:
api_prompt: str
llm_context: LLMContext
tools: list[Tool]
custom_serializer: Callable[[Any], Any] | None = None
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
@ -284,6 +304,7 @@ class AssistAPI(API):
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
llm_context=llm_context,
tools=self._async_get_tools(llm_context, exposed_entities),
custom_serializer=_selector_serializer,
)
@callback
@ -372,7 +393,7 @@ class AssistAPI(API):
exposed_domains: set[str] | None = None
if exposed_entities is not None:
exposed_domains = {
entity_id.split(".")[0] for entity_id in exposed_entities
split_entity_id(entity_id)[0] for entity_id in exposed_entities
}
intent_handlers = [
intent_handler
@ -381,11 +402,22 @@ class AssistAPI(API):
or intent_handler.platforms & exposed_domains
]
return [
tools: list[Tool] = [
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
for intent_handler in intent_handlers
]
if llm_context.assistant is not None:
for state in self.hass.states.async_all(SCRIPT_DOMAIN):
if not async_should_expose(
self.hass, llm_context.assistant, state.entity_id
):
continue
tools.append(ScriptTool(self.hass, state.entity_id))
return tools
def _get_exposed_entities(
hass: HomeAssistant, assistant: str
@ -413,13 +445,15 @@ def _get_exposed_entities(
entities = {}
for state in hass.states.async_all():
if state.domain == SCRIPT_DOMAIN:
continue
if not async_should_expose(hass, assistant, state.entity_id):
continue
entity_entry = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
description: str | None = None
if entity_entry is not None:
names.extend(entity_entry.aliases)
@ -439,25 +473,11 @@ def _get_exposed_entities(
area_names.append(area.name)
area_names.extend(area.aliases)
if (
state.domain == "script"
and entity_entry.unique_id
and (
service_desc := service.async_get_cached_service_description(
hass, "script", entity_entry.unique_id
)
)
):
description = service_desc.get("description")
info: dict[str, Any] = {
"names": ", ".join(names),
"state": state.state,
}
if description:
info["description"] = description
if area_names:
info["areas"] = ", ".join(area_names)
@ -473,3 +493,231 @@ def _get_exposed_entities(
entities[state.entity_id] = info
return entities
def _selector_serializer(schema: Any) -> Any: # noqa: C901
"""Convert selectors into OpenAPI schema."""
if not isinstance(schema, selector.Selector):
return UNSUPPORTED
if isinstance(schema, selector.BackupLocationSelector):
return {"type": "string", "pattern": "^(?:\\/backup|\\w+)$"}
if isinstance(schema, selector.BooleanSelector):
return {"type": "boolean"}
if isinstance(schema, selector.ColorRGBSelector):
return {
"type": "array",
"items": {"type": "number"},
"minItems": 3,
"maxItems": 3,
"format": "RGB",
}
if isinstance(schema, selector.ConditionSelector):
return convert(cv.CONDITIONS_SCHEMA)
if isinstance(schema, selector.ConstantSelector):
return {"enum": [schema.config["value"]]}
result: dict[str, Any]
if isinstance(schema, selector.ColorTempSelector):
result = {"type": "number"}
if "min" in schema.config:
result["minimum"] = schema.config["min"]
elif "min_mireds" in schema.config:
result["minimum"] = schema.config["min_mireds"]
if "max" in schema.config:
result["maximum"] = schema.config["max"]
elif "max_mireds" in schema.config:
result["maximum"] = schema.config["max_mireds"]
return result
if isinstance(schema, selector.CountrySelector):
if schema.config.get("countries"):
return {"type": "string", "enum": schema.config["countries"]}
return {"type": "string", "format": "ISO 3166-1 alpha-2"}
if isinstance(schema, selector.DateSelector):
return {"type": "string", "format": "date"}
if isinstance(schema, selector.DateTimeSelector):
return {"type": "string", "format": "date-time"}
if isinstance(schema, selector.DurationSelector):
return convert(cv.time_period_dict)
if isinstance(schema, selector.EntitySelector):
if schema.config.get("multiple"):
return {"type": "array", "items": {"type": "string", "format": "entity_id"}}
return {"type": "string", "format": "entity_id"}
if isinstance(schema, selector.LanguageSelector):
if schema.config.get("languages"):
return {"type": "string", "enum": schema.config["languages"]}
return {"type": "string", "format": "RFC 5646"}
if isinstance(schema, (selector.LocationSelector, selector.MediaSelector)):
return convert(schema.DATA_SCHEMA)
if isinstance(schema, selector.NumberSelector):
result = {"type": "number"}
if "min" in schema.config:
result["minimum"] = schema.config["min"]
if "max" in schema.config:
result["maximum"] = schema.config["max"]
return result
if isinstance(schema, selector.ObjectSelector):
return {"type": "object"}
if isinstance(schema, selector.SelectSelector):
options = [
x["value"] if isinstance(x, dict) else x for x in schema.config["options"]
]
if schema.config.get("multiple"):
return {
"type": "array",
"items": {"type": "string", "enum": options},
"uniqueItems": True,
}
return {"type": "string", "enum": options}
if isinstance(schema, selector.TargetSelector):
return convert(cv.TARGET_SERVICE_FIELDS)
if isinstance(schema, selector.TemplateSelector):
return {"type": "string", "format": "jinja2"}
if isinstance(schema, selector.TimeSelector):
return {"type": "string", "format": "time"}
if isinstance(schema, selector.TriggerSelector):
return convert(cv.TRIGGER_SCHEMA)
if schema.config.get("multiple"):
return {"type": "array", "items": {"type": "string"}}
return {"type": "string"}
class ScriptTool(Tool):
"""LLM Tool representing a Script."""
def __init__(
self,
hass: HomeAssistant,
script_entity_id: str,
) -> None:
"""Init the class."""
entity_registry = er.async_get(hass)
self.name = split_entity_id(script_entity_id)[1]
self.parameters = vol.Schema({})
entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id:
parameters_cache = hass.data.get(SCRIPT_PARAMETERS_CACHE)
if parameters_cache is None:
parameters_cache = hass.data[SCRIPT_PARAMETERS_CACHE] = {}
@callback
def clear_cache(event: Event) -> None:
"""Clear script parameter cache on script reload or delete."""
if (
event.data[ATTR_DOMAIN] == SCRIPT_DOMAIN
and event.data[ATTR_SERVICE] in parameters_cache
):
parameters_cache.pop(event.data[ATTR_SERVICE])
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
@callback
def on_homeassistant_close(event: Event) -> None:
"""Cleanup."""
cancel()
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close
)
if entity_entry.unique_id in parameters_cache:
self.description, self.parameters = parameters_cache[
entity_entry.unique_id
]
return
if service_desc := service.async_get_cached_service_description(
hass, SCRIPT_DOMAIN, entity_entry.unique_id
):
self.description = service_desc.get("description")
schema: dict[vol.Marker, Any] = {}
fields = service_desc.get("fields", {})
for field, config in fields.items():
description = config.get("description")
if not description:
description = config.get("name")
if config.get("required"):
key = vol.Required(field, description=description)
else:
key = vol.Optional(field, description=description)
if "selector" in config:
schema[key] = selector.selector(config["selector"])
else:
schema[key] = cv.string
self.parameters = vol.Schema(schema)
parameters_cache[entity_entry.unique_id] = (
self.description,
self.parameters,
)
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Run the script."""
for field, validator in self.parameters.schema.items():
if field not in tool_input.tool_args:
continue
if isinstance(validator, selector.AreaSelector):
area_reg = ar.async_get(hass)
if validator.config.get("multiple"):
areas: list[ar.AreaEntry] = []
for area in tool_input.tool_args[field]:
areas.extend(intent.find_areas(area, area_reg))
tool_input.tool_args[field] = list({area.id for area in areas})
else:
area = tool_input.tool_args[field]
area = list(intent.find_areas(area, area_reg))[0].id
tool_input.tool_args[field] = area
elif isinstance(validator, selector.FloorSelector):
floor_reg = fr.async_get(hass)
if validator.config.get("multiple"):
floors: list[fr.FloorEntry] = []
for floor in tool_input.tool_args[field]:
floors.extend(intent.find_floors(floor, floor_reg))
tool_input.tool_args[field] = list(
{floor.floor_id for floor in floors}
)
else:
floor = tool_input.tool_args[field]
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
tool_input.tool_args[field] = floor
await hass.services.async_call(
SCRIPT_DOMAIN,
SERVICE_TURN_ON,
{
ATTR_ENTITY_ID: SCRIPT_DOMAIN + "." + self.name,
ATTR_VARIABLES: tool_input.tool_args,
},
context=llm_context.context,
)
return {"success": True}

View File

@ -75,6 +75,13 @@ class Selector[_T: Mapping[str, Any]]:
self.config = self.CONFIG_SCHEMA(config)
def __eq__(self, other: object) -> bool:
"""Check equality."""
if not isinstance(other, Selector):
return NotImplemented
return self.selector_type == other.selector_type and self.config == other.config
def serialize(self) -> dict[str, dict[str, _T]]:
"""Serialize Selector for voluptuous_serialize."""
return {"selector": {self.selector_type: self.config}}
@ -278,7 +285,7 @@ class AssistPipelineSelector(Selector[AssistPipelineSelectorConfig]):
CONFIG_SCHEMA = vol.Schema({})
def __init__(self, config: AssistPipelineSelectorConfig) -> None:
def __init__(self, config: AssistPipelineSelectorConfig | None = None) -> None:
"""Instantiate a selector."""
super().__init__(config)
@ -430,10 +437,10 @@ class ColorTempSelector(Selector[ColorTempSelectorConfig]):
range_min = self.config.get("min")
range_max = self.config.get("max")
if not range_min:
if range_min is None:
range_min = self.config.get("min_mireds")
if not range_max:
if range_max is None:
range_max = self.config.get("max_mireds")
value: int = vol.All(
@ -517,7 +524,7 @@ class ConstantSelector(Selector[ConstantSelectorConfig]):
}
)
def __init__(self, config: ConstantSelectorConfig | None = None) -> None:
def __init__(self, config: ConstantSelectorConfig) -> None:
"""Instantiate a selector."""
super().__init__(config)
@ -560,7 +567,7 @@ class QrCodeSelector(Selector[QrCodeSelectorConfig]):
}
)
def __init__(self, config: QrCodeSelectorConfig | None = None) -> None:
def __init__(self, config: QrCodeSelectorConfig) -> None:
"""Instantiate a selector."""
super().__init__(config)
@ -588,7 +595,7 @@ class ConversationAgentSelector(Selector[ConversationAgentSelectorConfig]):
}
)
def __init__(self, config: ConversationAgentSelectorConfig) -> None:
def __init__(self, config: ConversationAgentSelectorConfig | None = None) -> None:
"""Instantiate a selector."""
super().__init__(config)
@ -820,7 +827,7 @@ class FloorSelectorConfig(TypedDict, total=False):
@SELECTORS.register("floor")
class FloorSelector(Selector[AreaSelectorConfig]):
class FloorSelector(Selector[FloorSelectorConfig]):
"""Selector of a single or list of floors."""
selector_type = "floor"
@ -934,7 +941,7 @@ class LanguageSelector(Selector[LanguageSelectorConfig]):
}
)
def __init__(self, config: LanguageSelectorConfig) -> None:
def __init__(self, config: LanguageSelectorConfig | None = None) -> None:
"""Instantiate a selector."""
super().__init__(config)
@ -1159,7 +1166,7 @@ class SelectSelector(Selector[SelectSelectorConfig]):
}
)
def __init__(self, config: SelectSelectorConfig | None = None) -> None:
def __init__(self, config: SelectSelectorConfig) -> None:
"""Instantiate a selector."""
super().__init__(config)
@ -1434,7 +1441,7 @@ class FileSelector(Selector[FileSelectorConfig]):
}
)
def __init__(self, config: FileSelectorConfig | None = None) -> None:
def __init__(self, config: FileSelectorConfig) -> None:
"""Instantiate a selector."""
super().__init__(config)

View File

@ -58,6 +58,7 @@ SQLAlchemy==2.0.31
typing-extensions>=4.12.2,<5.0
ulid-transform==0.9.0
urllib3>=1.26.5,<2
voluptuous-openapi==0.0.4
voluptuous-serialize==2.6.0
voluptuous==0.13.1
webrtc-noise-gain==1.2.3

View File

@ -69,6 +69,7 @@ dependencies = [
"urllib3>=1.26.5,<2",
"voluptuous==0.13.1",
"voluptuous-serialize==2.6.0",
"voluptuous-openapi==0.0.4",
"yarl==1.9.4",
]

View File

@ -41,4 +41,5 @@ ulid-transform==0.9.0
urllib3>=1.26.5,<2
voluptuous==0.13.1
voluptuous-serialize==2.6.0
voluptuous-openapi==0.0.4
yarl==1.9.4

View File

@ -2846,10 +2846,6 @@ voip-utils==0.1.0
# homeassistant.components.volkszaehler
volkszaehler==0.4.0
# homeassistant.components.google_generative_ai_conversation
# homeassistant.components.openai_conversation
voluptuous-openapi==0.0.4
# homeassistant.components.volvooncall
volvooncall==0.10.3

View File

@ -2217,10 +2217,6 @@ vilfo-api-client==0.5.0
# homeassistant.components.voip
voip-utils==0.1.0
# homeassistant.components.google_generative_ai_conversation
# homeassistant.components.openai_conversation
voluptuous-openapi==0.0.4
# homeassistant.components.volvooncall
volvooncall==0.10.3

View File

@ -8,6 +8,7 @@ import voluptuous as vol
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.components.script.config import ScriptConfig
from homeassistant.core import Context, HomeAssistant, State
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
@ -18,6 +19,7 @@ from homeassistant.helpers import (
floor_registry as fr,
intent,
llm,
selector,
)
from homeassistant.setup import async_setup_component
from homeassistant.util import yaml
@ -564,11 +566,6 @@ async def test_assist_api_prompt(
"names": "Unnamed Device",
"state": "unavailable",
},
"script.test_script": {
"description": "This is a test script",
"names": "test_script",
"state": "off",
},
}
exposed_entities_prompt = (
"An overview of the areas and the devices in this smart home:\n"
@ -634,3 +631,323 @@ async def test_assist_api_prompt(
{area_prompt}
{exposed_entities_prompt}"""
)
async def test_script_tool(
hass: HomeAssistant,
area_registry: ar.AreaRegistry,
floor_registry: fr.FloorRegistry,
) -> None:
"""Test ScriptTool for the assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
context=context,
user_prompt="test_text",
language="*",
assistant="conversation",
device_id=None,
)
# Create a script with a unique ID
assert await async_setup_component(
hass,
"script",
{
"script": {
"test_script": {
"description": "This is a test script",
"sequence": [],
"fields": {
"beer": {"description": "Number of beers", "required": True},
"wine": {"selector": {"number": {"min": 0, "max": 3}}},
"where": {"selector": {"area": {}}},
"area_list": {"selector": {"area": {"multiple": True}}},
"floor": {"selector": {"floor": {}}},
"floor_list": {"selector": {"floor": {"multiple": True}}},
"extra_field": {"selector": {"area": {}}},
},
},
"unexposed_script": {
"sequence": [],
},
}
},
)
async_expose_entity(hass, "conversation", "script.test_script", True)
area = area_registry.async_create("Living room")
floor = floor_registry.async_create("2")
assert llm.SCRIPT_PARAMETERS_CACHE not in hass.data
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_script"
assert tool.description == "This is a test script"
schema = {
vol.Required("beer", description="Number of beers"): cv.string,
vol.Optional("wine"): selector.NumberSelector({"min": 0, "max": 3}),
vol.Optional("where"): selector.AreaSelector(),
vol.Optional("area_list"): selector.AreaSelector({"multiple": True}),
vol.Optional("floor"): selector.FloorSelector(),
vol.Optional("floor_list"): selector.FloorSelector({"multiple": True}),
vol.Optional("extra_field"): selector.AreaSelector(),
}
assert tool.parameters.schema == schema
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
"test_script": ("This is a test script", vol.Schema(schema))
}
tool_input = llm.ToolInput(
tool_name="test_script",
tool_args={
"beer": "3",
"wine": 0,
"where": "Living room",
"area_list": ["Living room"],
"floor": "2",
"floor_list": ["2"],
},
)
with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
response = await api.async_call_tool(tool_input)
mock_service_call.assert_awaited_once_with(
"script",
"turn_on",
{
"entity_id": "script.test_script",
"variables": {
"beer": "3",
"wine": 0,
"where": area.id,
"area_list": [area.id],
"floor": floor.floor_id,
"floor_list": [floor.floor_id],
},
},
context=context,
)
assert response == {"success": True}
# Test reload script with new parameters
config = {
"script": {
"test_script": ScriptConfig(
{
"description": "This is a new test script",
"sequence": [],
"mode": "single",
"max": 2,
"max_exceeded": "WARNING",
"trace": {},
"fields": {
"beer": {"description": "Number of beers", "required": True},
},
}
)
}
}
with patch(
"homeassistant.helpers.entity_component.EntityComponent.async_prepare_reload",
return_value=config,
):
await hass.services.async_call("script", "reload", blocking=True)
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {}
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
assert len(tools) == 1
tool = tools[0]
assert tool.name == "test_script"
assert tool.description == "This is a new test script"
schema = {vol.Required("beer", description="Number of beers"): cv.string}
assert tool.parameters.schema == schema
assert hass.data[llm.SCRIPT_PARAMETERS_CACHE] == {
"test_script": ("This is a new test script", vol.Schema(schema))
}
async def test_selector_serializer(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test serialization of Selectors in Open API format."""
api = await llm.async_get_api(hass, "assist", llm_context)
selector_serializer = api.custom_serializer
assert selector_serializer(selector.ActionSelector()) == {"type": "string"}
assert selector_serializer(selector.AddonSelector()) == {"type": "string"}
assert selector_serializer(selector.AreaSelector()) == {"type": "string"}
assert selector_serializer(selector.AreaSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.AssistPipelineSelector()) == {"type": "string"}
assert selector_serializer(
selector.AttributeSelector({"entity_id": "sensor.test"})
) == {"type": "string"}
assert selector_serializer(selector.BackupLocationSelector()) == {
"type": "string",
"pattern": "^(?:\\/backup|\\w+)$",
}
assert selector_serializer(selector.BooleanSelector()) == {"type": "boolean"}
assert selector_serializer(selector.ColorRGBSelector()) == {
"type": "array",
"items": {"type": "number"},
"maxItems": 3,
"minItems": 3,
"format": "RGB",
}
assert selector_serializer(selector.ColorTempSelector()) == {"type": "number"}
assert selector_serializer(selector.ColorTempSelector({"min": 0, "max": 1000})) == {
"type": "number",
"minimum": 0,
"maximum": 1000,
}
assert selector_serializer(
selector.ColorTempSelector({"min_mireds": 100, "max_mireds": 1000})
) == {"type": "number", "minimum": 100, "maximum": 1000}
assert selector_serializer(selector.ConfigEntrySelector()) == {"type": "string"}
assert selector_serializer(selector.ConstantSelector({"value": "test"})) == {
"enum": ["test"]
}
assert selector_serializer(selector.ConstantSelector({"value": 1})) == {"enum": [1]}
assert selector_serializer(selector.ConstantSelector({"value": True})) == {
"enum": [True]
}
assert selector_serializer(selector.QrCodeSelector({"data": "test"})) == {
"type": "string"
}
assert selector_serializer(selector.ConversationAgentSelector()) == {
"type": "string"
}
assert selector_serializer(selector.CountrySelector()) == {
"type": "string",
"format": "ISO 3166-1 alpha-2",
}
assert selector_serializer(
selector.CountrySelector({"countries": ["GB", "FR"]})
) == {"type": "string", "enum": ["GB", "FR"]}
assert selector_serializer(selector.DateSelector()) == {
"type": "string",
"format": "date",
}
assert selector_serializer(selector.DateTimeSelector()) == {
"type": "string",
"format": "date-time",
}
assert selector_serializer(selector.DeviceSelector()) == {"type": "string"}
assert selector_serializer(selector.DeviceSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.EntitySelector()) == {
"type": "string",
"format": "entity_id",
}
assert selector_serializer(selector.EntitySelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string", "format": "entity_id"},
}
assert selector_serializer(selector.FloorSelector()) == {"type": "string"}
assert selector_serializer(selector.FloorSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.IconSelector()) == {"type": "string"}
assert selector_serializer(selector.LabelSelector()) == {"type": "string"}
assert selector_serializer(selector.LabelSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.LanguageSelector()) == {
"type": "string",
"format": "RFC 5646",
}
assert selector_serializer(
selector.LanguageSelector({"languages": ["en", "fr"]})
) == {"type": "string", "enum": ["en", "fr"]}
assert selector_serializer(selector.LocationSelector()) == {
"type": "object",
"properties": {
"latitude": {"type": "number"},
"longitude": {"type": "number"},
"radius": {"type": "number"},
},
"required": ["latitude", "longitude"],
}
assert selector_serializer(selector.MediaSelector()) == {
"type": "object",
"properties": {
"entity_id": {"type": "string"},
"media_content_id": {"type": "string"},
"media_content_type": {"type": "string"},
"metadata": {"type": "object", "additionalProperties": True},
},
"required": ["entity_id", "media_content_id", "media_content_type"],
}
assert selector_serializer(selector.NumberSelector({"mode": "box"})) == {
"type": "number"
}
assert selector_serializer(selector.NumberSelector({"min": 30, "max": 100})) == {
"type": "number",
"minimum": 30,
"maximum": 100,
}
assert selector_serializer(selector.ObjectSelector()) == {"type": "object"}
assert selector_serializer(
selector.SelectSelector(
{
"options": [
{"value": "A", "label": "Letter A"},
{"value": "B", "label": "Letter B"},
{"value": "C", "label": "Letter C"},
]
}
)
) == {"type": "string", "enum": ["A", "B", "C"]}
assert selector_serializer(
selector.SelectSelector({"options": ["A", "B", "C"], "multiple": True})
) == {
"type": "array",
"items": {"type": "string", "enum": ["A", "B", "C"]},
"uniqueItems": True,
}
assert selector_serializer(
selector.StateSelector({"entity_id": "sensor.test"})
) == {"type": "string"}
assert selector_serializer(selector.TemplateSelector()) == {
"type": "string",
"format": "jinja2",
}
assert selector_serializer(selector.TextSelector()) == {"type": "string"}
assert selector_serializer(selector.TextSelector({"multiple": True})) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.ThemeSelector()) == {"type": "string"}
assert selector_serializer(selector.TimeSelector()) == {
"type": "string",
"format": "time",
}
assert selector_serializer(selector.TriggerSelector()) == {
"type": "array",
"items": {"type": "string"},
}
assert selector_serializer(selector.FileSelector({"accept": ".txt"})) == {
"type": "string"
}

View File

@ -55,6 +55,8 @@ def _test_selector(
config = {selector_type: schema}
selector.validate_selector(config)
selector_instance = selector.selector(config)
assert selector_instance == selector.selector(config)
assert selector_instance != 5
# We do not allow enums in the config, as they cannot serialize
assert not any(isinstance(val, Enum) for val in selector_instance.config.values())