Allow entities to indicate they should be disabled by default (#26011)

This commit is contained in:
Paulus Schoutsen 2019-08-16 16:17:16 -07:00 committed by GitHub
parent b5893a8a6e
commit 6c292846be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 1 deletions

View file

@ -216,6 +216,11 @@ class Entity:
"""Time that a context is considered recent."""
return timedelta(seconds=5)
@property
def entity_registry_enabled_default(self):
"""Return if the entity should be enabled when first added to the entity registry."""
return True
# DO NOT OVERWRITE
# These properties and methods are either managed by Home Assistant or they
# are used to perform a very specific function. Overwriting these may

View file

@ -8,6 +8,7 @@ from homeassistant.core import callback, valid_entity_id, split_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.util.async_ import run_callback_threadsafe, run_coroutine_threadsafe
from .entity_registry import DISABLED_INTEGRATION
from .event import async_track_time_interval, async_call_later
@ -333,6 +334,10 @@ class EntityPlatform:
if device:
device_id = device.id
disabled_by: Optional[str] = None
if not entity.entity_registry_enabled_default:
disabled_by = DISABLED_INTEGRATION
entry = entity_registry.async_get_or_create(
self.domain,
self.platform_name,
@ -341,6 +346,7 @@ class EntityPlatform:
config_entry_id=config_entry_id,
device_id=device_id,
known_object_ids=self.entities.keys(),
disabled_by=disabled_by,
)
if entry.disabled:

View file

@ -35,6 +35,7 @@ _LOGGER = logging.getLogger(__name__)
_UNDEF = object()
DISABLED_HASS = "hass"
DISABLED_USER = "user"
DISABLED_INTEGRATION = "integration"
STORAGE_VERSION = 1
STORAGE_KEY = "core.entity_registry"
@ -53,7 +54,9 @@ class RegistryEntry:
disabled_by = attr.ib(
type=str,
default=None,
validator=attr.validators.in_((DISABLED_HASS, DISABLED_USER, None)),
validator=attr.validators.in_(
(DISABLED_HASS, DISABLED_USER, DISABLED_INTEGRATION, None)
),
) # type: Optional[str]
domain = attr.ib(type=str, init=False, repr=False)
@ -132,6 +135,7 @@ class EntityRegistry:
config_entry_id=None,
device_id=None,
known_object_ids=None,
disabled_by=None,
):
"""Get entity. Create if it doesn't exist."""
entity_id = self.async_get_entity_id(domain, platform, unique_id)
@ -161,6 +165,7 @@ class EntityRegistry:
device_id=device_id,
unique_id=unique_id,
platform=platform,
disabled_by=disabled_by,
)
self.entities[entity_id] = entity
_LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id)

View file

@ -908,6 +908,11 @@ class MockEntity(entity.Entity):
"""Info how it links to a device."""
return self._handle("device_info")
@property
def entity_registry_enabled_default(self):
"""Return if the entity should be enabled when first added to the entity registry."""
return self._handle("entity_registry_enabled_default")
def _handle(self, attr):
"""Return attribute value."""
if attr in self._values:

View file

@ -775,3 +775,22 @@ async def test_device_info_not_overrides(hass):
assert device.id == device2.id
assert device2.manufacturer == "test-manufacturer"
assert device2.model == "test-model"
async def test_entity_disabled_by_integration(hass):
"""Test entity disabled by integration."""
component = EntityComponent(_LOGGER, DOMAIN, hass, timedelta(seconds=20))
entity_default = MockEntity(unique_id="default")
entity_disabled = MockEntity(
unique_id="disabled", entity_registry_enabled_default=False
)
await component.async_add_entities([entity_default, entity_disabled])
registry = await hass.helpers.entity_registry.async_get_registry()
entry_default = registry.async_get_or_create(DOMAIN, DOMAIN, "default")
assert entry_default.disabled_by is None
entry_disabled = registry.async_get_or_create(DOMAIN, DOMAIN, "disabled")
assert entry_disabled.disabled_by == "integration"

View file

@ -352,3 +352,17 @@ async def test_update_entity_unique_id_conflict(registry):
) as mock_schedule_save, pytest.raises(ValueError):
registry.async_update_entity(entry.entity_id, new_unique_id=entry2.unique_id)
assert mock_schedule_save.call_count == 0
async def test_disabled_by(registry):
"""Test that we can disable an entry when we create it."""
entry = registry.async_get_or_create("light", "hue", "5678", disabled_by="hass")
assert entry.disabled_by == "hass"
entry = registry.async_get_or_create(
"light", "hue", "5678", disabled_by="integration"
)
assert entry.disabled_by == "hass"
entry2 = registry.async_get_or_create("light", "hue", "1234")
assert entry2.disabled_by is None