Make EntityComponent generic (#78473)

This commit is contained in:
epenet 2022-09-14 20:16:23 +02:00 committed by GitHub
parent fd05d949cc
commit 996bcbdac6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 53 additions and 50 deletions

View file

@ -154,12 +154,12 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component: EntityComponent = hass.data[DOMAIN]
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if entity_id in cast(AutomationEntity, automation_entity).referenced_entities
if entity_id in automation_entity.referenced_entities
]
@ -169,12 +169,12 @@ def entities_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component: EntityComponent = hass.data[DOMAIN]
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None:
return []
return list(cast(AutomationEntity, automation_entity).referenced_entities)
return list(automation_entity.referenced_entities)
@callback
@ -183,12 +183,12 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component: EntityComponent = hass.data[DOMAIN]
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if device_id in cast(AutomationEntity, automation_entity).referenced_devices
if device_id in automation_entity.referenced_devices
]
@ -198,12 +198,12 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component: EntityComponent = hass.data[DOMAIN]
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None:
return []
return list(cast(AutomationEntity, automation_entity).referenced_devices)
return list(automation_entity.referenced_devices)
@callback
@ -212,12 +212,12 @@ def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component: EntityComponent = hass.data[DOMAIN]
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if area_id in cast(AutomationEntity, automation_entity).referenced_areas
if area_id in automation_entity.referenced_areas
]
@ -227,12 +227,12 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
if DOMAIN not in hass.data:
return []
component: EntityComponent = hass.data[DOMAIN]
component: EntityComponent[AutomationEntity] = hass.data[DOMAIN]
if (automation_entity := component.get_entity(entity_id)) is None:
return []
return list(cast(AutomationEntity, automation_entity).referenced_areas)
return list(automation_entity.referenced_areas)
@callback
@ -252,7 +252,9 @@ def automations_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up all automations."""
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
hass.data[DOMAIN] = component = EntityComponent[AutomationEntity](
LOGGER, DOMAIN, hass
)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -93,7 +93,7 @@ CONFIG_SCHEMA = vol.Schema(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the counters."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[Counter](_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager()
yaml_collection = collection.YamlCollection(

View file

@ -67,9 +67,9 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up is called when Home Assistant is loading our component."""
dominos = Dominos(hass, config)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[DominosOrder](_LOGGER, DOMAIN, hass)
hass.data[DOMAIN] = {}
entities = []
entities: list[DominosOrder] = []
conf = config[DOMAIN]
hass.services.register(

View file

@ -85,7 +85,9 @@ class FaceInformation(TypedDict, total=False):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the image processing."""
component = EntityComponent(_LOGGER, DOMAIN, hass, SCAN_INTERVAL)
component = EntityComponent[ImageProcessingEntity](
_LOGGER, DOMAIN, hass, SCAN_INTERVAL
)
await component.async_setup(config)

View file

@ -92,7 +92,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input boolean."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[InputBoolean](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -77,7 +77,7 @@ class InputButtonStorageCollection(collection.StorageCollection):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input button."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[InputButton](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -130,7 +130,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input datetime."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[InputDatetime](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -107,7 +107,7 @@ STORAGE_VERSION = 1
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input slider."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[InputNumber](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -132,7 +132,7 @@ class InputSelectStore(Store):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input select."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[InputSelect](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -107,7 +107,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input text."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[InputText](_LOGGER, DOMAIN, hass)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -83,7 +83,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
mailboxes.append(mailbox)
mailbox_entity = MailboxEntity(mailbox)
component = EntityComponent(
component = EntityComponent[MailboxEntity](
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL
)
await component.async_add_entities([mailbox_entity])

View file

@ -326,7 +326,7 @@ The following persons point at invalid users:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the person component."""
entity_component = EntityComponent(_LOGGER, DOMAIN, hass)
entity_component = EntityComponent[Person](_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager()
yaml_collection = collection.YamlCollection(
logging.getLogger(f"{__name__}.yaml_collection"), id_manager

View file

@ -111,7 +111,7 @@ CONFIG_SCHEMA = vol.Schema({DOMAIN: {cv.string: PLANT_SCHEMA}}, extra=vol.ALLOW_
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Plant component."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[Plant](_LOGGER, DOMAIN, hass)
entities = []
for plant_name, plant_config in config[DOMAIN].items():

View file

@ -52,7 +52,7 @@ SERVICE_SCHEMA_COMPLETE_TASK = vol.Schema({vol.Required(CONF_ID): cv.string})
def setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Remember the milk component."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[RememberTheMilk](_LOGGER, DOMAIN, hass)
stored_rtm_config = RememberTheMilkConfiguration(hass)
for rtm_config in config[DOMAIN]:

View file

@ -27,6 +27,7 @@ from homeassistant.const import (
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers import discovery, template
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import (
DEFAULT_SCAN_INTERVAL,
EntityComponent,
@ -53,7 +54,7 @@ COORDINATOR_AWARE_PLATFORMS = [SENSOR_DOMAIN, BINARY_SENSOR_DOMAIN]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the rest platforms."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[Entity](_LOGGER, DOMAIN, hass)
_async_setup_shared_data(hass)
async def reload_service_handler(service: ServiceCall) -> None:

View file

@ -154,7 +154,7 @@ ENTITY_SCHEMA = vol.Schema(
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input select."""
component = EntityComponent(LOGGER, DOMAIN, hass)
component = EntityComponent[Schedule](LOGGER, DOMAIN, hass)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED

View file

@ -183,7 +183,7 @@ def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Load the scripts from the configuration."""
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
hass.data[DOMAIN] = component = EntityComponent[ScriptEntity](LOGGER, DOMAIN, hass)
# Process integration platforms right away since
# we will create entities before firing EVENT_COMPONENT_LOADED
@ -205,9 +205,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def turn_on_service(service: ServiceCall) -> None:
"""Call a service to turn script on."""
variables = service.data.get(ATTR_VARIABLES)
script_entities: list[ScriptEntity] = cast(
list[ScriptEntity], await component.async_extract_from_service(service)
)
script_entities = await component.async_extract_from_service(service)
for script_entity in script_entities:
await script_entity.async_turn_on(
variables=variables, context=service.context, wait=False
@ -216,9 +214,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def turn_off_service(service: ServiceCall) -> None:
"""Cancel a script."""
# Stopping a script is ok to be done in parallel
script_entities: list[ScriptEntity] = cast(
list[ScriptEntity], await component.async_extract_from_service(service)
)
script_entities = await component.async_extract_from_service(service)
if not script_entities:
return
@ -232,9 +228,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def toggle_service(service: ServiceCall) -> None:
"""Toggle a script."""
script_entities: list[ScriptEntity] = cast(
list[ScriptEntity], await component.async_extract_from_service(service)
)
script_entities = await component.async_extract_from_service(service)
for script_entity in script_entities:
await script_entity.async_toggle(context=service.context, wait=False)

View file

@ -106,7 +106,7 @@ RELOAD_SERVICE_SCHEMA = vol.Schema({})
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up an input select."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component = EntityComponent[Timer](_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager()
yaml_collection = collection.YamlCollection(

View file

@ -185,7 +185,7 @@ class ZoneStorageCollection(collection.StorageCollection):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up configured zones as well as Home Assistant zone if necessary."""
component = entity_component.EntityComponent(_LOGGER, DOMAIN, hass)
component = entity_component.EntityComponent[Zone](_LOGGER, DOMAIN, hass)
id_manager = collection.IDManager()
yaml_collection = collection.IDLessCollection(

View file

@ -7,7 +7,7 @@ from datetime import timedelta
from itertools import chain
import logging
from types import ModuleType
from typing import Any
from typing import Any, Generic, TypeVar
import voluptuous as vol
@ -30,6 +30,8 @@ from .typing import ConfigType, DiscoveryInfoType
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
DATA_INSTANCES = "entity_components"
_EntityT = TypeVar("_EntityT", bound=entity.Entity)
@bind_hass
async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
@ -52,7 +54,7 @@ async def async_update_entity(hass: HomeAssistant, entity_id: str) -> None:
await entity_obj.async_update_ha_state(True)
class EntityComponent:
class EntityComponent(Generic[_EntityT]):
"""The EntityComponent manages platforms that manages entities.
This class has the following responsibilities:
@ -86,18 +88,19 @@ class EntityComponent:
hass.data.setdefault(DATA_INSTANCES, {})[domain] = self
@property
def entities(self) -> Iterable[entity.Entity]:
def entities(self) -> Iterable[_EntityT]:
"""Return an iterable that returns all entities."""
return chain.from_iterable(
platform.entities.values() for platform in self._platforms.values()
platform.entities.values() # type: ignore[misc]
for platform in self._platforms.values()
)
def get_entity(self, entity_id: str) -> entity.Entity | None:
def get_entity(self, entity_id: str) -> _EntityT | None:
"""Get an entity."""
for platform in self._platforms.values():
entity_obj = platform.entities.get(entity_id)
if entity_obj is not None:
return entity_obj
return entity_obj # type: ignore[return-value]
return None
def setup(self, config: ConfigType) -> None:
@ -176,14 +179,14 @@ class EntityComponent:
async def async_extract_from_service(
self, service_call: ServiceCall, expand_group: bool = True
) -> list[entity.Entity]:
) -> list[_EntityT]:
"""Extract all known and available entities from a service call.
Will return an empty list if entities specified but unknown.
This method must be run in the event loop.
"""
return await service.async_extract_entities(
return await service.async_extract_entities( # type: ignore[return-value]
self.hass, self.entities, service_call, expand_group
)

View file

@ -14,6 +14,7 @@ from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component
from . import config_per_platform
from .entity import Entity
from .entity_component import EntityComponent
from .entity_platform import EntityPlatform, async_get_platforms
from .service import async_register_admin_service
@ -120,7 +121,7 @@ async def _async_setup_platform(
)
return
entity_component: EntityComponent = hass.data[integration_platform]
entity_component: EntityComponent[Entity] = hass.data[integration_platform]
tasks = [
entity_component.async_setup_platform(integration_name, p_config)
for p_config in platform_configs