Add strict typing to core.py (2) - State (#63240)

This commit is contained in:
Marc Mueller 2022-01-04 18:33:56 +01:00 committed by GitHub
parent 5f5adffd5b
commit 3a32fe9a34
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 9 deletions

View file

@ -880,6 +880,9 @@ class EventBus:
)
_StateT = TypeVar("_StateT", bound="State")
class State:
"""Object to represent a state within the state machine.
@ -946,7 +949,7 @@ class State:
"_", " "
)
def as_dict(self) -> dict:
def as_dict(self) -> dict[str, Collection[Any]]:
"""Return a dict representation of the State.
Async friendly.
@ -971,7 +974,7 @@ class State:
return self._as_dict
@classmethod
def from_dict(cls, json_dict: dict) -> Any:
def from_dict(cls: type[_StateT], json_dict: dict[str, Any]) -> _StateT | None:
"""Initialize a state from a dict.
Async friendly.
@ -1042,7 +1045,7 @@ class StateMachine:
@callback
def async_entity_ids(
self, domain_filter: str | Iterable | None = None
self, domain_filter: str | Iterable[str] | None = None
) -> list[str]:
"""List of entity ids that are being tracked.
@ -1062,7 +1065,7 @@ class StateMachine:
@callback
def async_entity_ids_count(
self, domain_filter: str | Iterable | None = None
self, domain_filter: str | Iterable[str] | None = None
) -> int:
"""Count the entity ids that are being tracked.
@ -1078,14 +1081,16 @@ class StateMachine:
[None for state in self._states.values() if state.domain in domain_filter]
)
def all(self, domain_filter: str | Iterable | None = None) -> list[State]:
def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]:
"""Create a list of all states."""
return run_callback_threadsafe(
self._loop, self.async_all, domain_filter
).result()
@callback
def async_all(self, domain_filter: str | Iterable | None = None) -> list[State]:
def async_all(
self, domain_filter: str | Iterable[str] | None = None
) -> list[State]:
"""Create a list of all states matching the filter.
This method must be run in the event loop.

View file

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio
from datetime import datetime, timedelta
import logging
from typing import Any, cast
from typing import Any, TypeVar, cast
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant, State, callback, valid_entity_id
@ -31,6 +31,8 @@ STATE_DUMP_INTERVAL = timedelta(minutes=15)
# How long should a saved state be preserved if the entity no longer exists
STATE_EXPIRATION = timedelta(days=7)
_StoredStateT = TypeVar("_StoredStateT", bound="StoredState")
class StoredState:
"""Object to represent a stored state."""
@ -45,14 +47,14 @@ class StoredState:
return {"state": self.state.as_dict(), "last_seen": self.last_seen}
@classmethod
def from_dict(cls, json_dict: dict) -> StoredState:
def from_dict(cls: type[_StoredStateT], json_dict: dict) -> _StoredStateT:
"""Initialize a stored state from a dict."""
last_seen = json_dict["last_seen"]
if isinstance(last_seen, str):
last_seen = dt_util.parse_datetime(last_seen)
return cls(State.from_dict(json_dict["state"]), last_seen)
return cls(cast(State, State.from_dict(json_dict["state"])), last_seen)
class RestoreStateData: