Speed up reconnects by caching state serialize (#93050)

This commit is contained in:
J. Nick Koston 2023-05-16 02:33:12 -05:00 committed by GitHub
parent 9c039a17ea
commit 99265a983a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 152 additions and 53 deletions

View file

@ -2,7 +2,6 @@
from __future__ import annotations
from collections.abc import Callable
from contextlib import suppress
import datetime as dt
from functools import lru_cache
import json
@ -50,6 +49,17 @@ from . import const, decorators, messages
from .connection import ActiveConnection
from .const import ERR_NOT_FOUND
_STATES_TEMPLATE = "__STATES__"
_STATES_JSON_TEMPLATE = '"__STATES__"'
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
messages.event_message(
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
)
)
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
)
@callback
def async_register_commands(
@ -242,33 +252,43 @@ def handle_get_states(
"""Handle get states command."""
states = _async_get_allowed_states(hass, connection)
# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
# to succeed for the UI to show.
response = messages.result_message(msg["id"], states)
try:
connection.send_message(JSON_DUMP(response))
return
serialized_states = [state.as_dict_json() for state in states]
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
pass
else:
_send_handle_get_states_response(connection, msg["id"], serialized_states)
return
# If we can't serialize, we'll filter out unserializable states
serialized = []
serialized_states = []
for state in states:
# Error is already logged above
with suppress(ValueError, TypeError):
serialized.append(JSON_DUMP(state))
try:
serialized_states.append(state.as_dict_json())
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(state, dump=JSON_DUMP)
),
)
# We now have partially serialized states. Craft some JSON.
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
connection.send_message(response2)
_send_handle_get_states_response(connection, msg["id"], serialized_states)
def _send_handle_get_states_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None:
"""Send handle get states response."""
connection.send_message(
_HANDLE_GET_STATES_TEMPLATE.replace(
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"[" + ",".join(serialized_states) + "]",
1,
)
)
@callback
@ -304,42 +324,50 @@ def handle_subscribe_entities(
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True
)
connection.send_result(msg["id"])
data: dict[str, dict[str, dict]] = {
messages.ENTITY_EVENT_ADD: {
state.entity_id: state.as_compressed_state()
for state in states
if not entity_ids or state.entity_id in entity_ids
}
}
# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
# to succeed for the UI to show.
response = messages.event_message(msg["id"], data)
try:
connection.send_message(JSON_DUMP(response))
return
serialized_states = [
state.as_compressed_state_json()
for state in states
if not entity_ids or state.entity_id in entity_ids
]
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
pass
else:
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
return
add_entities = data[messages.ENTITY_EVENT_ADD]
cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items():
serialized_states = []
for state in states:
try:
JSON_DUMP(state_dict)
serialized_states.append(state.as_compressed_state_json())
except (ValueError, TypeError):
cannot_serialize.append(entity_id)
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(state, dump=JSON_DUMP)
),
)
for entity_id in cannot_serialize:
del add_entities[entity_id]
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))
def _send_handle_entities_init_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None:
"""Send handle entities init response."""
connection.send_message(
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace(
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"{" + ",".join(serialized_states) + "}",
1,
)
)
@decorators.websocket_command({vol.Required("type"): "get_services"})

View file

@ -44,7 +44,7 @@ ENTITY_EVENT_REMOVE = "r"
ENTITY_EVENT_CHANGE = "c"
def result_message(iden: int, result: Any = None) -> dict[str, Any]:
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]:
"""Return a success result message."""
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}

View file

@ -80,6 +80,7 @@ from .exceptions import (
Unauthorized,
)
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
from .helpers.json import json_dumps
from .util import dt as dt_util, location, ulid as ulid_util
from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
from .util.read_only_dict import ReadOnlyDict
@ -1224,6 +1225,8 @@ class State:
"object_id",
"_as_dict",
"_as_compressed_state",
"_as_dict_json",
"_as_compressed_state_json",
)
def __init__(
@ -1260,6 +1263,8 @@ class State:
self.domain, self.object_id = split_entity_id(self.entity_id)
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
self._as_compressed_state: dict[str, Any] | None = None
self._as_dict_json: str | None = None
self._as_compressed_state_json: str | None = None
@property
def name(self) -> str:
@ -1294,6 +1299,12 @@ class State:
)
return self._as_dict
def as_dict_json(self) -> str:
"""Return a JSON string of the State."""
if not self._as_dict_json:
self._as_dict_json = json_dumps(self.as_dict())
return self._as_dict_json
def as_compressed_state(self) -> dict[str, Any]:
"""Build a compressed dict of a state for adds.
@ -1321,6 +1332,19 @@ class State:
self._as_compressed_state = compressed_state
return compressed_state
def as_compressed_state_json(self) -> str:
"""Build a compressed JSON key value pair of a state for adds.
The JSON string is a key value pair of the entity_id and the compressed state.
It is used for sending multiple states in a single message.
"""
if not self._as_compressed_state_json:
self._as_compressed_state_json = json_dumps(
{self.entity_id: self.as_compressed_state()}
)[1:-1]
return self._as_compressed_state_json
@classmethod
def from_dict(cls, json_dict: dict[str, Any]) -> Self | None:
"""Initialize a state from a dict.

View file

@ -9,7 +9,6 @@ from typing import Any, Final
import orjson
from homeassistant.core import Event, State
from homeassistant.util.file import write_utf8_file, write_utf8_file_atomic
from homeassistant.util.json import ( # pylint: disable=unused-import # noqa: F401
JSON_DECODE_EXCEPTIONS,
@ -189,6 +188,11 @@ def find_paths_unserializable_data(
This method is slow! Only use for error handling.
"""
from homeassistant.core import ( # pylint: disable=import-outside-toplevel
Event,
State,
)
to_process = deque([(bad_data, "$")])
invalid = {}

View file

@ -188,10 +188,9 @@ async def test_non_json_message(
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == []
assert (
f"Unable to serialize to JSON. Bad data found at $.result[0](State: test_domain.entity).attributes.bad={bad_data}(<class 'object'>"
in caplog.text
)
assert "Unable to serialize to JSON. Bad data found" in caplog.text
assert "State: test_domain.entity" in caplog.text
assert "bad=<object" in caplog.text
async def test_prepare_fail(

View file

@ -466,6 +466,29 @@ def test_state_as_dict() -> None:
assert state.as_dict() is as_dict_1
def test_state_as_dict_json() -> None:
"""Test a State as JSON."""
last_time = datetime(1984, 12, 8, 12, 0, 0)
state = ha.State(
"happy.happy",
"on",
{"pig": "dog"},
last_updated=last_time,
last_changed=last_time,
context=ha.Context(id="01H0D6K3RFJAYAV2093ZW30PCW"),
)
expected = (
'{"entity_id":"happy.happy","state":"on","attributes":{"pig":"dog"},'
'"last_changed":"1984-12-08T12:00:00","last_updated":"1984-12-08T12:00:00",'
'"context":{"id":"01H0D6K3RFJAYAV2093ZW30PCW","parent_id":null,"user_id":null}}'
)
as_dict_json_1 = state.as_dict_json()
assert as_dict_json_1 == expected
# 2nd time to verify cache
assert state.as_dict_json() == expected
assert state.as_dict_json() is as_dict_json_1
def test_state_as_compressed_state() -> None:
"""Test a State as compressed state."""
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
@ -518,6 +541,27 @@ def test_state_as_compressed_state_unique_last_updated() -> None:
assert state.as_compressed_state() is as_compressed_state
def test_state_as_compressed_state_json() -> None:
"""Test a State as a JSON compressed state."""
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
state = ha.State(
"happy.happy",
"on",
{"pig": "dog"},
last_updated=last_time,
last_changed=last_time,
context=ha.Context(id="01H0D6H5K3SZJ3XGDHED1TJ79N"),
)
expected = '"happy.happy":{"s":"on","a":{"pig":"dog"},"c":"01H0D6H5K3SZJ3XGDHED1TJ79N","lc":471355200.0}'
as_compressed_state = state.as_compressed_state_json()
# We are not too concerned about these being ReadOnlyDict
# since we don't expect them to be called by external callers
assert as_compressed_state == expected
# 2nd time to verify cache
assert state.as_compressed_state_json() == expected
assert state.as_compressed_state_json() is as_compressed_state
async def test_eventbus_add_remove_listener(hass: HomeAssistant) -> None:
"""Test remove_listener method."""
old_count = len(hass.bus.async_listeners())