Add support for responses to call_service WS cmd (#98610)

* Add support for responses to call_service WS cmd

* Revert ServiceNotFound removal and add a parameter for return_response

* fix type

* fix tests

* remove exception handling that was added

* Revert unnecessary modifications

* Use kwargs
This commit is contained in:
Raman Gupta 2023-11-10 15:44:43 -05:00 committed by GitHub
parent 229944c21c
commit 618b666126
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 102 additions and 13 deletions

View file

@ -18,7 +18,14 @@ from homeassistant.const import (
MATCH_ALL,
SIGNAL_BOOTSTRAP_INTEGRATIONS,
)
from homeassistant.core import Context, Event, HomeAssistant, State, callback
from homeassistant.core import (
Context,
Event,
HomeAssistant,
ServiceResponse,
State,
callback,
)
from homeassistant.exceptions import (
HomeAssistantError,
ServiceNotFound,
@ -213,6 +220,7 @@ def handle_unsubscribe_events(
vol.Required("service"): str,
vol.Optional("target"): cv.ENTITY_SERVICE_FIELDS,
vol.Optional("service_data"): dict,
vol.Optional("return_response", default=False): bool,
}
)
@decorators.async_response
@ -220,7 +228,6 @@ async def handle_call_service(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle call service command."""
blocking = True
# We do not support templates.
target = msg.get("target")
if template.is_complex(target):
@ -228,15 +235,19 @@ async def handle_call_service(
try:
context = connection.context(msg)
await hass.services.async_call(
msg["domain"],
msg["service"],
msg.get("service_data"),
blocking,
context,
response = await hass.services.async_call(
domain=msg["domain"],
service=msg["service"],
service_data=msg.get("service_data"),
blocking=True,
context=context,
target=target,
return_response=msg["return_response"],
)
connection.send_result(msg["id"], {"context": context})
result: dict[str, Context | ServiceResponse] = {"context": context}
if msg["return_response"]:
result["response"] = response
connection.send_result(msg["id"], result)
except ServiceNotFound as err:
if err.domain == msg["domain"] and err.service == msg["service"]:
connection.send_error(

View file

@ -307,8 +307,11 @@ def async_mock_service(
calls.append(call)
return response
if supports_response is None and response is not None:
supports_response = SupportsResponse.OPTIONAL
if supports_response is None:
if response is not None:
supports_response = SupportsResponse.OPTIONAL
else:
supports_response = SupportsResponse.NONE
hass.services.async_register(
domain,

View file

@ -18,7 +18,7 @@ from homeassistant.components.websocket_api.auth import (
)
from homeassistant.components.websocket_api.const import FEATURE_COALESCE_MESSAGES, URL
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATIONS
from homeassistant.core import Context, HomeAssistant, State, callback
from homeassistant.core import Context, HomeAssistant, State, SupportsResponse, callback
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.dispatcher import async_dispatcher_send
@ -183,14 +183,76 @@ async def test_call_service(
assert call.context.as_dict() == msg["result"]["context"]
async def test_return_response_error(hass: HomeAssistant, websocket_client) -> None:
"""Test return_response=True errors when service has no response."""
hass.services.async_register(
"domain_test", "test_service_with_no_response", lambda x: None
)
await websocket_client.send_json(
{
"id": 8,
"type": "call_service",
"domain": "domain_test",
"service": "test_service_with_no_response",
"service_data": {"hello": "world"},
"return_response": True,
},
)
msg = await websocket_client.receive_json()
assert msg["id"] == 8
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == "unknown_error"
@pytest.mark.parametrize("command", ("call_service", "call_service_action"))
async def test_call_service_blocking(
hass: HomeAssistant, websocket_client: MockHAClientWebSocket, command
) -> None:
"""Test call service commands block, except for homeassistant restart / stop."""
async_mock_service(
hass,
"domain_test",
"test_service",
response={"hello": "world"},
supports_response=SupportsResponse.OPTIONAL,
)
with patch(
"homeassistant.core.ServiceRegistry.async_call", autospec=True
) as mock_call:
mock_call.return_value = {"foo": "bar"}
await websocket_client.send_json(
{
"id": 4,
"type": "call_service",
"domain": "domain_test",
"service": "test_service",
"service_data": {"hello": "world"},
"return_response": True,
},
)
msg = await websocket_client.receive_json()
assert msg["id"] == 4
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"]["response"] == {"foo": "bar"}
mock_call.assert_called_once_with(
ANY,
"domain_test",
"test_service",
{"hello": "world"},
blocking=True,
context=ANY,
target=ANY,
return_response=True,
)
with patch(
"homeassistant.core.ServiceRegistry.async_call", autospec=True
) as mock_call:
mock_call.return_value = None
await websocket_client.send_json(
{
"id": 5,
@ -213,11 +275,14 @@ async def test_call_service_blocking(
blocking=True,
context=ANY,
target=ANY,
return_response=False,
)
async_mock_service(hass, "homeassistant", "test_service")
with patch(
"homeassistant.core.ServiceRegistry.async_call", autospec=True
) as mock_call:
mock_call.return_value = None
await websocket_client.send_json(
{
"id": 6,
@ -239,11 +304,14 @@ async def test_call_service_blocking(
blocking=True,
context=ANY,
target=ANY,
return_response=False,
)
async_mock_service(hass, "homeassistant", "restart")
with patch(
"homeassistant.core.ServiceRegistry.async_call", autospec=True
) as mock_call:
mock_call.return_value = None
await websocket_client.send_json(
{
"id": 7,
@ -258,7 +326,14 @@ async def test_call_service_blocking(
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
mock_call.assert_called_once_with(
ANY, "homeassistant", "restart", ANY, blocking=True, context=ANY, target=ANY
ANY,
"homeassistant",
"restart",
ANY,
blocking=True,
context=ANY,
target=ANY,
return_response=False,
)