Fix saving subclassed datetime objects in storage (#97502)

This commit is contained in:
J. Nick Koston 2023-07-31 09:49:02 -07:00 committed by GitHub
parent c2e9fd85c2
commit 094f2cbad7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 2 deletions

View file

@ -53,6 +53,8 @@ def json_encoder_default(obj: Any) -> Any:
return obj.as_dict()
if isinstance(obj, Path):
return obj.as_posix()
if isinstance(obj, datetime.datetime):
return obj.isoformat()
raise TypeError

View file

@ -67,7 +67,7 @@ from homeassistant.helpers import (
storage,
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder
from homeassistant.helpers.typing import ConfigType, StateType
from homeassistant.setup import setup_component
from homeassistant.util.async_ import run_callback_threadsafe
@ -1260,7 +1260,14 @@ def mock_storage(
# To ensure that the data can be serialized
_LOGGER.debug("Writing data to %s: %s", store.key, data_to_write)
raise_contains_mocks(data_to_write)
data[store.key] = json.loads(json.dumps(data_to_write, cls=store._encoder))
encoder = store._encoder
if encoder and encoder is not JSONEncoder:
# If they pass a custom encoder that is not the
# default JSONEncoder, we use the slow path of json.dumps
dump = ft.partial(json.dumps, cls=store._encoder)
else:
dump = _orjson_default_encoder
data[store.key] = json.loads(dump(data_to_write))
async def mock_remove(store: storage.Store) -> None:
"""Remove data."""

View file

@ -215,6 +215,20 @@ def test_custom_encoder(tmp_path: Path) -> None:
assert data == "9"
def test_saving_subclassed_datetime(tmp_path: Path) -> None:
"""Test saving subclassed datetime objects."""
class SubClassDateTime(datetime.datetime):
"""Subclass datetime."""
time = SubClassDateTime.fromtimestamp(0)
fname = tmp_path / "test6.json"
save_json(fname, {"time": time})
data = load_json(fname)
assert data == {"time": time.isoformat()}
def test_default_encoder_is_passed(tmp_path: Path) -> None:
"""Test we use orjson if they pass in the default encoder."""
fname = tmp_path / "test6.json"