Allow stopping a script with a response value (#95284)

This commit is contained in:
Paulus Schoutsen 2023-06-27 02:24:22 -04:00 committed by GitHub
parent 51aa2ba835
commit 5f14cdf69d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 140 additions and 28 deletions

View file

@ -28,7 +28,14 @@ from homeassistant.const import (
SERVICE_TURN_ON,
STATE_ON,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.core import (
Context,
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.helpers import entity_registry as er
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.config_validation import make_entity_service_schema
@ -436,6 +443,12 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
variables = kwargs.get("variables")
context = kwargs.get("context")
wait = kwargs.get("wait", True)
await self._async_start_run(variables, context, wait)
async def _async_start_run(
self, variables: dict, context: Context, wait: bool
) -> ServiceResponse:
"""Start the run of a script."""
self.async_set_context(context)
self.hass.bus.async_fire(
EVENT_SCRIPT_STARTED,
@ -444,8 +457,7 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
)
coro = self._async_run(variables, context)
if wait:
await coro
return
return await coro
# Caller does not want to wait for called script to finish so let script run in
# separate Task. Make a new empty script stack; scripts are allowed to
@ -457,6 +469,7 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
# Wait for first state change so we can guarantee that
# it is written to the State Machine before we return.
await self._changed.wait()
return None
async def _async_run(self, variables, context):
with trace_script(
@ -483,16 +496,25 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
"""
await self.script.async_stop()
async def _service_handler(self, service: ServiceCall) -> None:
async def _service_handler(self, service: ServiceCall) -> ServiceResponse:
"""Execute a service call to script.<script name>."""
await self.async_turn_on(variables=service.data, context=service.context)
response = await self._async_start_run(
variables=service.data, context=service.context, wait=True
)
if service.return_response:
return response
return None
async def async_added_to_hass(self) -> None:
"""Restore last triggered on startup and register service."""
unique_id = cast(str, self.unique_id)
self.hass.services.async_register(
DOMAIN, unique_id, self._service_handler, schema=SCRIPT_SERVICE_SCHEMA
DOMAIN,
unique_id,
self._service_handler,
schema=SCRIPT_SERVICE_SCHEMA,
supports_response=SupportsResponse.OPTIONAL,
)
# Register the service description

View file

@ -675,8 +675,14 @@ async def handle_execute_script(
context = connection.context(msg)
script_obj = Script(hass, msg["sequence"], f"{const.DOMAIN} script", const.DOMAIN)
await script_obj.async_run(msg.get("variables"), context=context)
connection.send_result(msg["id"], {"context": context})
response = await script_obj.async_run(msg.get("variables"), context=context)
connection.send_result(
msg["id"],
{
"context": context,
"response": response,
},
)
@callback

View file

@ -221,9 +221,10 @@ CONF_RECIPIENT: Final = "recipient"
CONF_REGION: Final = "region"
CONF_REPEAT: Final = "repeat"
CONF_RESOURCE: Final = "resource"
CONF_RESOURCES: Final = "resources"
CONF_RESOURCE_TEMPLATE: Final = "resource_template"
CONF_RESOURCES: Final = "resources"
CONF_RESPONSE_VARIABLE: Final = "response_variable"
CONF_RESPONSE: Final = "response"
CONF_RGB: Final = "rgb"
CONF_ROOM: Final = "room"
CONF_SCAN_INTERVAL: Final = "scan_interval"

View file

@ -59,6 +59,7 @@ from homeassistant.const import (
CONF_PARALLEL,
CONF_PLATFORM,
CONF_REPEAT,
CONF_RESPONSE,
CONF_RESPONSE_VARIABLE,
CONF_SCAN_INTERVAL,
CONF_SCENE,
@ -1689,7 +1690,11 @@ _SCRIPT_STOP_SCHEMA = vol.Schema(
{
**SCRIPT_ACTION_BASE_SCHEMA,
vol.Required(CONF_STOP): vol.Any(None, string),
vol.Optional(CONF_ERROR, default=False): boolean,
vol.Exclusive(CONF_ERROR, "error_or_response"): boolean,
vol.Exclusive(CONF_RESPONSE, "error_or_response"): vol.Any(
vol.All(dict, template_complex),
vol.All(str, template),
),
}
)

View file

@ -46,6 +46,7 @@ from homeassistant.const import (
CONF_MODE,
CONF_PARALLEL,
CONF_REPEAT,
CONF_RESPONSE,
CONF_RESPONSE_VARIABLE,
CONF_SCENE,
CONF_SEQUENCE,
@ -69,6 +70,7 @@ from homeassistant.core import (
Event,
HassJob,
HomeAssistant,
ServiceResponse,
SupportsResponse,
callback,
)
@ -352,6 +354,11 @@ class _ConditionFail(_HaltScript):
class _StopScript(_HaltScript):
"""Throw if script needs to stop."""
def __init__(self, message: str, response: Any) -> None:
"""Initialize a halt exception."""
super().__init__(message)
self.response = response
class _ScriptRun:
"""Manage Script sequence run."""
@ -396,13 +403,14 @@ class _ScriptRun:
)
self._log("Executing step %s%s", self._script.last_action, _timeout)
async def async_run(self) -> None:
async def async_run(self) -> ServiceResponse:
"""Run script."""
# Push the script to the script execution stack
if (script_stack := script_stack_cv.get()) is None:
script_stack = []
script_stack_cv.set(script_stack)
script_stack.append(id(self._script))
response = None
try:
self._log("Running %s", self._script.running_description)
@ -420,11 +428,15 @@ class _ScriptRun:
raise
except _ConditionFail:
script_execution_set("aborted")
except _StopScript:
script_execution_set("finished")
except _StopScript as err:
script_execution_set("finished", err.response)
response = err.response
# Let the _StopScript bubble up if this is a sub-script
if not self._script.top_level:
raise
# We already consumed the response, do not pass it on
err.response = None
raise err
except Exception:
script_execution_set("error")
raise
@ -433,6 +445,8 @@ class _ScriptRun:
script_stack.pop()
self._finish()
return response
async def _async_step(self, log_exceptions):
continue_on_error = self._action.get(CONF_CONTINUE_ON_ERROR, False)
@ -1010,13 +1024,20 @@ class _ScriptRun:
async def _async_stop_step(self):
"""Stop script execution."""
stop = self._action[CONF_STOP]
error = self._action[CONF_ERROR]
error = self._action.get(CONF_ERROR, False)
trace_set_result(stop=stop, error=error)
if error:
self._log("Error script sequence: %s", stop)
raise _AbortScript(stop)
self._log("Stop script sequence: %s", stop)
raise _StopScript(stop)
if CONF_RESPONSE in self._action:
response = template.render_complex(
self._action[CONF_RESPONSE], self._variables
)
else:
response = None
raise _StopScript(stop, response)
@async_trace_path("parallel")
async def _async_parallel_step(self) -> None:
@ -1455,7 +1476,7 @@ class Script:
run_variables: _VarsType | None = None,
context: Context | None = None,
started_action: Callable[..., Any] | None = None,
) -> None:
) -> ServiceResponse:
"""Run script."""
if context is None:
self._log(
@ -1466,7 +1487,7 @@ class Script:
# Prevent spawning new script runs when Home Assistant is shutting down
if DATA_NEW_SCRIPT_RUNS_NOT_ALLOWED in self._hass.data:
self._log("Home Assistant is shutting down, starting script blocked")
return
return None
# Prevent spawning new script runs if not allowed by script mode
if self.is_running:
@ -1474,7 +1495,7 @@ class Script:
if self._max_exceeded != "SILENT":
self._log("Already running", level=LOGSEVERITY[self._max_exceeded])
script_execution_set("failed_single")
return
return None
if self.script_mode != SCRIPT_MODE_RESTART and self.runs == self.max_runs:
if self._max_exceeded != "SILENT":
self._log(
@ -1482,7 +1503,7 @@ class Script:
level=LOGSEVERITY[self._max_exceeded],
)
script_execution_set("failed_max_runs")
return
return None
# If this is a top level Script then make a copy of the variables in case they
# are read-only, but more importantly, so as not to leak any variables created
@ -1519,7 +1540,7 @@ class Script:
):
script_execution_set("disallowed_recursion_detected")
self._log("Disallowed recursion detected", level=logging.WARNING)
return
return None
if self.script_mode != SCRIPT_MODE_QUEUED:
cls = _ScriptRun
@ -1543,7 +1564,7 @@ class Script:
self._changed()
try:
await asyncio.shield(run.async_run())
return await asyncio.shield(run.async_run())
except asyncio.CancelledError:
await run.async_stop()
self._changed()

View file

@ -441,7 +441,7 @@ class TemplateEntity(Entity):
"""Run an action script."""
if run_variables is None:
run_variables = {}
return await script.async_run(
await script.async_run(
run_variables={
"this": TemplateStateFromEntityId(self.hass, self.entity_id),
**run_variables,

View file

@ -8,6 +8,7 @@ from contextvars import ContextVar
from functools import wraps
from typing import Any, cast
from homeassistant.core import ServiceResponse
import homeassistant.util.dt as dt_util
from .typing import TemplateVarsType
@ -207,13 +208,15 @@ class StopReason:
"""Mutable container class for script_execution."""
script_execution: str | None = None
response: ServiceResponse = None
def script_execution_set(reason: str) -> None:
def script_execution_set(reason: str, response: ServiceResponse = None) -> None:
"""Set stop reason."""
if (data := script_execution_cv.get()) is None:
return
data.script_execution = reason
data.response = response
def script_execution_get() -> str | None:

View file

@ -48,7 +48,9 @@ from homeassistant.core import (
Event,
HomeAssistant,
ServiceCall,
ServiceResponse,
State,
SupportsResponse,
callback,
)
from homeassistant.helpers import (
@ -285,7 +287,12 @@ async def async_test_home_assistant(event_loop, load_registries=True):
def async_mock_service(
hass: HomeAssistant, domain: str, service: str, schema: vol.Schema | None = None
hass: HomeAssistant,
domain: str,
service: str,
schema: vol.Schema | None = None,
response: ServiceResponse = None,
supports_response: SupportsResponse | None = None,
) -> list[ServiceCall]:
"""Set up a fake service & return a calls log list to this service."""
calls = []
@ -294,8 +301,18 @@ def async_mock_service(
def mock_service_log(call): # pylint: disable=unnecessary-lambda
"""Mock service call."""
calls.append(call)
return response
hass.services.async_register(domain, service, mock_service_log, schema=schema)
if supports_response is None and response is not None:
supports_response = SupportsResponse.OPTIONAL
hass.services.async_register(
domain,
service,
mock_service_log,
schema=schema,
supports_response=supports_response,
)
return calls

View file

@ -1,6 +1,7 @@
"""The tests for the Script component."""
import asyncio
from datetime import timedelta
from typing import Any
from unittest.mock import Mock, patch
import pytest
@ -1502,3 +1503,34 @@ async def test_blueprint_script_fails_substitution(
"{'service_to_call': 'test.automation'}: No substitution found for input blah"
in caplog.text
)
@pytest.mark.parametrize("response", ({"value": 5}, '{"value": 5}'))
async def test_responses(hass: HomeAssistant, response: Any) -> None:
"""Test we can get responses."""
mock_restore_cache(hass, ())
assert await async_setup_component(
hass,
"script",
{
"script": {
"test": {
"sequence": {
"stop": "done",
"response": response,
}
}
}
},
)
assert await hass.services.async_call(
DOMAIN, "test", {"greeting": "world"}, blocking=True, return_response=True
) == {"value": 5}
# Validate we can also call it without return_response
assert (
await hass.services.async_call(
DOMAIN, "test", {"greeting": "world"}, blocking=True, return_response=False
)
is None
)

View file

@ -1672,7 +1672,9 @@ async def test_test_condition(hass: HomeAssistant, websocket_client) -> None:
async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
"""Test testing a condition."""
calls = async_mock_service(hass, "domain_test", "test_service")
calls = async_mock_service(
hass, "domain_test", "test_service", response={"hello": "world"}
)
await websocket_client.send_json(
{
@ -1682,7 +1684,9 @@ async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
{
"service": "domain_test.test_service",
"data": {"hello": "world"},
}
"response_variable": "service_result",
},
{"stop": "done", "response": "{{ service_result }}"},
],
}
)
@ -1691,6 +1695,7 @@ async def test_execute_script(hass: HomeAssistant, websocket_client) -> None:
assert msg_no_var["id"] == 5
assert msg_no_var["type"] == const.TYPE_RESULT
assert msg_no_var["success"]
assert msg_no_var["result"]["response"] == {"hello": "world"}
await websocket_client.send_json(
{