diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 531444b9d1eb..b1786130b582 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -311,7 +311,9 @@ class Entity(ABC): start = timer() - attr = self.capability_attributes or {} + attr = self.capability_attributes + attr = dict(attr) if attr else {} + if not self.available: state = STATE_UNAVAILABLE else: diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index e171a4cade87..5fd88729f080 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -347,6 +347,9 @@ class EntityPlatform: device_id=device_id, known_object_ids=self.entities.keys(), disabled_by=disabled_by, + capabilities=entity.capability_attributes, + supported_features=entity.supported_features, + device_class=entity.device_class, ) entity.registry_entry = entry @@ -387,10 +390,16 @@ class EntityPlatform: # Make sure it is valid in case an entity set the value themselves if not valid_entity_id(entity.entity_id): raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}") - if ( - entity.entity_id in self.entities - or entity.entity_id in self.hass.states.async_entity_ids(self.domain) - ): + + already_exists = entity.entity_id in self.entities + + if not already_exists: + existing = self.hass.states.get(entity.entity_id) + + if existing and not existing.attributes.get("restored"): + already_exists = True + + if already_exists: msg = f"Entity id already exists: {entity.entity_id}" if entity.unique_id is not None: msg += ". Platform {} does not generate unique IDs".format( diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 5eb799658802..77d8ccc00e08 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -15,6 +15,12 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, cast import attr +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + ATTR_SUPPORTED_FEATURES, + EVENT_HOMEASSISTANT_START, + STATE_UNAVAILABLE, +) from homeassistant.core import Event, callback, split_entity_id, valid_entity_id from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.loader import bind_hass @@ -39,6 +45,8 @@ DISABLED_HASS = "hass" DISABLED_USER = "user" DISABLED_INTEGRATION = "integration" +ATTR_RESTORED = "restored" + STORAGE_VERSION = 1 STORAGE_KEY = "core.entity_registry" @@ -66,6 +74,9 @@ class RegistryEntry: ) ), ) + capabilities: Optional[Dict[str, Any]] = attr.ib(default=None) + supported_features: int = attr.ib(default=0) + device_class: Optional[str] = attr.ib(default=None) domain = attr.ib(type=str, init=False, repr=False) @domain.default @@ -142,11 +153,17 @@ class EntityRegistry: platform: str, unique_id: str, *, + # To influence entity ID generation suggested_object_id: Optional[str] = None, + known_object_ids: Optional[Iterable[str]] = None, + # To disable an entity if it gets created + disabled_by: Optional[str] = None, + # Data that we want entry to have config_entry: Optional["ConfigEntry"] = None, device_id: Optional[str] = None, - known_object_ids: Optional[Iterable[str]] = None, - disabled_by: Optional[str] = None, + capabilities: Optional[Dict[str, Any]] = None, + supported_features: Optional[int] = None, + device_class: Optional[str] = None, ) -> RegistryEntry: """Get entity. Create if it doesn't exist.""" config_entry_id = None @@ -160,6 +177,9 @@ class EntityRegistry: entity_id, config_entry_id=config_entry_id or _UNDEF, device_id=device_id or _UNDEF, + capabilities=capabilities or _UNDEF, + supported_features=supported_features or _UNDEF, + device_class=device_class or _UNDEF, # When we changed our slugify algorithm, we invalidated some # stored entity IDs with either a __ or ending in _. # Fix introduced in 0.86 (Jan 23, 2019). Next line can be @@ -187,6 +207,9 @@ class EntityRegistry: unique_id=unique_id, platform=platform, disabled_by=disabled_by, + capabilities=capabilities, + supported_features=supported_features or 0, + device_class=device_class, ) self.entities[entity_id] = entity _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) @@ -253,6 +276,9 @@ class EntityRegistry: device_id=_UNDEF, new_unique_id=_UNDEF, disabled_by=_UNDEF, + capabilities=_UNDEF, + supported_features=_UNDEF, + device_class=_UNDEF, ): """Private facing update properties method.""" old = self.entities[entity_id] @@ -264,6 +290,9 @@ class EntityRegistry: ("config_entry_id", config_entry_id), ("device_id", device_id), ("disabled_by", disabled_by), + ("capabilities", capabilities), + ("supported_features", supported_features), + ("device_class", device_class), ): if value is not _UNDEF and value != getattr(old, attr_name): changes[attr_name] = value @@ -318,6 +347,8 @@ class EntityRegistry: async def async_load(self) -> None: """Load the entity registry.""" + async_setup_entity_restore(self.hass, self) + data = await self.hass.helpers.storage.async_migrator( self.hass.config.path(PATH_REGISTRY), self._store, @@ -336,6 +367,9 @@ class EntityRegistry: platform=entity["platform"], name=entity.get("name"), disabled_by=entity.get("disabled_by"), + capabilities=entity.get("capabilities") or {}, + supported_features=entity.get("supported_features", 0), + device_class=entity.get("device_class"), ) self.entities = entities @@ -359,6 +393,9 @@ class EntityRegistry: "platform": entry.platform, "name": entry.name, "disabled_by": entry.disabled_by, + "capabilities": entry.capabilities, + "supported_features": entry.supported_features, + "device_class": entry.device_class, } for entry in self.entities.values() ] @@ -416,3 +453,53 @@ async def _async_migrate(entities: Dict[str, Any]) -> Dict[str, List[Dict[str, A {"entity_id": entity_id, **info} for entity_id, info in entities.items() ] } + + +@callback +def async_setup_entity_restore( + hass: HomeAssistantType, registry: EntityRegistry +) -> None: + """Set up the entity restore mechanism.""" + + @callback + def cleanup_restored_states(event: Event) -> None: + """Clean up restored states.""" + if event.data["action"] != "remove": + return + + state = hass.states.get(event.data["entity_id"]) + + if state is None or not state.attributes.get(ATTR_RESTORED): + return + + hass.states.async_remove(event.data["entity_id"]) + + hass.bus.async_listen(EVENT_ENTITY_REGISTRY_UPDATED, cleanup_restored_states) + + if hass.is_running: + return + + @callback + def _write_unavailable_states(_: Event) -> None: + """Make sure state machine contains entry for each registered entity.""" + states = hass.states + existing = set(states.async_entity_ids()) + + for entry in registry.entities.values(): + if entry.entity_id in existing or entry.disabled: + continue + + attrs: Dict[str, Any] = {ATTR_RESTORED: True} + + if entry.capabilities: + attrs.update(entry.capabilities) + + if entry.supported_features: + attrs[ATTR_SUPPORTED_FEATURES] = entry.supported_features + + if entry.device_class: + attrs[ATTR_DEVICE_CLASS] = entry.device_class + + states.async_set(entry.entity_id, STATE_UNAVAILABLE, attrs) + + hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states) diff --git a/tests/common.py b/tests/common.py index 5d13da74e880..e57710c46cc5 100644 --- a/tests/common.py +++ b/tests/common.py @@ -906,6 +906,11 @@ class MockEntity(entity.Entity): """Return the unique ID of the entity.""" return self._handle("unique_id") + @property + def state(self): + """Return the state of the entity.""" + return self._handle("state") + @property def available(self): """Return True if entity is available.""" @@ -916,6 +921,21 @@ class MockEntity(entity.Entity): """Info how it links to a device.""" return self._handle("device_info") + @property + def device_class(self): + """Info how device should be classified.""" + return self._handle("device_class") + + @property + def capability_attributes(self): + """Info about capabilities.""" + return self._handle("capability_attributes") + + @property + def supported_features(self): + """Info about supported features.""" + return self._handle("supported_features") + @property def entity_registry_enabled_default(self): """Return if the entity should be enabled when first added to the entity registry.""" diff --git a/tests/components/zwave/test_init.py b/tests/components/zwave/test_init.py index 36c918232205..b7ffaba7e420 100644 --- a/tests/components/zwave/test_init.py +++ b/tests/components/zwave/test_init.py @@ -130,6 +130,7 @@ async def test_auto_heal_midnight(hass, mock_openzwave): time = utc.localize(datetime(2017, 5, 6, 0, 0, 0)) async_fire_time_changed(hass, time) await hass.async_block_till_done() + await hass.async_block_till_done() assert network.heal.called assert len(network.heal.mock_calls) == 1 diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 81fbe2d65200..a069c050cf4f 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -8,7 +8,6 @@ from unittest.mock import Mock, patch import asynctest import pytest -from homeassistant.components import group from homeassistant.const import ENTITY_MATCH_ALL import homeassistant.core as ha from homeassistant.exceptions import PlatformNotReady @@ -285,15 +284,13 @@ async def test_extract_from_service_filter_out_non_existing_entities(hass): async def test_extract_from_service_no_group_expand(hass): """Test not expanding a group.""" component = EntityComponent(_LOGGER, DOMAIN, hass) - test_group = await group.Group.async_create_group( - hass, "test_group", ["light.Ceiling", "light.Kitchen"] - ) - await component.async_add_entities([test_group]) + await component.async_add_entities([MockEntity(entity_id="group.test_group")]) call = ha.ServiceCall("test", "service", {"entity_id": ["group.test_group"]}) extracted = await component.async_extract_from_service(call, expand_group=False) - assert extracted == [test_group] + assert len(extracted) == 1 + assert extracted[0].entity_id == "group.test_group" async def test_setup_dependencies_platform(hass): diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 5909dfaf3aaa..0f73699c8961 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -793,3 +793,44 @@ async def test_entity_disabled_by_integration(hass): assert entry_default.disabled_by is None entry_disabled = registry.async_get_or_create(DOMAIN, DOMAIN, "disabled") assert entry_disabled.disabled_by == "integration" + + +async def test_entity_info_added_to_entity_registry(hass): + """Test entity info is written to entity registry.""" + component = EntityComponent(_LOGGER, DOMAIN, hass, timedelta(seconds=20)) + + entity_default = MockEntity( + unique_id="default", + capability_attributes={"max": 100}, + supported_features=5, + device_class="mock-device-class", + ) + + await component.async_add_entities([entity_default]) + + registry = await hass.helpers.entity_registry.async_get_registry() + + entry_default = registry.async_get_or_create(DOMAIN, DOMAIN, "default") + print(entry_default) + assert entry_default.capabilities == {"max": 100} + assert entry_default.supported_features == 5 + assert entry_default.device_class == "mock-device-class" + + +async def test_override_restored_entities(hass): + """Test that we allow overriding restored entities.""" + registry = mock_registry(hass) + registry.async_get_or_create( + "test_domain", "test_domain", "1234", suggested_object_id="world" + ) + + hass.states.async_set("test_domain.world", "unavailable", {"restored": True}) + + component = EntityComponent(_LOGGER, DOMAIN, hass) + + await component.async_add_entities( + [MockEntity(unique_id="1234", state="on", entity_id="test_domain.world")], True + ) + + state = hass.states.get("test_domain.world") + assert state.state == "on" diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index b07c5237116b..7f45ff0d174d 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -5,7 +5,8 @@ from unittest.mock import patch import asynctest import pytest -from homeassistant.core import callback, valid_entity_id +from homeassistant.const import EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE +from homeassistant.core import CoreState, callback, valid_entity_id from homeassistant.helpers import entity_registry from tests.common import MockConfigEntry, flush_store, mock_registry @@ -57,6 +58,52 @@ def test_get_or_create_suggested_object_id(registry): assert entry.entity_id == "light.beer" +def test_get_or_create_updates_data(registry): + """Test that we update data in get_or_create.""" + orig_config_entry = MockConfigEntry(domain="light") + + orig_entry = registry.async_get_or_create( + "light", + "hue", + "5678", + config_entry=orig_config_entry, + device_id="mock-dev-id", + capabilities={"max": 100}, + supported_features=5, + device_class="mock-device-class", + disabled_by=entity_registry.DISABLED_HASS, + ) + + assert orig_entry.config_entry_id == orig_config_entry.entry_id + assert orig_entry.device_id == "mock-dev-id" + assert orig_entry.capabilities == {"max": 100} + assert orig_entry.supported_features == 5 + assert orig_entry.device_class == "mock-device-class" + assert orig_entry.disabled_by == entity_registry.DISABLED_HASS + + new_config_entry = MockConfigEntry(domain="light") + + new_entry = registry.async_get_or_create( + "light", + "hue", + "5678", + config_entry=new_config_entry, + device_id="new-mock-dev-id", + capabilities={"new-max": 100}, + supported_features=10, + device_class="new-mock-device-class", + disabled_by=entity_registry.DISABLED_USER, + ) + + assert new_entry.config_entry_id == new_config_entry.entry_id + assert new_entry.device_id == "new-mock-dev-id" + assert new_entry.capabilities == {"new-max": 100} + assert new_entry.supported_features == 10 + assert new_entry.device_class == "new-mock-device-class" + # Should not be updated + assert new_entry.disabled_by == entity_registry.DISABLED_HASS + + def test_get_or_create_suggested_object_id_conflict_register(registry): """Test that we don't generate an entity id that is already registered.""" entry = registry.async_get_or_create( @@ -91,7 +138,15 @@ async def test_loading_saving_data(hass, registry): orig_entry1 = registry.async_get_or_create("light", "hue", "1234") orig_entry2 = registry.async_get_or_create( - "light", "hue", "5678", config_entry=mock_config + "light", + "hue", + "5678", + device_id="mock-dev-id", + config_entry=mock_config, + capabilities={"max": 100}, + supported_features=5, + device_class="mock-device-class", + disabled_by=entity_registry.DISABLED_HASS, ) assert len(registry.entities) == 2 @@ -104,13 +159,17 @@ async def test_loading_saving_data(hass, registry): # Ensure same order assert list(registry.entities) == list(registry2.entities) new_entry1 = registry.async_get_or_create("light", "hue", "1234") - new_entry2 = registry.async_get_or_create( - "light", "hue", "5678", config_entry=mock_config - ) + new_entry2 = registry.async_get_or_create("light", "hue", "5678") assert orig_entry1 == new_entry1 assert orig_entry2 == new_entry2 + assert new_entry2.device_id == "mock-dev-id" + assert new_entry2.disabled_by == entity_registry.DISABLED_HASS + assert new_entry2.capabilities == {"max": 100} + assert new_entry2.supported_features == 5 + assert new_entry2.device_class == "mock-device-class" + def test_generate_entity_considers_registered_entities(registry): """Test that we don't create entity id that are already registered.""" @@ -417,3 +476,62 @@ async def test_disabled_by_system_options(registry): "light", "hue", "BBBB", config_entry=mock_config, disabled_by="user" ) assert entry2.disabled_by == "user" + + +async def test_restore_states(hass): + """Test restoring states.""" + hass.state = CoreState.not_running + + registry = await entity_registry.async_get_registry(hass) + + registry.async_get_or_create( + "light", "hue", "1234", suggested_object_id="simple", + ) + # Should not be created + registry.async_get_or_create( + "light", + "hue", + "5678", + suggested_object_id="disabled", + disabled_by=entity_registry.DISABLED_HASS, + ) + registry.async_get_or_create( + "light", + "hue", + "9012", + suggested_object_id="all_info_set", + capabilities={"max": 100}, + supported_features=5, + device_class="mock-device-class", + ) + + hass.bus.async_fire(EVENT_HOMEASSISTANT_START, {}) + await hass.async_block_till_done() + + simple = hass.states.get("light.simple") + assert simple is not None + assert simple.state == STATE_UNAVAILABLE + assert simple.attributes == {"restored": True} + + disabled = hass.states.get("light.disabled") + assert disabled is None + + all_info_set = hass.states.get("light.all_info_set") + assert all_info_set is not None + assert all_info_set.state == STATE_UNAVAILABLE + assert all_info_set.attributes == { + "max": 100, + "supported_features": 5, + "device_class": "mock-device-class", + "restored": True, + } + + registry.async_remove("light.disabled") + registry.async_remove("light.simple") + registry.async_remove("light.all_info_set") + + await hass.async_block_till_done() + + assert hass.states.get("light.simple") is None + assert hass.states.get("light.disabled") is None + assert hass.states.get("light.all_info_set") is None