Initial orjson support (#72754)

This commit is contained in:
J. Nick Koston 2022-05-31 09:18:11 -10:00 committed by GitHub
parent a3e1b285cf
commit d9d22a9556
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 127 additions and 67 deletions

View file

@ -24,10 +24,10 @@ from homeassistant.components.recorder.statistics import (
)
from homeassistant.components.recorder.util import session_scope
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util

View file

@ -14,9 +14,9 @@ from homeassistant.components import websocket_api
from homeassistant.components.recorder import get_instance
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.helpers.json import JSON_DUMP
import homeassistant.util.dt as dt_util
from .helpers import (

View file

@ -1,12 +1,11 @@
"""Recorder constants."""
from functools import partial
import json
from typing import Final
from homeassistant.backports.enum import StrEnum
from homeassistant.const import ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-import
JSON_DUMP,
)
DATA_INSTANCE = "recorder_instance"
SQLITE_URL_PREFIX = "sqlite://"
@ -27,7 +26,6 @@ MAX_ROWS_TO_PURGE = 998
DB_WORKER_PREFIX = "DbWorker"
JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, separators=(",", ":"))
ALL_DOMAIN_EXCLUDE_ATTRS = {ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES}

View file

@ -744,11 +744,12 @@ class Recorder(threading.Thread):
return
try:
shared_data = EventData.shared_data_from_event(event)
shared_data_bytes = EventData.shared_data_bytes_from_event(event)
except (TypeError, ValueError) as ex:
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
return
shared_data = shared_data_bytes.decode("utf-8")
# Matching attributes found in the pending commit
if pending_event_data := self._pending_event_data.get(shared_data):
dbevent.event_data_rel = pending_event_data
@ -756,7 +757,7 @@ class Recorder(threading.Thread):
elif data_id := self._event_data_ids.get(shared_data):
dbevent.data_id = data_id
else:
data_hash = EventData.hash_shared_data(shared_data)
data_hash = EventData.hash_shared_data_bytes(shared_data_bytes)
# Matching attributes found in the database
if data_id := self._find_shared_data_in_db(data_hash, shared_data):
self._event_data_ids[shared_data] = dbevent.data_id = data_id
@ -775,7 +776,7 @@ class Recorder(threading.Thread):
assert self.event_session is not None
try:
dbstate = States.from_event(event)
shared_attrs = StateAttributes.shared_attrs_from_event(
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
event, self._exclude_attributes_by_domain
)
except (TypeError, ValueError) as ex:
@ -786,6 +787,7 @@ class Recorder(threading.Thread):
)
return
shared_attrs = shared_attrs_bytes.decode("utf-8")
dbstate.attributes = None
# Matching attributes found in the pending commit
if pending_attributes := self._pending_state_attributes.get(shared_attrs):
@ -794,7 +796,7 @@ class Recorder(threading.Thread):
elif attributes_id := self._state_attributes_ids.get(shared_attrs):
dbstate.attributes_id = attributes_id
else:
attr_hash = StateAttributes.hash_shared_attrs(shared_attrs)
attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
# Matching attributes found in the database
if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
dbstate.attributes_id = attributes_id

View file

@ -3,12 +3,12 @@ from __future__ import annotations
from collections.abc import Callable
from datetime import datetime, timedelta
import json
import logging
from typing import Any, TypedDict, cast, overload
import ciso8601
from fnvhash import fnv1a_32
import orjson
from sqlalchemy import (
JSON,
BigInteger,
@ -46,9 +46,10 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSON_DUMP, json_bytes
import homeassistant.util.dt as dt_util
from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
from .const import ALL_DOMAIN_EXCLUDE_ATTRS
# SQLAlchemy Schema
# pylint: disable=invalid-name
@ -132,7 +133,7 @@ class JSONLiteral(JSON): # type: ignore[misc]
def process(value: Any) -> str:
"""Dump json."""
return json.dumps(value)
return JSON_DUMP(value)
return process
@ -199,7 +200,7 @@ class Events(Base): # type: ignore[misc,valid-type]
try:
return Event(
self.event_type,
json.loads(self.event_data) if self.event_data else {},
orjson.loads(self.event_data) if self.event_data else {},
EventOrigin(self.origin)
if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx],
@ -207,7 +208,7 @@ class Events(Base): # type: ignore[misc,valid-type]
context=context,
)
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
@ -235,25 +236,26 @@ class EventData(Base): # type: ignore[misc,valid-type]
@staticmethod
def from_event(event: Event) -> EventData:
"""Create object from an event."""
shared_data = JSON_DUMP(event.data)
shared_data = json_bytes(event.data)
return EventData(
shared_data=shared_data, hash=EventData.hash_shared_data(shared_data)
shared_data=shared_data.decode("utf-8"),
hash=EventData.hash_shared_data_bytes(shared_data),
)
@staticmethod
def shared_data_from_event(event: Event) -> str:
"""Create shared_attrs from an event."""
return JSON_DUMP(event.data)
def shared_data_bytes_from_event(event: Event) -> bytes:
"""Create shared_data from an event."""
return json_bytes(event.data)
@staticmethod
def hash_shared_data(shared_data: str) -> int:
def hash_shared_data_bytes(shared_data_bytes: bytes) -> int:
"""Return the hash of json encoded shared data."""
return cast(int, fnv1a_32(shared_data.encode("utf-8")))
return cast(int, fnv1a_32(shared_data_bytes))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_data))
return cast(dict[str, Any], orjson.loads(self.shared_data))
except ValueError:
_LOGGER.exception("Error converting row to event data: %s", self)
return {}
@ -340,9 +342,9 @@ class States(Base): # type: ignore[misc,valid-type]
parent_id=self.context_parent_id,
)
try:
attrs = json.loads(self.attributes) if self.attributes else {}
attrs = orjson.loads(self.attributes) if self.attributes else {}
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
if self.last_changed is None or self.last_changed == self.last_updated:
@ -388,40 +390,39 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
"""Create object from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
dbstate = StateAttributes(
shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
)
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
attr_bytes = b"{}" if state is None else json_bytes(state.attributes)
dbstate = StateAttributes(shared_attrs=attr_bytes.decode("utf-8"))
dbstate.hash = StateAttributes.hash_shared_attrs_bytes(attr_bytes)
return dbstate
@staticmethod
def shared_attrs_from_event(
def shared_attrs_bytes_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
) -> str:
) -> bytes:
"""Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
if state is None:
return "{}"
return b"{}"
domain = split_entity_id(state.entity_id)[0]
exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
)
return JSON_DUMP(
return json_bytes(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
)
@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int:
"""Return the hash of orjson encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs_bytes))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_attrs))
return cast(dict[str, Any], orjson.loads(self.shared_attrs))
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self)
return {}
@ -835,7 +836,7 @@ def decode_attributes_from_row(
if not source or source == EMPTY_JSON_OBJECT:
return {}
try:
attr_cache[source] = attributes = json.loads(source)
attr_cache[source] = attributes = orjson.loads(source)
except ValueError:
_LOGGER.exception("Error converting row to state attributes: %s", source)
attr_cache[source] = attributes = {}

View file

@ -29,7 +29,7 @@ from homeassistant.helpers.event import (
TrackTemplateResult,
async_track_template_result,
)
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
@ -241,13 +241,13 @@ def handle_get_states(
# to succeed for the UI to show.
response = messages.result_message(msg["id"], states)
try:
connection.send_message(const.JSON_DUMP(response))
connection.send_message(JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
@ -256,13 +256,13 @@ def handle_get_states(
serialized = []
for state in states:
try:
serialized.append(const.JSON_DUMP(state))
serialized.append(JSON_DUMP(state))
except (ValueError, TypeError):
# Error is already logged above
pass
# We now have partially serialized states. Craft some JSON.
response2 = const.JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
connection.send_message(response2)
@ -315,13 +315,13 @@ def handle_subscribe_entities(
# to succeed for the UI to show.
response = messages.event_message(msg["id"], data)
try:
connection.send_message(const.JSON_DUMP(response))
connection.send_message(JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
@ -330,14 +330,14 @@ def handle_subscribe_entities(
cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items():
try:
const.JSON_DUMP(state_dict)
JSON_DUMP(state_dict)
except (ValueError, TypeError):
cannot_serialize.append(entity_id)
for entity_id in cannot_serialize:
del add_entities[entity_id]
connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data)))
connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))
@decorators.websocket_command({vol.Required("type"): "get_services"})

View file

@ -11,6 +11,7 @@ import voluptuous as vol
from homeassistant.auth.models import RefreshToken, User
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.json import JSON_DUMP
from . import const, messages
@ -56,7 +57,7 @@ class ActiveConnection:
async def send_big_result(self, msg_id: int, result: Any) -> None:
"""Send a result message that would be expensive to JSON serialize."""
content = await self.hass.async_add_executor_job(
const.JSON_DUMP, messages.result_message(msg_id, result)
JSON_DUMP, messages.result_message(msg_id, result)
)
self.send_message(content)

View file

@ -4,12 +4,9 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from concurrent import futures
from functools import partial
import json
from typing import TYPE_CHECKING, Any, Final
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder
if TYPE_CHECKING:
from .connection import ActiveConnection # noqa: F401
@ -53,10 +50,6 @@ SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected"
# Data used to store the current connection list
DATA_CONNECTIONS: Final = f"{DOMAIN}.connections"
JSON_DUMP: Final = partial(
json.dumps, cls=JSONEncoder, allow_nan=False, separators=(",", ":")
)
COMPRESSED_STATE_STATE = "s"
COMPRESSED_STATE_ATTRIBUTES = "a"
COMPRESSED_STATE_CONTEXT = "c"

View file

@ -9,6 +9,7 @@ import voluptuous as vol
from homeassistant.core import Event, State
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.util.json import (
find_paths_unserializable_data,
format_unserializable_data,
@ -193,15 +194,15 @@ def compressed_state_dict_add(state: State) -> dict[str, Any]:
def message_to_json(message: dict[str, Any]) -> str:
"""Serialize a websocket message to json."""
try:
return const.JSON_DUMP(message)
return JSON_DUMP(message)
except (ValueError, TypeError):
_LOGGER.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(message, dump=const.JSON_DUMP)
find_paths_unserializable_data(message, dump=JSON_DUMP)
),
)
return const.JSON_DUMP(
return JSON_DUMP(
error_message(
message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response"
)

View file

@ -14,6 +14,7 @@ from aiohttp import web
from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
import async_timeout
import orjson
from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
@ -97,6 +98,7 @@ def _async_create_clientsession(
"""Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession(
connector=_async_get_connector(hass, verify_ssl),
json_serialize=lambda x: orjson.dumps(x).decode("utf-8"),
**kwargs,
)
# Prevent packages accidentally overriding our default headers

View file

@ -1,7 +1,10 @@
"""Helpers to help with encoding Home Assistant objects in JSON."""
import datetime
import json
from typing import Any
from pathlib import Path
from typing import Any, Final
import orjson
class JSONEncoder(json.JSONEncoder):
@ -22,6 +25,20 @@ class JSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o)
def json_encoder_default(obj: Any) -> Any:
"""Convert Home Assistant objects.
Hand other objects to the original method.
"""
if isinstance(obj, set):
return list(obj)
if hasattr(obj, "as_dict"):
return obj.as_dict()
if isinstance(obj, Path):
return obj.as_posix()
raise TypeError
class ExtendedJSONEncoder(JSONEncoder):
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
@ -40,3 +57,31 @@ class ExtendedJSONEncoder(JSONEncoder):
return super().default(o)
except TypeError:
return {"__type": str(type(o)), "repr": repr(o)}
def json_bytes(data: Any) -> bytes:
"""Dump json bytes."""
return orjson.dumps(
data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default
)
def json_dumps(data: Any) -> str:
"""Dump json string.
orjson supports serializing dataclasses natively which
eliminates the need to implement as_dict in many places
when the data is already in a dataclass. This works
well as long as all the data in the dataclass can also
be serialized.
If it turns out to be a problem we can disable this
with option |= orjson.OPT_PASSTHROUGH_DATACLASS and it
will fallback to as_dict
"""
return orjson.dumps(
data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default
).decode("utf-8")
JSON_DUMP: Final = json_dumps

View file

@ -20,6 +20,7 @@ httpx==0.23.0
ifaddr==0.1.7
jinja2==3.1.2
lru-dict==1.1.7
orjson==3.6.8
paho-mqtt==1.6.1
pillow==9.1.1
pip>=21.0,<22.2

View file

@ -12,14 +12,13 @@ from timeit import default_timer as timer
from typing import TypeVar
from homeassistant import core
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.const import EVENT_STATE_CHANGED
from homeassistant.helpers.entityfilter import convert_include_exclude_filter
from homeassistant.helpers.event import (
async_track_state_change,
async_track_state_change_event,
)
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import JSON_DUMP, JSONEncoder
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
# mypy: no-warn-return-any

View file

@ -7,6 +7,8 @@ import json
import logging
from typing import Any
import orjson
from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError
@ -30,7 +32,7 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict:
"""
try:
with open(filename, encoding="utf-8") as fdesc:
return json.loads(fdesc.read()) # type: ignore[no-any-return]
return orjson.loads(fdesc.read()) # type: ignore[no-any-return]
except FileNotFoundError:
# This is not a fatal error
_LOGGER.debug("JSON file not found: %s", filename)
@ -56,7 +58,10 @@ def save_json(
Returns True on success.
"""
try:
json_data = json.dumps(data, indent=4, cls=encoder)
if encoder:
json_data = json.dumps(data, indent=2, cls=encoder)
else:
json_data = orjson.dumps(data, option=orjson.OPT_INDENT_2).decode("utf-8")
except TypeError as error:
msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data))}"
_LOGGER.error(msg)

View file

@ -41,6 +41,7 @@ dependencies = [
"PyJWT==2.4.0",
# PyJWT has loose dependency. We want the latest one.
"cryptography==36.0.2",
"orjson==3.6.8",
"pip>=21.0,<22.2",
"python-slugify==4.0.1",
"pyyaml==6.0",
@ -119,6 +120,7 @@ extension-pkg-allow-list = [
"av.audio.stream",
"av.stream",
"ciso8601",
"orjson",
"cv2",
]

View file

@ -15,6 +15,7 @@ ifaddr==0.1.7
jinja2==3.1.2
PyJWT==2.4.0
cryptography==36.0.2
orjson==3.6.8
pip>=21.0,<22.2
python-slugify==4.0.1
pyyaml==6.0

View file

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from homeassistant.components.energy import async_get_manager, validate
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.setup import async_setup_component
@ -408,7 +409,11 @@ async def test_validation_grid(
},
)
assert (await validate.async_validate(hass)).as_dict() == {
result = await validate.async_validate(hass)
# verify its also json serializable
JSON_DUMP(result)
assert result.as_dict() == {
"energy_sources": [
[
{

View file

@ -619,12 +619,15 @@ async def test_states_filters_visible(hass, hass_admin_user, websocket_client):
async def test_get_states_not_allows_nan(hass, websocket_client):
"""Test get_states command not allows NaN floats."""
"""Test get_states command converts NaN to None."""
hass.states.async_set("greeting.hello", "world")
hass.states.async_set("greeting.bad", "data", {"hello": float("NaN")})
hass.states.async_set("greeting.bye", "universe")
await websocket_client.send_json({"id": 5, "type": "get_states"})
bad = dict(hass.states.get("greeting.bad").as_dict())
bad["attributes"] = dict(bad["attributes"])
bad["attributes"]["hello"] = None
msg = await websocket_client.receive_json()
assert msg["id"] == 5
@ -632,6 +635,7 @@ async def test_get_states_not_allows_nan(hass, websocket_client):
assert msg["success"]
assert msg["result"] == [
hass.states.get("greeting.hello").as_dict(),
bad,
hass.states.get("greeting.bye").as_dict(),
]