mirror of
https://github.com/home-assistant/core
synced 2024-10-05 15:22:20 +00:00
Fix saving subclassed datetime objects in storage (#97502)
This commit is contained in:
parent
c2e9fd85c2
commit
094f2cbad7
|
@ -53,6 +53,8 @@ def json_encoder_default(obj: Any) -> Any:
|
|||
return obj.as_dict()
|
||||
if isinstance(obj, Path):
|
||||
return obj.as_posix()
|
||||
if isinstance(obj, datetime.datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError
|
||||
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ from homeassistant.helpers import (
|
|||
storage,
|
||||
)
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.helpers.json import JSONEncoder, _orjson_default_encoder
|
||||
from homeassistant.helpers.typing import ConfigType, StateType
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.util.async_ import run_callback_threadsafe
|
||||
|
@ -1260,7 +1260,14 @@ def mock_storage(
|
|||
# To ensure that the data can be serialized
|
||||
_LOGGER.debug("Writing data to %s: %s", store.key, data_to_write)
|
||||
raise_contains_mocks(data_to_write)
|
||||
data[store.key] = json.loads(json.dumps(data_to_write, cls=store._encoder))
|
||||
encoder = store._encoder
|
||||
if encoder and encoder is not JSONEncoder:
|
||||
# If they pass a custom encoder that is not the
|
||||
# default JSONEncoder, we use the slow path of json.dumps
|
||||
dump = ft.partial(json.dumps, cls=store._encoder)
|
||||
else:
|
||||
dump = _orjson_default_encoder
|
||||
data[store.key] = json.loads(dump(data_to_write))
|
||||
|
||||
async def mock_remove(store: storage.Store) -> None:
|
||||
"""Remove data."""
|
||||
|
|
|
@ -215,6 +215,20 @@ def test_custom_encoder(tmp_path: Path) -> None:
|
|||
assert data == "9"
|
||||
|
||||
|
||||
def test_saving_subclassed_datetime(tmp_path: Path) -> None:
|
||||
"""Test saving subclassed datetime objects."""
|
||||
|
||||
class SubClassDateTime(datetime.datetime):
|
||||
"""Subclass datetime."""
|
||||
|
||||
time = SubClassDateTime.fromtimestamp(0)
|
||||
|
||||
fname = tmp_path / "test6.json"
|
||||
save_json(fname, {"time": time})
|
||||
data = load_json(fname)
|
||||
assert data == {"time": time.isoformat()}
|
||||
|
||||
|
||||
def test_default_encoder_is_passed(tmp_path: Path) -> None:
|
||||
"""Test we use orjson if they pass in the default encoder."""
|
||||
fname = tmp_path / "test6.json"
|
||||
|
|
Loading…
Reference in a new issue