Improve typing of state event helpers (#120639)

This commit is contained in:
Erik Montnemery 2024-06-27 13:08:19 +02:00 committed by GitHub
parent 54a5a3e3fb
commit a165064e9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 10 deletions

View file

@ -158,26 +158,29 @@ class ConfigSource(enum.StrEnum):
YAML = "yaml"
class EventStateChangedData(TypedDict):
class EventStateEventData(TypedDict):
"""Base class for EVENT_STATE_CHANGED and EVENT_STATE_CHANGED data."""
entity_id: str
new_state: State | None
class EventStateChangedData(EventStateEventData):
"""EVENT_STATE_CHANGED data.
A state changed event is fired when on state write when the state is changed.
"""
entity_id: str
old_state: State | None
new_state: State | None
class EventStateReportedData(TypedDict):
class EventStateReportedData(EventStateEventData):
"""EVENT_STATE_REPORTED data.
A state reported event is fired when on state write when the state is unchanged.
"""
entity_id: str
old_last_reported: datetime.datetime
new_state: State | None
# SOURCE_* are deprecated as of Home Assistant 2022.2, use ConfigSource instead

View file

@ -27,6 +27,7 @@ from homeassistant.core import (
Event,
# Explicit reexport of 'EventStateChangedData' for backwards compatibility
EventStateChangedData as EventStateChangedData, # noqa: PLC0414
EventStateEventData,
EventStateReportedData,
HassJob,
HassJobType,
@ -89,6 +90,7 @@ RANDOM_MICROSECOND_MIN = 50000
RANDOM_MICROSECOND_MAX = 500000
_TypedDictT = TypeVar("_TypedDictT", bound=Mapping[str, Any])
_StateEventDataT = TypeVar("_StateEventDataT", bound=EventStateEventData)
@dataclass(slots=True, frozen=True)
@ -329,8 +331,8 @@ def async_track_state_change_event(
@callback
def _async_dispatch_entity_id_event(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
event: Event[_TypedDictT],
callbacks: dict[str, list[HassJob[[Event[_StateEventDataT]], Any]]],
event: Event[_StateEventDataT],
) -> None:
"""Dispatch to listeners."""
if not (callbacks_list := callbacks.get(event.data["entity_id"])):
@ -349,8 +351,8 @@ def _async_dispatch_entity_id_event(
@callback
def _async_state_filter(
hass: HomeAssistant,
callbacks: dict[str, list[HassJob[[Event[_TypedDictT]], Any]]],
event_data: _TypedDictT,
callbacks: dict[str, list[HassJob[[Event[_StateEventDataT]], Any]]],
event_data: _StateEventDataT,
) -> bool:
"""Filter state changes by entity_id."""
return event_data["entity_id"] in callbacks