Reload config entry when entity enabled in entity registry, remove entity if disabled. (#26120)

* Reload config entry when disabled_by updated in entity registry

* Add types

* Remove entities that get disabled

* Remove unnecessary domain checks.

* Attach handler in async_setup

* Remove unused var

* Type

* Fix test

* Fix tests
This commit is contained in:
Paulus Schoutsen 2019-08-22 17:32:43 -07:00 committed by Andrew Sayre
parent a4eeaac24c
commit f704a8e90e
7 changed files with 219 additions and 12 deletions

View file

@ -3,13 +3,7 @@ import asyncio
import logging
import functools
import uuid
from typing import (
Any,
Callable,
List,
Optional,
Set, # noqa pylint: disable=unused-import
)
from typing import Any, Callable, List, Optional, Set
import weakref
import attr
@ -19,6 +13,7 @@ from homeassistant.core import callback, HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ConfigEntryNotReady
from homeassistant.setup import async_setup_component, async_process_deps_reqs
from homeassistant.util.decorator import Registry
from homeassistant.helpers import entity_registry
# mypy: allow-untyped-defs
@ -161,8 +156,6 @@ class ConfigEntry:
try:
component = integration.get_component()
if self.domain == integration.domain:
integration.get_platform("config_flow")
except ImportError as err:
_LOGGER.error(
"Error importing integration %s to set up %s config entry: %s",
@ -174,8 +167,20 @@ class ConfigEntry:
self.state = ENTRY_STATE_SETUP_ERROR
return
# Perform migration
if integration.domain == self.domain:
if self.domain == integration.domain:
try:
integration.get_platform("config_flow")
except ImportError as err:
_LOGGER.error(
"Error importing platform config_flow from integration %s to set up %s config entry: %s",
integration.domain,
self.domain,
err,
)
self.state = ENTRY_STATE_SETUP_ERROR
return
# Perform migration
if not await self.async_migrate(hass):
self.state = ENTRY_STATE_MIGRATION_ERROR
return
@ -383,6 +388,7 @@ class ConfigEntries:
self._hass_config = hass_config
self._entries = [] # type: List[ConfigEntry]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
EntityRegistryDisabledHandler(hass).async_setup()
@callback
def async_domains(self) -> List[str]:
@ -757,3 +763,91 @@ class SystemOptions:
def as_dict(self):
"""Return dictionary version of this config entrys system options."""
return {"disable_new_entities": self.disable_new_entities}
class EntityRegistryDisabledHandler:
"""Handler to handle when entities related to config entries updating disabled_by."""
RELOAD_AFTER_UPDATE_DELAY = 30
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the handler."""
self.hass = hass
self.registry: Optional[entity_registry.EntityRegistry] = None
self.changed: Set[str] = set()
self._remove_call_later: Optional[Callable[[], None]] = None
@callback
def async_setup(self) -> None:
"""Set up the disable handler."""
self.hass.bus.async_listen(
entity_registry.EVENT_ENTITY_REGISTRY_UPDATED, self._handle_entry_updated
)
async def _handle_entry_updated(self, event):
"""Handle entity registry entry update."""
if (
event.data["action"] != "update"
or "disabled_by" not in event.data["changes"]
):
return
if self.registry is None:
self.registry = await entity_registry.async_get_registry(self.hass)
entity_entry = self.registry.async_get(event.data["entity_id"])
if (
# Stop if no entry found
entity_entry is None
# Stop if entry not connected to config entry
or entity_entry.config_entry_id is None
# Stop if the entry got disabled. In that case the entity handles it
# themselves.
or entity_entry.disabled_by
):
return
config_entry = self.hass.config_entries.async_get_entry(
entity_entry.config_entry_id
)
if config_entry.entry_id not in self.changed and await support_entry_unload(
self.hass, config_entry.domain
):
self.changed.add(config_entry.entry_id)
if not self.changed:
return
# We are going to delay reloading on *every* entity registry change so that
# if a user is happily clicking along, it will only reload at the end.
if self._remove_call_later:
self._remove_call_later()
self._remove_call_later = self.hass.helpers.event.async_call_later(
self.RELOAD_AFTER_UPDATE_DELAY, self._handle_reload
)
async def _handle_reload(self, _now):
"""Handle a reload."""
self._remove_call_later = None
to_reload = self.changed
self.changed = set()
_LOGGER.info(
"Reloading config entries because disabled_by changed in entity registry: %s",
", ".join(self.changed),
)
await asyncio.gather(
*[self.hass.config_entries.async_reload(entry_id) for entry_id in to_reload]
)
async def support_entry_unload(hass: HomeAssistant, domain: str) -> bool:
"""Test if a domain supports entry unloading."""
integration = await loader.async_get_integration(hass, domain)
component = integration.get_component()
return hasattr(component, "async_unload_entry")

View file

@ -503,6 +503,10 @@ class Entity:
old = self.registry_entry
self.registry_entry = ent_reg.async_get(data["entity_id"])
if self.registry_entry.disabled_by is not None:
await self.async_remove()
return
if self.registry_entry.entity_id == old.entity_id:
self.async_write_ha_state()
return

View file

@ -302,7 +302,7 @@ class EntityRegistry:
self.async_schedule_save()
data = {"action": "update", "entity_id": entity_id}
data = {"action": "update", "entity_id": entity_id, "changes": list(changes)}
if old.entity_id != entity_id:
data["old_entity_id"] = old.entity_id

View file

@ -163,6 +163,7 @@ async def test_update_entity(hass, client):
msg = await client.receive_json()
assert hass.states.get("test_domain.world") is None
assert registry.entities["test_domain.world"].disabled_by == "user"
# UPDATE DISABLED_BY TO NONE

View file

@ -526,3 +526,34 @@ async def test_warn_disabled(hass, caplog):
ent.async_write_ha_state()
assert hass.states.get("hello.world") is None
assert caplog.text == ""
async def test_disabled_in_entity_registry(hass):
"""Test entity is removed if we disable entity registry entry."""
entry = entity_registry.RegistryEntry(
entity_id="hello.world",
unique_id="test-unique-id",
platform="test-platform",
disabled_by="user",
)
registry = mock_registry(hass, {"hello.world": entry})
ent = entity.Entity()
ent.hass = hass
ent.entity_id = "hello.world"
ent.registry_entry = entry
ent.platform = MagicMock(platform_name="test-platform")
await ent.async_internal_added_to_hass()
ent.async_write_ha_state()
assert hass.states.get("hello.world") is None
entry2 = registry.async_update_entity("hello.world", disabled_by=None)
await hass.async_block_till_done()
assert entry2 != entry
assert ent.registry_entry == entry2
entry3 = registry.async_update_entity("hello.world", disabled_by="user")
await hass.async_block_till_done()
assert entry3 != entry2
assert ent.registry_entry == entry3

View file

@ -219,6 +219,7 @@ async def test_updating_config_entry_id(hass, registry, update_events):
assert update_events[0]["entity_id"] == entry.entity_id
assert update_events[1]["action"] == "update"
assert update_events[1]["entity_id"] == entry.entity_id
assert update_events[1]["changes"] == ["config_entry_id"]
async def test_removing_config_entry_id(hass, registry, update_events):

View file

@ -20,6 +20,7 @@ from tests.common import (
MockEntity,
mock_integration,
mock_entity_platform,
mock_registry,
)
@ -925,3 +926,78 @@ async def test_init_custom_integration(hass):
return_value=mock_coro(integration),
):
await hass.config_entries.flow.async_init("bla")
async def test_support_entry_unload(hass):
"""Test unloading entry."""
assert await config_entries.support_entry_unload(hass, "light")
assert not await config_entries.support_entry_unload(hass, "auth")
async def test_reload_entry_entity_registry_ignores_no_entry(hass):
"""Test reloading entry in entity registry skips if no config entry linked."""
handler = config_entries.EntityRegistryDisabledHandler(hass)
registry = mock_registry(hass)
# Test we ignore entities without config entry
entry = registry.async_get_or_create("light", "hue", "123")
registry.async_update_entity(entry.entity_id, disabled_by="user")
await hass.async_block_till_done()
assert not handler.changed
assert handler._remove_call_later is None
async def test_reload_entry_entity_registry_works(hass):
"""Test we schedule an entry to be reloaded if disabled_by is updated."""
handler = config_entries.EntityRegistryDisabledHandler(hass)
handler.async_setup()
registry = mock_registry(hass)
config_entry = MockConfigEntry(
domain="comp", state=config_entries.ENTRY_STATE_LOADED
)
config_entry.add_to_hass(hass)
mock_setup_entry = MagicMock(return_value=mock_coro(True))
mock_unload_entry = MagicMock(return_value=mock_coro(True))
mock_integration(
hass,
MockModule(
"comp",
async_setup_entry=mock_setup_entry,
async_unload_entry=mock_unload_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
# Only changing disabled_by should update trigger
entity_entry = registry.async_get_or_create(
"light", "hue", "123", config_entry=config_entry
)
registry.async_update_entity(entity_entry.entity_id, name="yo")
await hass.async_block_till_done()
assert not handler.changed
assert handler._remove_call_later is None
# Disable entity, we should not do anything, only act when enabled.
registry.async_update_entity(entity_entry.entity_id, disabled_by="user")
await hass.async_block_till_done()
assert not handler.changed
assert handler._remove_call_later is None
# Enable entity, check we are reloading config entry.
registry.async_update_entity(entity_entry.entity_id, disabled_by=None)
await hass.async_block_till_done()
assert handler.changed == {config_entry.entry_id}
assert handler._remove_call_later is not None
async_fire_time_changed(
hass,
dt.utcnow()
+ timedelta(
seconds=config_entries.EntityRegistryDisabledHandler.RELOAD_AFTER_UPDATE_DELAY
+ 1
),
)
await hass.async_block_till_done()
assert len(mock_unload_entry.mock_calls) == 1