Store capabilities and supported features in entity registry, restore registered entities on startup (#30094)

* Store capabilities and supported features in entity registry

* Restore states at startup

* Restore non-disabled entities on HA start

* Fix test

* Pass device class from entity platform

* Clean up restored entities from state machine

* Fix Z-Wave test?
This commit is contained in:
Paulus Schoutsen 2019-12-31 14:29:43 +01:00 committed by GitHub
parent 2c1a7a54cd
commit bb14a083f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 293 additions and 18 deletions

View file

@ -311,7 +311,9 @@ class Entity(ABC):
start = timer()
attr = self.capability_attributes or {}
attr = self.capability_attributes
attr = dict(attr) if attr else {}
if not self.available:
state = STATE_UNAVAILABLE
else:

View file

@ -347,6 +347,9 @@ class EntityPlatform:
device_id=device_id,
known_object_ids=self.entities.keys(),
disabled_by=disabled_by,
capabilities=entity.capability_attributes,
supported_features=entity.supported_features,
device_class=entity.device_class,
)
entity.registry_entry = entry
@ -387,10 +390,16 @@ class EntityPlatform:
# Make sure it is valid in case an entity set the value themselves
if not valid_entity_id(entity.entity_id):
raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}")
if (
entity.entity_id in self.entities
or entity.entity_id in self.hass.states.async_entity_ids(self.domain)
):
already_exists = entity.entity_id in self.entities
if not already_exists:
existing = self.hass.states.get(entity.entity_id)
if existing and not existing.attributes.get("restored"):
already_exists = True
if already_exists:
msg = f"Entity id already exists: {entity.entity_id}"
if entity.unique_id is not None:
msg += ". Platform {} does not generate unique IDs".format(

View file

@ -15,6 +15,12 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast
import attr
from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_SUPPORTED_FEATURES,
EVENT_HOMEASSISTANT_START,
STATE_UNAVAILABLE,
)
from homeassistant.core import Event, callback, split_entity_id, valid_entity_id
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from homeassistant.loader import bind_hass
@ -39,6 +45,8 @@ DISABLED_HASS = "hass"
DISABLED_USER = "user"
DISABLED_INTEGRATION = "integration"
ATTR_RESTORED = "restored"
STORAGE_VERSION = 1
STORAGE_KEY = "core.entity_registry"
@ -66,6 +74,9 @@ class RegistryEntry:
)
),
)
capabilities: Optional[Dict[str, Any]] = attr.ib(default=None)
supported_features: int = attr.ib(default=0)
device_class: Optional[str] = attr.ib(default=None)
domain = attr.ib(type=str, init=False, repr=False)
@domain.default
@ -142,11 +153,17 @@ class EntityRegistry:
platform: str,
unique_id: str,
*,
# To influence entity ID generation
suggested_object_id: Optional[str] = None,
known_object_ids: Optional[Iterable[str]] = None,
# To disable an entity if it gets created
disabled_by: Optional[str] = None,
# Data that we want entry to have
config_entry: Optional["ConfigEntry"] = None,
device_id: Optional[str] = None,
known_object_ids: Optional[Iterable[str]] = None,
disabled_by: Optional[str] = None,
capabilities: Optional[Dict[str, Any]] = None,
supported_features: Optional[int] = None,
device_class: Optional[str] = None,
) -> RegistryEntry:
"""Get entity. Create if it doesn't exist."""
config_entry_id = None
@ -160,6 +177,9 @@ class EntityRegistry:
entity_id,
config_entry_id=config_entry_id or _UNDEF,
device_id=device_id or _UNDEF,
capabilities=capabilities or _UNDEF,
supported_features=supported_features or _UNDEF,
device_class=device_class or _UNDEF,
# When we changed our slugify algorithm, we invalidated some
# stored entity IDs with either a __ or ending in _.
# Fix introduced in 0.86 (Jan 23, 2019). Next line can be
@ -187,6 +207,9 @@ class EntityRegistry:
unique_id=unique_id,
platform=platform,
disabled_by=disabled_by,
capabilities=capabilities,
supported_features=supported_features or 0,
device_class=device_class,
)
self.entities[entity_id] = entity
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)
@ -253,6 +276,9 @@ class EntityRegistry:
device_id=_UNDEF,
new_unique_id=_UNDEF,
disabled_by=_UNDEF,
capabilities=_UNDEF,
supported_features=_UNDEF,
device_class=_UNDEF,
):
"""Private facing update properties method."""
old = self.entities[entity_id]
@ -264,6 +290,9 @@ class EntityRegistry:
("config_entry_id", config_entry_id),
("device_id", device_id),
("disabled_by", disabled_by),
("capabilities", capabilities),
("supported_features", supported_features),
("device_class", device_class),
):
if value is not _UNDEF and value != getattr(old, attr_name):
changes[attr_name] = value
@ -318,6 +347,8 @@ class EntityRegistry:
async def async_load(self) -> None:
"""Load the entity registry."""
async_setup_entity_restore(self.hass, self)
data = await self.hass.helpers.storage.async_migrator(
self.hass.config.path(PATH_REGISTRY),
self._store,
@ -336,6 +367,9 @@ class EntityRegistry:
platform=entity["platform"],
name=entity.get("name"),
disabled_by=entity.get("disabled_by"),
capabilities=entity.get("capabilities") or {},
supported_features=entity.get("supported_features", 0),
device_class=entity.get("device_class"),
)
self.entities = entities
@ -359,6 +393,9 @@ class EntityRegistry:
"platform": entry.platform,
"name": entry.name,
"disabled_by": entry.disabled_by,
"capabilities": entry.capabilities,
"supported_features": entry.supported_features,
"device_class": entry.device_class,
}
for entry in self.entities.values()
]
@ -416,3 +453,53 @@ async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, A
{"entity_id": entity_id, **info} for entity_id, info in entities.items()
]
}
@callback
def async_setup_entity_restore(
hass: HomeAssistantType, registry: EntityRegistry
) -> None:
"""Set up the entity restore mechanism."""
@callback
def cleanup_restored_states(event: Event) -> None:
"""Clean up restored states."""
if event.data["action"] != "remove":
return
state = hass.states.get(event.data["entity_id"])
if state is None or not state.attributes.get(ATTR_RESTORED):
return
hass.states.async_remove(event.data["entity_id"])
hass.bus.async_listen(EVENT_ENTITY_REGISTRY_UPDATED, cleanup_restored_states)
if hass.is_running:
return
@callback
def _write_unavailable_states(_: Event) -> None:
"""Make sure state machine contains entry for each registered entity."""
states = hass.states
existing = set(states.async_entity_ids())
for entry in registry.entities.values():
if entry.entity_id in existing or entry.disabled:
continue
attrs: Dict[str, Any] = {ATTR_RESTORED: True}
if entry.capabilities:
attrs.update(entry.capabilities)
if entry.supported_features:
attrs[ATTR_SUPPORTED_FEATURES] = entry.supported_features
if entry.device_class:
attrs[ATTR_DEVICE_CLASS] = entry.device_class
states.async_set(entry.entity_id, STATE_UNAVAILABLE, attrs)
hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states)

View file

@ -906,6 +906,11 @@ class MockEntity(entity.Entity):
"""Return the unique ID of the entity."""
return self._handle("unique_id")
@property
def state(self):
"""Return the state of the entity."""
return self._handle("state")
@property
def available(self):
"""Return True if entity is available."""
@ -916,6 +921,21 @@ class MockEntity(entity.Entity):
"""Info how it links to a device."""
return self._handle("device_info")
@property
def device_class(self):
"""Info how device should be classified."""
return self._handle("device_class")
@property
def capability_attributes(self):
"""Info about capabilities."""
return self._handle("capability_attributes")
@property
def supported_features(self):
"""Info about supported features."""
return self._handle("supported_features")
@property
def entity_registry_enabled_default(self):
"""Return if the entity should be enabled when first added to the entity registry."""

View file

@ -130,6 +130,7 @@ async def test_auto_heal_midnight(hass, mock_openzwave):
time = utc.localize(datetime(2017, 5, 6, 0, 0, 0))
async_fire_time_changed(hass, time)
await hass.async_block_till_done()
await hass.async_block_till_done()
assert network.heal.called
assert len(network.heal.mock_calls) == 1

View file

@ -8,7 +8,6 @@ from unittest.mock import Mock, patch
import asynctest
import pytest
from homeassistant.components import group
from homeassistant.const import ENTITY_MATCH_ALL
import homeassistant.core as ha
from homeassistant.exceptions import PlatformNotReady
@ -285,15 +284,13 @@ async def test_extract_from_service_filter_out_non_existing_entities(hass):
async def test_extract_from_service_no_group_expand(hass):
"""Test not expanding a group."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
test_group = await group.Group.async_create_group(
hass, "test_group", ["light.Ceiling", "light.Kitchen"]
)
await component.async_add_entities([test_group])
await component.async_add_entities([MockEntity(entity_id="group.test_group")])
call = ha.ServiceCall("test", "service", {"entity_id": ["group.test_group"]})
extracted = await component.async_extract_from_service(call, expand_group=False)
assert extracted == [test_group]
assert len(extracted) == 1
assert extracted[0].entity_id == "group.test_group"
async def test_setup_dependencies_platform(hass):

View file

@ -793,3 +793,44 @@ async def test_entity_disabled_by_integration(hass):
assert entry_default.disabled_by is None
entry_disabled = registry.async_get_or_create(DOMAIN, DOMAIN, "disabled")
assert entry_disabled.disabled_by == "integration"
async def test_entity_info_added_to_entity_registry(hass):
"""Test entity info is written to entity registry."""
component = EntityComponent(_LOGGER, DOMAIN, hass, timedelta(seconds=20))
entity_default = MockEntity(
unique_id="default",
capability_attributes={"max": 100},
supported_features=5,
device_class="mock-device-class",
)
await component.async_add_entities([entity_default])
registry = await hass.helpers.entity_registry.async_get_registry()
entry_default = registry.async_get_or_create(DOMAIN, DOMAIN, "default")
print(entry_default)
assert entry_default.capabilities == {"max": 100}
assert entry_default.supported_features == 5
assert entry_default.device_class == "mock-device-class"
async def test_override_restored_entities(hass):
"""Test that we allow overriding restored entities."""
registry = mock_registry(hass)
registry.async_get_or_create(
"test_domain", "test_domain", "1234", suggested_object_id="world"
)
hass.states.async_set("test_domain.world", "unavailable", {"restored": True})
component = EntityComponent(_LOGGER, DOMAIN, hass)
await component.async_add_entities(
[MockEntity(unique_id="1234", state="on", entity_id="test_domain.world")], True
)
state = hass.states.get("test_domain.world")
assert state.state == "on"

View file

@ -5,7 +5,8 @@ from unittest.mock import patch
import asynctest
import pytest
from homeassistant.core import callback, valid_entity_id
from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE
from homeassistant.core import CoreState, callback, valid_entity_id
from homeassistant.helpers import entity_registry
from tests.common import MockConfigEntry, flush_store, mock_registry
@ -57,6 +58,52 @@ def test_get_or_create_suggested_object_id(registry):
assert entry.entity_id == "light.beer"
def test_get_or_create_updates_data(registry):
"""Test that we update data in get_or_create."""
orig_config_entry = MockConfigEntry(domain="light")
orig_entry = registry.async_get_or_create(
"light",
"hue",
"5678",
config_entry=orig_config_entry,
device_id="mock-dev-id",
capabilities={"max": 100},
supported_features=5,
device_class="mock-device-class",
disabled_by=entity_registry.DISABLED_HASS,
)
assert orig_entry.config_entry_id == orig_config_entry.entry_id
assert orig_entry.device_id == "mock-dev-id"
assert orig_entry.capabilities == {"max": 100}
assert orig_entry.supported_features == 5
assert orig_entry.device_class == "mock-device-class"
assert orig_entry.disabled_by == entity_registry.DISABLED_HASS
new_config_entry = MockConfigEntry(domain="light")
new_entry = registry.async_get_or_create(
"light",
"hue",
"5678",
config_entry=new_config_entry,
device_id="new-mock-dev-id",
capabilities={"new-max": 100},
supported_features=10,
device_class="new-mock-device-class",
disabled_by=entity_registry.DISABLED_USER,
)
assert new_entry.config_entry_id == new_config_entry.entry_id
assert new_entry.device_id == "new-mock-dev-id"
assert new_entry.capabilities == {"new-max": 100}
assert new_entry.supported_features == 10
assert new_entry.device_class == "new-mock-device-class"
# Should not be updated
assert new_entry.disabled_by == entity_registry.DISABLED_HASS
def test_get_or_create_suggested_object_id_conflict_register(registry):
"""Test that we don't generate an entity id that is already registered."""
entry = registry.async_get_or_create(
@ -91,7 +138,15 @@ async def test_loading_saving_data(hass, registry):
orig_entry1 = registry.async_get_or_create("light", "hue", "1234")
orig_entry2 = registry.async_get_or_create(
"light", "hue", "5678", config_entry=mock_config
"light",
"hue",
"5678",
device_id="mock-dev-id",
config_entry=mock_config,
capabilities={"max": 100},
supported_features=5,
device_class="mock-device-class",
disabled_by=entity_registry.DISABLED_HASS,
)
assert len(registry.entities) == 2
@ -104,13 +159,17 @@ async def test_loading_saving_data(hass, registry):
# Ensure same order
assert list(registry.entities) == list(registry2.entities)
new_entry1 = registry.async_get_or_create("light", "hue", "1234")
new_entry2 = registry.async_get_or_create(
"light", "hue", "5678", config_entry=mock_config
)
new_entry2 = registry.async_get_or_create("light", "hue", "5678")
assert orig_entry1 == new_entry1
assert orig_entry2 == new_entry2
assert new_entry2.device_id == "mock-dev-id"
assert new_entry2.disabled_by == entity_registry.DISABLED_HASS
assert new_entry2.capabilities == {"max": 100}
assert new_entry2.supported_features == 5
assert new_entry2.device_class == "mock-device-class"
def test_generate_entity_considers_registered_entities(registry):
"""Test that we don't create entity id that are already registered."""
@ -417,3 +476,62 @@ async def test_disabled_by_system_options(registry):
"light", "hue", "BBBB", config_entry=mock_config, disabled_by="user"
)
assert entry2.disabled_by == "user"
async def test_restore_states(hass):
"""Test restoring states."""
hass.state = CoreState.not_running
registry = await entity_registry.async_get_registry(hass)
registry.async_get_or_create(
"light", "hue", "1234", suggested_object_id="simple",
)
# Should not be created
registry.async_get_or_create(
"light",
"hue",
"5678",
suggested_object_id="disabled",
disabled_by=entity_registry.DISABLED_HASS,
)
registry.async_get_or_create(
"light",
"hue",
"9012",
suggested_object_id="all_info_set",
capabilities={"max": 100},
supported_features=5,
device_class="mock-device-class",
)
hass.bus.async_fire(EVENT_HOMEASSISTANT_START, {})
await hass.async_block_till_done()
simple = hass.states.get("light.simple")
assert simple is not None
assert simple.state == STATE_UNAVAILABLE
assert simple.attributes == {"restored": True}
disabled = hass.states.get("light.disabled")
assert disabled is None
all_info_set = hass.states.get("light.all_info_set")
assert all_info_set is not None
assert all_info_set.state == STATE_UNAVAILABLE
assert all_info_set.attributes == {
"max": 100,
"supported_features": 5,
"device_class": "mock-device-class",
"restored": True,
}
registry.async_remove("light.disabled")
registry.async_remove("light.simple")
registry.async_remove("light.all_info_set")
await hass.async_block_till_done()
assert hass.states.get("light.simple") is None
assert hass.states.get("light.disabled") is None
assert hass.states.get("light.all_info_set") is None