Initial orjson support take 3 (#73849)

* Initial orjson support take 2

Still need to work out problem building wheels

--

Redux of #72754 / #32153 Now possible since the following is solved:
ijl/orjson#220 (comment)

This implements orjson where we use our default encoder.  This does not implement orjson where `ExtendedJSONEncoder` is used as these areas tend to be called far less frequently.  If its desired, this could be done in a followup, but it seemed like a case of diminishing returns (except maybe for large diagnostics files, or traces, but those are not expected to be downloaded frequently).

Areas where this makes a perceptible difference:
- Anything that subscribes to entities (Initial subscribe_entities payload)
- Initial download of registries on first connection / restore
- History queries
- Saving states to the database
- Large logbook queries
- Anything that subscribes to events (appdaemon)

Cavets:
orjson supports serializing dataclasses natively (and much faster) which
eliminates the need to implement `as_dict` in many places
when the data is already in a dataclass. This works
well as long as all the data in the dataclass can also
be serialized. I audited all places where we have an `as_dict`
for a dataclass and found only backups needs to be adjusted (support for `Path` needed to be added for backups).  I was a little bit worried about `SensorExtraStoredData` with `Decimal` but it all seems to work out from since it converts it before it gets to the json encoding cc @dgomes

If it turns out to be a problem we can disable this
with option |= [orjson.OPT_PASSTHROUGH_DATACLASS](https://github.com/ijl/orjson#opt_passthrough_dataclass) and it
will fallback to `as_dict`

Its quite impressive for history queries
<img width="1271" alt="Screen_Shot_2022-05-30_at_23_46_30" src="https://user-images.githubusercontent.com/663432/171145699-661ad9db-d91d-4b2d-9c1a-9d7866c03a73.png">

* use for views as well

* handle UnicodeEncodeError

* tweak

* DRY

* DRY

* not needed

* fix tests

* Update tests/components/http/test_view.py

* Update tests/components/http/test_view.py

* black

* templates
This commit is contained in:
J. Nick Koston 2022-06-22 14:59:51 -05:00 committed by GitHub
parent 9ac28d2076
commit 8b067e83f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 149 additions and 80 deletions

View file

@ -24,10 +24,10 @@ from homeassistant.components.recorder.statistics import (
)
from homeassistant.components.recorder.util import session_scope
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.helpers.typing import ConfigType
import homeassistant.util.dt as dt_util

View file

@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from http import HTTPStatus
import json
import logging
from typing import Any
@ -21,7 +20,7 @@ import voluptuous as vol
from homeassistant import exceptions
from homeassistant.const import CONTENT_TYPE_JSON
from homeassistant.core import Context, is_callback
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import JSON_ENCODE_EXCEPTIONS, json_bytes
from .const import KEY_AUTHENTICATED, KEY_HASS
@ -53,8 +52,8 @@ class HomeAssistantView:
) -> web.Response:
"""Return a JSON response."""
try:
msg = json.dumps(result, cls=JSONEncoder, allow_nan=False).encode("UTF-8")
except (ValueError, TypeError) as err:
msg = json_bytes(result)
except JSON_ENCODE_EXCEPTIONS as err:
_LOGGER.error("Unable to serialize to JSON: %s\n%s", err, result)
raise HTTPInternalServerError from err
response = web.Response(

View file

@ -14,10 +14,10 @@ from homeassistant.components import websocket_api
from homeassistant.components.recorder import get_instance
from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.entityfilter import EntityFilter
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.helpers.json import JSON_DUMP
import homeassistant.util.dt as dt_util
from .const import LOGBOOK_ENTITIES_FILTER

View file

@ -1,12 +1,10 @@
"""Recorder constants."""
from functools import partial
import json
from typing import Final
from homeassistant.backports.enum import StrEnum
from homeassistant.const import ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import ( # noqa: F401 pylint: disable=unused-import
JSON_DUMP,
)
DATA_INSTANCE = "recorder_instance"
SQLITE_URL_PREFIX = "sqlite://"
@ -27,8 +25,6 @@ MAX_ROWS_TO_PURGE = 998
DB_WORKER_PREFIX = "DbWorker"
JSON_DUMP: Final = partial(json.dumps, cls=JSONEncoder, separators=(",", ":"))
ALL_DOMAIN_EXCLUDE_ATTRS = {ATTR_ATTRIBUTION, ATTR_RESTORED, ATTR_SUPPORTED_FEATURES}
ATTR_KEEP_DAYS = "keep_days"

View file

@ -36,6 +36,7 @@ from homeassistant.helpers.event import (
async_track_time_interval,
async_track_utc_time_change,
)
from homeassistant.helpers.json import JSON_ENCODE_EXCEPTIONS
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util
@ -754,11 +755,12 @@ class Recorder(threading.Thread):
return
try:
shared_data = EventData.shared_data_from_event(event)
except (TypeError, ValueError) as ex:
shared_data_bytes = EventData.shared_data_bytes_from_event(event)
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning("Event is not JSON serializable: %s: %s", event, ex)
return
shared_data = shared_data_bytes.decode("utf-8")
# Matching attributes found in the pending commit
if pending_event_data := self._pending_event_data.get(shared_data):
dbevent.event_data_rel = pending_event_data
@ -766,7 +768,7 @@ class Recorder(threading.Thread):
elif data_id := self._event_data_ids.get(shared_data):
dbevent.data_id = data_id
else:
data_hash = EventData.hash_shared_data(shared_data)
data_hash = EventData.hash_shared_data_bytes(shared_data_bytes)
# Matching attributes found in the database
if data_id := self._find_shared_data_in_db(data_hash, shared_data):
self._event_data_ids[shared_data] = dbevent.data_id = data_id
@ -785,10 +787,10 @@ class Recorder(threading.Thread):
assert self.event_session is not None
try:
dbstate = States.from_event(event)
shared_attrs = StateAttributes.shared_attrs_from_event(
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
event, self._exclude_attributes_by_domain
)
except (TypeError, ValueError) as ex:
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning(
"State is not JSON serializable: %s: %s",
event.data.get("new_state"),
@ -796,6 +798,7 @@ class Recorder(threading.Thread):
)
return
shared_attrs = shared_attrs_bytes.decode("utf-8")
dbstate.attributes = None
# Matching attributes found in the pending commit
if pending_attributes := self._pending_state_attributes.get(shared_attrs):
@ -804,7 +807,7 @@ class Recorder(threading.Thread):
elif attributes_id := self._state_attributes_ids.get(shared_attrs):
dbstate.attributes_id = attributes_id
else:
attr_hash = StateAttributes.hash_shared_attrs(shared_attrs)
attr_hash = StateAttributes.hash_shared_attrs_bytes(shared_attrs_bytes)
# Matching attributes found in the database
if attributes_id := self._find_shared_attr_in_db(attr_hash, shared_attrs):
dbstate.attributes_id = attributes_id

View file

@ -3,12 +3,12 @@ from __future__ import annotations
from collections.abc import Callable
from datetime import datetime, timedelta
import json
import logging
from typing import Any, cast
import ciso8601
from fnvhash import fnv1a_32
import orjson
from sqlalchemy import (
JSON,
BigInteger,
@ -39,9 +39,10 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.json import JSON_DUMP, json_bytes
import homeassistant.util.dt as dt_util
from .const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
from .const import ALL_DOMAIN_EXCLUDE_ATTRS
from .models import StatisticData, StatisticMetaData, process_timestamp
# SQLAlchemy Schema
@ -124,7 +125,7 @@ class JSONLiteral(JSON): # type: ignore[misc]
def process(value: Any) -> str:
"""Dump json."""
return json.dumps(value)
return JSON_DUMP(value)
return process
@ -187,7 +188,7 @@ class Events(Base): # type: ignore[misc,valid-type]
try:
return Event(
self.event_type,
json.loads(self.event_data) if self.event_data else {},
orjson.loads(self.event_data) if self.event_data else {},
EventOrigin(self.origin)
if self.origin
else EVENT_ORIGIN_ORDER[self.origin_idx],
@ -195,7 +196,7 @@ class Events(Base): # type: ignore[misc,valid-type]
context=context,
)
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
@ -223,25 +224,26 @@ class EventData(Base): # type: ignore[misc,valid-type]
@staticmethod
def from_event(event: Event) -> EventData:
"""Create object from an event."""
shared_data = JSON_DUMP(event.data)
shared_data = json_bytes(event.data)
return EventData(
shared_data=shared_data, hash=EventData.hash_shared_data(shared_data)
shared_data=shared_data.decode("utf-8"),
hash=EventData.hash_shared_data_bytes(shared_data),
)
@staticmethod
def shared_data_from_event(event: Event) -> str:
"""Create shared_attrs from an event."""
return JSON_DUMP(event.data)
def shared_data_bytes_from_event(event: Event) -> bytes:
"""Create shared_data from an event."""
return json_bytes(event.data)
@staticmethod
def hash_shared_data(shared_data: str) -> int:
def hash_shared_data_bytes(shared_data_bytes: bytes) -> int:
"""Return the hash of json encoded shared data."""
return cast(int, fnv1a_32(shared_data.encode("utf-8")))
return cast(int, fnv1a_32(shared_data_bytes))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_data))
return cast(dict[str, Any], orjson.loads(self.shared_data))
except ValueError:
_LOGGER.exception("Error converting row to event data: %s", self)
return {}
@ -328,9 +330,9 @@ class States(Base): # type: ignore[misc,valid-type]
parent_id=self.context_parent_id,
)
try:
attrs = json.loads(self.attributes) if self.attributes else {}
attrs = orjson.loads(self.attributes) if self.attributes else {}
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
if self.last_changed is None or self.last_changed == self.last_updated:
@ -376,40 +378,39 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
"""Create object from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
dbstate = StateAttributes(
shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
)
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
attr_bytes = b"{}" if state is None else json_bytes(state.attributes)
dbstate = StateAttributes(shared_attrs=attr_bytes.decode("utf-8"))
dbstate.hash = StateAttributes.hash_shared_attrs_bytes(attr_bytes)
return dbstate
@staticmethod
def shared_attrs_from_event(
def shared_attrs_bytes_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
) -> str:
) -> bytes:
"""Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
if state is None:
return "{}"
return b"{}"
domain = split_entity_id(state.entity_id)[0]
exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
)
return JSON_DUMP(
return json_bytes(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
)
@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int:
"""Return the hash of orjson encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs_bytes))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_attrs))
return cast(dict[str, Any], orjson.loads(self.shared_attrs))
except ValueError:
# When json.loads fails
# When orjson.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self)
return {}

View file

@ -2,10 +2,10 @@
from __future__ import annotations
from datetime import datetime
import json
import logging
from typing import Any, TypedDict, overload
import orjson
from sqlalchemy.engine.row import Row
from homeassistant.components.websocket_api.const import (
@ -253,7 +253,7 @@ def decode_attributes_from_row(
if not source or source == EMPTY_JSON_OBJECT:
return {}
try:
attr_cache[source] = attributes = json.loads(source)
attr_cache[source] = attributes = orjson.loads(source)
except ValueError:
_LOGGER.exception("Error converting row to state attributes: %s", source)
attr_cache[source] = attributes = {}

View file

@ -29,7 +29,7 @@ from homeassistant.helpers.event import (
TrackTemplateResult,
async_track_template_result,
)
from homeassistant.helpers.json import ExtendedJSONEncoder
from homeassistant.helpers.json import JSON_DUMP, ExtendedJSONEncoder
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration
from homeassistant.setup import DATA_SETUP_TIME, async_get_loaded_integrations
@ -241,13 +241,13 @@ def handle_get_states(
# to succeed for the UI to show.
response = messages.result_message(msg["id"], states)
try:
connection.send_message(const.JSON_DUMP(response))
connection.send_message(JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
@ -256,13 +256,13 @@ def handle_get_states(
serialized = []
for state in states:
try:
serialized.append(const.JSON_DUMP(state))
serialized.append(JSON_DUMP(state))
except (ValueError, TypeError):
# Error is already logged above
pass
# We now have partially serialized states. Craft some JSON.
response2 = const.JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
connection.send_message(response2)
@ -315,13 +315,13 @@ def handle_subscribe_entities(
# to succeed for the UI to show.
response = messages.event_message(msg["id"], data)
try:
connection.send_message(const.JSON_DUMP(response))
connection.send_message(JSON_DUMP(response))
return
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=const.JSON_DUMP)
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
@ -330,14 +330,14 @@ def handle_subscribe_entities(
cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items():
try:
const.JSON_DUMP(state_dict)
JSON_DUMP(state_dict)
except (ValueError, TypeError):
cannot_serialize.append(entity_id)
for entity_id in cannot_serialize:
del add_entities[entity_id]
connection.send_message(const.JSON_DUMP(messages.event_message(msg["id"], data)))
connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))
@decorators.websocket_command({vol.Required("type"): "get_services"})

View file

@ -11,6 +11,7 @@ import voluptuous as vol
from homeassistant.auth.models import RefreshToken, User
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.json import JSON_DUMP
from . import const, messages
@ -56,7 +57,7 @@ class ActiveConnection:
async def send_big_result(self, msg_id: int, result: Any) -> None:
"""Send a result message that would be expensive to JSON serialize."""
content = await self.hass.async_add_executor_job(
const.JSON_DUMP, messages.result_message(msg_id, result)
JSON_DUMP, messages.result_message(msg_id, result)
)
self.send_message(content)

View file

@ -4,12 +4,9 @@ from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from concurrent import futures
from functools import partial
import json
from typing import TYPE_CHECKING, Any, Final
from homeassistant.core import HomeAssistant
from homeassistant.helpers.json import JSONEncoder
if TYPE_CHECKING:
from .connection import ActiveConnection # noqa: F401
@ -53,10 +50,6 @@ SIGNAL_WEBSOCKET_DISCONNECTED: Final = "websocket_disconnected"
# Data used to store the current connection list
DATA_CONNECTIONS: Final = f"{DOMAIN}.connections"
JSON_DUMP: Final = partial(
json.dumps, cls=JSONEncoder, allow_nan=False, separators=(",", ":")
)
COMPRESSED_STATE_STATE = "s"
COMPRESSED_STATE_ATTRIBUTES = "a"
COMPRESSED_STATE_CONTEXT = "c"

View file

@ -9,6 +9,7 @@ import voluptuous as vol
from homeassistant.core import Event, State
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.util.json import (
find_paths_unserializable_data,
format_unserializable_data,
@ -193,15 +194,15 @@ def compressed_state_dict_add(state: State) -> dict[str, Any]:
def message_to_json(message: dict[str, Any]) -> str:
"""Serialize a websocket message to json."""
try:
return const.JSON_DUMP(message)
return JSON_DUMP(message)
except (ValueError, TypeError):
_LOGGER.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(message, dump=const.JSON_DUMP)
find_paths_unserializable_data(message, dump=JSON_DUMP)
),
)
return const.JSON_DUMP(
return JSON_DUMP(
error_message(
message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response"
)

View file

@ -14,6 +14,7 @@ from aiohttp import web
from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
import async_timeout
import orjson
from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
@ -97,6 +98,7 @@ def _async_create_clientsession(
"""Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession(
connector=_async_get_connector(hass, verify_ssl),
json_serialize=lambda x: orjson.dumps(x).decode("utf-8"),
**kwargs,
)
# Prevent packages accidentally overriding our default headers

View file

@ -1,7 +1,12 @@
"""Helpers to help with encoding Home Assistant objects in JSON."""
import datetime
import json
from typing import Any
from pathlib import Path
from typing import Any, Final
import orjson
JSON_ENCODE_EXCEPTIONS = (TypeError, ValueError)
class JSONEncoder(json.JSONEncoder):
@ -22,6 +27,20 @@ class JSONEncoder(json.JSONEncoder):
return json.JSONEncoder.default(self, o)
def json_encoder_default(obj: Any) -> Any:
"""Convert Home Assistant objects.
Hand other objects to the original method.
"""
if isinstance(obj, set):
return list(obj)
if hasattr(obj, "as_dict"):
return obj.as_dict()
if isinstance(obj, Path):
return obj.as_posix()
raise TypeError
class ExtendedJSONEncoder(JSONEncoder):
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
@ -40,3 +59,31 @@ class ExtendedJSONEncoder(JSONEncoder):
return super().default(o)
except TypeError:
return {"__type": str(type(o)), "repr": repr(o)}
def json_bytes(data: Any) -> bytes:
"""Dump json bytes."""
return orjson.dumps(
data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default
)
def json_dumps(data: Any) -> str:
"""Dump json string.
orjson supports serializing dataclasses natively which
eliminates the need to implement as_dict in many places
when the data is already in a dataclass. This works
well as long as all the data in the dataclass can also
be serialized.
If it turns out to be a problem we can disable this
with option |= orjson.OPT_PASSTHROUGH_DATACLASS and it
will fallback to as_dict
"""
return orjson.dumps(
data, option=orjson.OPT_NON_STR_KEYS, default=json_encoder_default
).decode("utf-8")
JSON_DUMP: Final = json_dumps

View file

@ -27,6 +27,7 @@ import jinja2
from jinja2 import pass_context, pass_environment
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2.utils import Namespace
import orjson
import voluptuous as vol
from homeassistant.const import (
@ -566,7 +567,7 @@ class Template:
variables["value"] = value
with suppress(ValueError, TypeError):
variables["value_json"] = json.loads(value)
variables["value_json"] = orjson.loads(value)
try:
return _render_with_context(
@ -1743,7 +1744,7 @@ def ordinal(value):
def from_json(value):
"""Convert a JSON string to an object."""
return json.loads(value)
return orjson.loads(value)
def to_json(value, ensure_ascii=True):

View file

@ -20,6 +20,7 @@ httpx==0.23.0
ifaddr==0.1.7
jinja2==3.1.2
lru-dict==1.1.7
orjson==3.7.2
paho-mqtt==1.6.1
pillow==9.1.1
pip>=21.0,<22.2

View file

@ -12,14 +12,13 @@ from timeit import default_timer as timer
from typing import TypeVar
from homeassistant import core
from homeassistant.components.websocket_api.const import JSON_DUMP
from homeassistant.const import EVENT_STATE_CHANGED
from homeassistant.helpers.entityfilter import convert_include_exclude_filter
from homeassistant.helpers.event import (
async_track_state_change,
async_track_state_change_event,
)
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import JSON_DUMP, JSONEncoder
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
# mypy: no-warn-return-any

View file

@ -7,6 +7,8 @@ import json
import logging
from typing import Any
import orjson
from homeassistant.core import Event, State
from homeassistant.exceptions import HomeAssistantError
@ -30,7 +32,7 @@ def load_json(filename: str, default: list | dict | None = None) -> list | dict:
"""
try:
with open(filename, encoding="utf-8") as fdesc:
return json.loads(fdesc.read()) # type: ignore[no-any-return]
return orjson.loads(fdesc.read()) # type: ignore[no-any-return]
except FileNotFoundError:
# This is not a fatal error
_LOGGER.debug("JSON file not found: %s", filename)
@ -56,7 +58,10 @@ def save_json(
Returns True on success.
"""
try:
json_data = json.dumps(data, indent=4, cls=encoder)
if encoder:
json_data = json.dumps(data, indent=2, cls=encoder)
else:
json_data = orjson.dumps(data, option=orjson.OPT_INDENT_2).decode("utf-8")
except TypeError as error:
msg = f"Failed to serialize to JSON: {filename}. Bad data at {format_unserializable_data(find_paths_unserializable_data(data))}"
_LOGGER.error(msg)

View file

@ -41,6 +41,7 @@ dependencies = [
"PyJWT==2.4.0",
# PyJWT has loose dependency. We want the latest one.
"cryptography==36.0.2",
"orjson==3.7.2",
"pip>=21.0,<22.2",
"python-slugify==4.0.1",
"pyyaml==6.0",
@ -119,6 +120,7 @@ extension-pkg-allow-list = [
"av.audio.stream",
"av.stream",
"ciso8601",
"orjson",
"cv2",
]

View file

@ -15,6 +15,7 @@ ifaddr==0.1.7
jinja2==3.1.2
PyJWT==2.4.0
cryptography==36.0.2
orjson==3.7.2
pip>=21.0,<22.2
python-slugify==4.0.1
pyyaml==6.0

View file

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from homeassistant.components.energy import async_get_manager, validate
from homeassistant.helpers.json import JSON_DUMP
from homeassistant.setup import async_setup_component
@ -408,7 +409,11 @@ async def test_validation_grid(
},
)
assert (await validate.async_validate(hass)).as_dict() == {
result = await validate.async_validate(hass)
# verify its also json serializable
JSON_DUMP(result)
assert result.as_dict() == {
"energy_sources": [
[
{

View file

@ -1,5 +1,6 @@
"""Tests for Home Assistant View."""
from http import HTTPStatus
import json
from unittest.mock import AsyncMock, Mock
from aiohttp.web_exceptions import (
@ -34,9 +35,16 @@ async def test_invalid_json(caplog):
view = HomeAssistantView()
with pytest.raises(HTTPInternalServerError):
view.json(float("NaN"))
view.json(rb"\ud800")
assert str(float("NaN")) in caplog.text
assert "Unable to serialize to JSON" in caplog.text
async def test_nan_serialized_to_null(caplog):
"""Test nan serialized to null JSON."""
view = HomeAssistantView()
response = view.json(float("NaN"))
assert json.loads(response.body.decode("utf-8")) is None
async def test_handling_unauthorized(mock_request):

View file

@ -619,12 +619,15 @@ async def test_states_filters_visible(hass, hass_admin_user, websocket_client):
async def test_get_states_not_allows_nan(hass, websocket_client):
"""Test get_states command not allows NaN floats."""
"""Test get_states command converts NaN to None."""
hass.states.async_set("greeting.hello", "world")
hass.states.async_set("greeting.bad", "data", {"hello": float("NaN")})
hass.states.async_set("greeting.bye", "universe")
await websocket_client.send_json({"id": 5, "type": "get_states"})
bad = dict(hass.states.get("greeting.bad").as_dict())
bad["attributes"] = dict(bad["attributes"])
bad["attributes"]["hello"] = None
msg = await websocket_client.receive_json()
assert msg["id"] == 5
@ -632,6 +635,7 @@ async def test_get_states_not_allows_nan(hass, websocket_client):
assert msg["success"]
assert msg["result"] == [
hass.states.get("greeting.hello").as_dict(),
bad,
hass.states.get("greeting.bye").as_dict(),
]