mirror of
https://github.com/home-assistant/core
synced 2024-10-05 15:17:19 +00:00
Storage entity registry (#16018)
* Split out storage delayed write * Update code using delayed save * Fix tests * Fix typing test * Add callback decorator * Migrate entity registry to storage helper * Make double loading protection easier * Lint * Fix tests * Ordered Dict
This commit is contained in:
parent
ef193b0f64
commit
8ec550d6e0
|
@ -6,15 +6,10 @@ identified by their domain, platform and a unique id provided by that platform.
|
|||
The Entity Registry will persist itself 10 seconds after a new entity is
|
||||
registered. Registering a new entity while a timer is in progress resets the
|
||||
timer.
|
||||
|
||||
After initializing, call EntityRegistry.async_ensure_loaded to load the data
|
||||
from disk.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
import logging
|
||||
import os
|
||||
import weakref
|
||||
|
||||
import attr
|
||||
|
@ -22,7 +17,7 @@ import attr
|
|||
from homeassistant.core import callback, split_entity_id, valid_entity_id
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import ensure_unique_string, slugify
|
||||
from homeassistant.util.yaml import load_yaml, save_yaml
|
||||
from homeassistant.util.yaml import load_yaml
|
||||
|
||||
PATH_REGISTRY = 'entity_registry.yaml'
|
||||
DATA_REGISTRY = 'entity_registry'
|
||||
|
@ -32,6 +27,9 @@ _UNDEF = object()
|
|||
DISABLED_HASS = 'hass'
|
||||
DISABLED_USER = 'user'
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = 'core.entity_registry'
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class RegistryEntry:
|
||||
|
@ -79,8 +77,7 @@ class EntityRegistry:
|
|||
"""Initialize the registry."""
|
||||
self.hass = hass
|
||||
self.entities = None
|
||||
self._load_task = None
|
||||
self._sched_save = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@callback
|
||||
def async_is_registered(self, entity_id):
|
||||
|
@ -199,71 +196,72 @@ class EntityRegistry:
|
|||
|
||||
return new
|
||||
|
||||
async def async_ensure_loaded(self):
|
||||
"""Load the registry from disk."""
|
||||
if self.entities is not None:
|
||||
return
|
||||
|
||||
if self._load_task is None:
|
||||
self._load_task = self.hass.async_add_job(self._async_load)
|
||||
|
||||
await self._load_task
|
||||
|
||||
async def _async_load(self):
|
||||
async def async_load(self):
|
||||
"""Load the entity registry."""
|
||||
path = self.hass.config.path(PATH_REGISTRY)
|
||||
data = await self.hass.helpers.storage.async_migrator(
|
||||
self.hass.config.path(PATH_REGISTRY), self._store,
|
||||
old_conf_load_func=load_yaml,
|
||||
old_conf_migrate_func=_async_migrate
|
||||
)
|
||||
entities = OrderedDict()
|
||||
|
||||
if os.path.isfile(path):
|
||||
data = await self.hass.async_add_job(load_yaml, path)
|
||||
|
||||
for entity_id, info in data.items():
|
||||
entities[entity_id] = RegistryEntry(
|
||||
entity_id=entity_id,
|
||||
config_entry_id=info.get('config_entry_id'),
|
||||
unique_id=info['unique_id'],
|
||||
platform=info['platform'],
|
||||
name=info.get('name'),
|
||||
disabled_by=info.get('disabled_by')
|
||||
if data is not None:
|
||||
for entity in data['entities']:
|
||||
entities[entity['entity_id']] = RegistryEntry(
|
||||
entity_id=entity['entity_id'],
|
||||
config_entry_id=entity.get('config_entry_id'),
|
||||
unique_id=entity['unique_id'],
|
||||
platform=entity['platform'],
|
||||
name=entity.get('name'),
|
||||
disabled_by=entity.get('disabled_by')
|
||||
)
|
||||
|
||||
self.entities = entities
|
||||
self._load_task = None
|
||||
|
||||
@callback
|
||||
def async_schedule_save(self):
|
||||
"""Schedule saving the entity registry."""
|
||||
if self._sched_save is not None:
|
||||
self._sched_save.cancel()
|
||||
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)
|
||||
|
||||
self._sched_save = self.hass.loop.call_later(
|
||||
SAVE_DELAY, self.hass.async_add_job, self._async_save
|
||||
)
|
||||
@callback
|
||||
def _data_to_save(self):
|
||||
"""Data of entity registry to store in a file."""
|
||||
data = {}
|
||||
|
||||
async def _async_save(self):
|
||||
"""Save the entity registry to a file."""
|
||||
self._sched_save = None
|
||||
data = OrderedDict()
|
||||
|
||||
for entry in self.entities.values():
|
||||
data[entry.entity_id] = {
|
||||
data['entities'] = [
|
||||
{
|
||||
'entity_id': entry.entity_id,
|
||||
'config_entry_id': entry.config_entry_id,
|
||||
'unique_id': entry.unique_id,
|
||||
'platform': entry.platform,
|
||||
'name': entry.name,
|
||||
}
|
||||
} for entry in self.entities.values()
|
||||
]
|
||||
|
||||
await self.hass.async_add_job(
|
||||
save_yaml, self.hass.config.path(PATH_REGISTRY), data)
|
||||
return data
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_registry(hass) -> EntityRegistry:
|
||||
"""Return entity registry instance."""
|
||||
registry = hass.data.get(DATA_REGISTRY)
|
||||
task = hass.data.get(DATA_REGISTRY)
|
||||
|
||||
if registry is None:
|
||||
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
|
||||
if task is None:
|
||||
async def _load_reg():
|
||||
registry = EntityRegistry(hass)
|
||||
await registry.async_load()
|
||||
return registry
|
||||
|
||||
await registry.async_ensure_loaded()
|
||||
return registry
|
||||
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
|
||||
|
||||
return await task
|
||||
|
||||
|
||||
async def _async_migrate(entities):
|
||||
"""Migrate the YAML config file to storage helper format."""
|
||||
return {
|
||||
'entities': [
|
||||
{'entity_id': entity_id, **info}
|
||||
for entity_id, info in entities.items()
|
||||
]
|
||||
}
|
||||
|
|
|
@ -15,7 +15,9 @@ _LOGGER = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@bind_hass
|
||||
async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None):
|
||||
async def async_migrator(hass, old_path, store, *,
|
||||
old_conf_load_func=json.load_json,
|
||||
old_conf_migrate_func=None):
|
||||
"""Helper function to migrate old data to a store and then load data.
|
||||
|
||||
async def old_conf_migrate_func(old_data)
|
||||
|
@ -25,7 +27,7 @@ async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None):
|
|||
if not os.path.isfile(old_path):
|
||||
return None
|
||||
|
||||
return json.load_json(old_path)
|
||||
return old_conf_load_func(old_path)
|
||||
|
||||
config = await hass.async_add_executor_job(load_old_config)
|
||||
|
||||
|
@ -52,7 +54,7 @@ class Store:
|
|||
self._data = None
|
||||
self._unsub_delay_listener = None
|
||||
self._unsub_stop_listener = None
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._write_lock = asyncio.Lock(loop=hass.loop)
|
||||
self._load_task = None
|
||||
|
||||
@property
|
||||
|
|
|
@ -307,7 +307,12 @@ def mock_registry(hass, mock_entries=None):
|
|||
"""Mock the Entity Registry."""
|
||||
registry = entity_registry.EntityRegistry(hass)
|
||||
registry.entities = mock_entries or {}
|
||||
hass.data[entity_registry.DATA_REGISTRY] = registry
|
||||
|
||||
async def _get_reg():
|
||||
return registry
|
||||
|
||||
hass.data[entity_registry.DATA_REGISTRY] = \
|
||||
hass.loop.create_task(_get_reg())
|
||||
return registry
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from homeassistant.components import light
|
|||
from homeassistant.helpers.intent import IntentHandleError
|
||||
|
||||
from tests.common import (
|
||||
async_mock_service, mock_service, get_test_home_assistant)
|
||||
async_mock_service, mock_service, get_test_home_assistant, mock_storage)
|
||||
|
||||
|
||||
class TestLight(unittest.TestCase):
|
||||
|
@ -333,10 +333,11 @@ class TestLight(unittest.TestCase):
|
|||
"group.all_lights.default,.4,.6,99\n"
|
||||
with mock.patch('os.path.isfile', side_effect=_mock_isfile):
|
||||
with mock.patch('builtins.open', side_effect=_mock_open):
|
||||
self.assertTrue(setup_component(
|
||||
self.hass, light.DOMAIN,
|
||||
{light.DOMAIN: {CONF_PLATFORM: 'test'}}
|
||||
))
|
||||
with mock_storage():
|
||||
self.assertTrue(setup_component(
|
||||
self.hass, light.DOMAIN,
|
||||
{light.DOMAIN: {CONF_PLATFORM: 'test'}}
|
||||
))
|
||||
|
||||
dev, _, _ = platform.DEVICES
|
||||
light.turn_on(self.hass, dev.entity_id)
|
||||
|
@ -371,10 +372,11 @@ class TestLight(unittest.TestCase):
|
|||
"light.ceiling_2.default,.6,.6,100\n"
|
||||
with mock.patch('os.path.isfile', side_effect=_mock_isfile):
|
||||
with mock.patch('builtins.open', side_effect=_mock_open):
|
||||
self.assertTrue(setup_component(
|
||||
self.hass, light.DOMAIN,
|
||||
{light.DOMAIN: {CONF_PLATFORM: 'test'}}
|
||||
))
|
||||
with mock_storage():
|
||||
self.assertTrue(setup_component(
|
||||
self.hass, light.DOMAIN,
|
||||
{light.DOMAIN: {CONF_PLATFORM: 'test'}}
|
||||
))
|
||||
|
||||
dev = next(filter(lambda x: x.entity_id == 'light.ceiling_2',
|
||||
platform.DEVICES))
|
||||
|
|
|
@ -5,13 +5,14 @@ from datetime import timedelta, datetime
|
|||
from unittest.mock import patch
|
||||
|
||||
import homeassistant.core as ha
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
import homeassistant.components.sensor as sensor
|
||||
from homeassistant.const import EVENT_STATE_CHANGED, STATE_UNAVAILABLE
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from tests.common import mock_mqtt_component, fire_mqtt_message, \
|
||||
assert_setup_component
|
||||
assert_setup_component, async_fire_mqtt_message, \
|
||||
async_mock_mqtt_component
|
||||
from tests.common import get_test_home_assistant, mock_component
|
||||
|
||||
|
||||
|
@ -331,27 +332,6 @@ class TestSensorMQTT(unittest.TestCase):
|
|||
state.attributes.get('val'))
|
||||
self.assertEqual('100', state.state)
|
||||
|
||||
def test_unique_id(self):
|
||||
"""Test unique id option only creates one sensor per unique_id."""
|
||||
assert setup_component(self.hass, sensor.DOMAIN, {
|
||||
sensor.DOMAIN: [{
|
||||
'platform': 'mqtt',
|
||||
'name': 'Test 1',
|
||||
'state_topic': 'test-topic',
|
||||
'unique_id': 'TOTALLY_UNIQUE'
|
||||
}, {
|
||||
'platform': 'mqtt',
|
||||
'name': 'Test 2',
|
||||
'state_topic': 'test-topic',
|
||||
'unique_id': 'TOTALLY_UNIQUE'
|
||||
}]
|
||||
})
|
||||
|
||||
fire_mqtt_message(self.hass, 'test-topic', 'payload')
|
||||
self.hass.block_till_done()
|
||||
|
||||
assert len(self.hass.states.all()) == 1
|
||||
|
||||
def test_invalid_device_class(self):
|
||||
"""Test device_class option with invalid value."""
|
||||
with assert_setup_component(0):
|
||||
|
@ -384,3 +364,26 @@ class TestSensorMQTT(unittest.TestCase):
|
|||
assert state.attributes['device_class'] == 'temperature'
|
||||
state = self.hass.states.get('sensor.test_2')
|
||||
assert 'device_class' not in state.attributes
|
||||
|
||||
|
||||
async def test_unique_id(hass):
|
||||
"""Test unique id option only creates one sensor per unique_id."""
|
||||
await async_mock_mqtt_component(hass)
|
||||
assert await async_setup_component(hass, sensor.DOMAIN, {
|
||||
sensor.DOMAIN: [{
|
||||
'platform': 'mqtt',
|
||||
'name': 'Test 1',
|
||||
'state_topic': 'test-topic',
|
||||
'unique_id': 'TOTALLY_UNIQUE'
|
||||
}, {
|
||||
'platform': 'mqtt',
|
||||
'name': 'Test 2',
|
||||
'state_topic': 'test-topic',
|
||||
'unique_id': 'TOTALLY_UNIQUE'
|
||||
}]
|
||||
})
|
||||
|
||||
async_fire_mqtt_message(hass, 'test-topic', 'payload')
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(hass.states.async_all()) == 1
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Tests for the Entity Registry."""
|
||||
import asyncio
|
||||
from unittest.mock import patch, mock_open
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -61,29 +61,13 @@ def test_get_or_create_suggested_object_id_conflict_existing(hass, registry):
|
|||
@asyncio.coroutine
|
||||
def test_create_triggers_save(hass, registry):
|
||||
"""Test that registering entry triggers a save."""
|
||||
with patch.object(hass.loop, 'call_later') as mock_call_later:
|
||||
with patch.object(registry, 'async_schedule_save') as mock_schedule_save:
|
||||
registry.async_get_or_create('light', 'hue', '1234')
|
||||
|
||||
assert len(mock_call_later.mock_calls) == 1
|
||||
assert len(mock_schedule_save.mock_calls) == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_save_timer_reset_on_subsequent_save(hass, registry):
|
||||
"""Test we reset the save timer on a new create."""
|
||||
with patch.object(hass.loop, 'call_later') as mock_call_later:
|
||||
registry.async_get_or_create('light', 'hue', '1234')
|
||||
|
||||
assert len(mock_call_later.mock_calls) == 1
|
||||
|
||||
with patch.object(hass.loop, 'call_later') as mock_call_later_2:
|
||||
registry.async_get_or_create('light', 'hue', '5678')
|
||||
|
||||
assert len(mock_call_later().cancel.mock_calls) == 1
|
||||
assert len(mock_call_later_2.mock_calls) == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_loading_saving_data(hass, registry):
|
||||
async def test_loading_saving_data(hass, registry):
|
||||
"""Test that we load/save data correctly."""
|
||||
orig_entry1 = registry.async_get_or_create('light', 'hue', '1234')
|
||||
orig_entry2 = registry.async_get_or_create(
|
||||
|
@ -91,18 +75,11 @@ def test_loading_saving_data(hass, registry):
|
|||
|
||||
assert len(registry.entities) == 2
|
||||
|
||||
with patch(YAML__OPEN_PATH, mock_open(), create=True) as mock_write:
|
||||
yield from registry._async_save()
|
||||
|
||||
# Mock open calls are: open file, context enter, write, context leave
|
||||
written = mock_write.mock_calls[2][1][0]
|
||||
|
||||
# Now load written data in new registry
|
||||
registry2 = entity_registry.EntityRegistry(hass)
|
||||
registry2._store = registry._store
|
||||
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
patch(YAML__OPEN_PATH, mock_open(read_data=written), create=True):
|
||||
yield from registry2._async_load()
|
||||
await registry2.async_load()
|
||||
|
||||
# Ensure same order
|
||||
assert list(registry.entities) == list(registry2.entities)
|
||||
|
@ -139,32 +116,37 @@ def test_is_registered(registry):
|
|||
assert not registry.async_is_registered('light.non_existing')
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_loading_extra_values(hass):
|
||||
async def test_loading_extra_values(hass, hass_storage):
|
||||
"""Test we load extra data from the registry."""
|
||||
written = """
|
||||
test.named:
|
||||
platform: super_platform
|
||||
unique_id: with-name
|
||||
name: registry override
|
||||
test.no_name:
|
||||
platform: super_platform
|
||||
unique_id: without-name
|
||||
test.disabled_user:
|
||||
platform: super_platform
|
||||
unique_id: disabled-user
|
||||
disabled_by: user
|
||||
test.disabled_hass:
|
||||
platform: super_platform
|
||||
unique_id: disabled-hass
|
||||
disabled_by: hass
|
||||
"""
|
||||
hass_storage[entity_registry.STORAGE_KEY] = {
|
||||
'version': entity_registry.STORAGE_VERSION,
|
||||
'data': {
|
||||
'entities': [
|
||||
{
|
||||
'entity_id': 'test.named',
|
||||
'platform': 'super_platform',
|
||||
'unique_id': 'with-name',
|
||||
'name': 'registry override',
|
||||
}, {
|
||||
'entity_id': 'test.no_name',
|
||||
'platform': 'super_platform',
|
||||
'unique_id': 'without-name',
|
||||
}, {
|
||||
'entity_id': 'test.disabled_user',
|
||||
'platform': 'super_platform',
|
||||
'unique_id': 'disabled-user',
|
||||
'disabled_by': 'user',
|
||||
}, {
|
||||
'entity_id': 'test.disabled_hass',
|
||||
'platform': 'super_platform',
|
||||
'unique_id': 'disabled-hass',
|
||||
'disabled_by': 'hass',
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
registry = entity_registry.EntityRegistry(hass)
|
||||
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
patch(YAML__OPEN_PATH, mock_open(read_data=written), create=True):
|
||||
yield from registry._async_load()
|
||||
registry = await entity_registry.async_get_registry(hass)
|
||||
|
||||
entry_with_name = registry.async_get_or_create(
|
||||
'test', 'super_platform', 'with-name')
|
||||
|
@ -202,3 +184,31 @@ async def test_updating_config_entry_id(registry):
|
|||
'light', 'hue', '5678', config_entry_id='mock-id-2')
|
||||
assert entry.entity_id == entry2.entity_id
|
||||
assert entry2.config_entry_id == 'mock-id-2'
|
||||
|
||||
|
||||
async def test_migration(hass):
|
||||
"""Test migration from old data to new."""
|
||||
old_conf = {
|
||||
'light.kitchen': {
|
||||
'config_entry_id': 'test-config-id',
|
||||
'unique_id': 'test-unique',
|
||||
'platform': 'test-platform',
|
||||
'name': 'Test Name',
|
||||
'disabled_by': 'hass',
|
||||
}
|
||||
}
|
||||
with patch('os.path.isfile', return_value=True), patch('os.remove'), \
|
||||
patch('homeassistant.helpers.entity_registry.load_yaml',
|
||||
return_value=old_conf):
|
||||
registry = await entity_registry.async_get_registry(hass)
|
||||
|
||||
assert registry.async_is_registered('light.kitchen')
|
||||
entry = registry.async_get_or_create(
|
||||
domain='light',
|
||||
platform='test-platform',
|
||||
unique_id='test-unique',
|
||||
config_entry_id='test-config-id',
|
||||
)
|
||||
assert entry.name == 'Test Name'
|
||||
assert entry.disabled_by == 'hass'
|
||||
assert entry.config_entry_id == 'test-config-id'
|
||||
|
|
|
@ -141,11 +141,10 @@ async def test_migrator_no_existing_config(hass, store, hass_storage):
|
|||
async def test_migrator_existing_config(hass, store, hass_storage):
|
||||
"""Test migrating existing config."""
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
patch('os.remove') as mock_remove, \
|
||||
patch('homeassistant.util.json.load_json',
|
||||
return_value={'old': 'config'}):
|
||||
patch('os.remove') as mock_remove:
|
||||
data = await storage.async_migrator(
|
||||
hass, 'old-path', store)
|
||||
hass, 'old-path', store,
|
||||
old_conf_load_func=lambda _: {'old': 'config'})
|
||||
|
||||
assert len(mock_remove.mock_calls) == 1
|
||||
assert data == {'old': 'config'}
|
||||
|
@ -163,12 +162,11 @@ async def test_migrator_transforming_config(hass, store, hass_storage):
|
|||
return {'new': old_config['old']}
|
||||
|
||||
with patch('os.path.isfile', return_value=True), \
|
||||
patch('os.remove') as mock_remove, \
|
||||
patch('homeassistant.util.json.load_json',
|
||||
return_value={'old': 'config'}):
|
||||
patch('os.remove') as mock_remove:
|
||||
data = await storage.async_migrator(
|
||||
hass, 'old-path', store,
|
||||
old_conf_migrate_func=old_conf_migrate_func)
|
||||
old_conf_migrate_func=old_conf_migrate_func,
|
||||
old_conf_load_func=lambda _: {'old': 'config'})
|
||||
|
||||
assert len(mock_remove.mock_calls) == 1
|
||||
assert data == {'new': 'config'}
|
||||
|
|
Loading…
Reference in a new issue