Clean up device update, add via-hub (#16659)

* Clean up device update, add via-hub

* Test loading/saving data

* Lint

* Add to Hue"

* Lint + tests
This commit is contained in:
Paulus Schoutsen 2018-09-17 13:39:30 +02:00 committed by GitHub
parent 849a93e0a6
commit b8257866f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 269 additions and 78 deletions

View file

@ -40,6 +40,7 @@ def websocket_list_devices(hass, connection, msg):
'name': entry.name,
'sw_version': entry.sw_version,
'id': entry.id,
'hub_device_id': entry.hub_device_id,
} for entry in registry.devices.values()]
))

View file

@ -127,7 +127,7 @@ async def async_setup_entry(hass, config_entry):
device_registry = await \
hass.helpers.device_registry.async_get_registry()
device_registry.async_get_or_create(
config_entry=config_entry.entry_id,
config_entry_id=config_entry.entry_id,
connections={(CONNECTION_NETWORK_MAC, deconz.config.mac)},
identifiers={(DOMAIN, deconz.config.bridgeid)},
manufacturer='Dresden Elektronik', model=deconz.config.modelid,

View file

@ -140,7 +140,7 @@ async def async_setup_entry(hass, entry):
config = bridge.api.config
device_registry = await dr.async_get_registry(hass)
device_registry.async_get_or_create(
config_entry=entry.entry_id,
config_entry_id=entry.entry_id,
connections={
(dr.CONNECTION_NETWORK_MAC, config.mac)
},

View file

@ -302,6 +302,7 @@ class HueLight(Light):
'model': self.light.productname or self.light.modelid,
# Not yet exposed as properties in aiohue
'sw_version': self.light.raw['swversion'],
'via_hub': (hue.DOMAIN, self.bridge.api.config.bridgeid),
}
async def async_turn_on(self, **kwargs):

View file

@ -10,6 +10,7 @@ from homeassistant.core import callback
from homeassistant.loader import bind_hass
_LOGGER = logging.getLogger(__name__)
_UNDEF = object()
DATA_REGISTRY = 'device_registry'
@ -32,6 +33,7 @@ class DeviceEntry:
model = attr.ib(type=str)
name = attr.ib(type=str, default=None)
sw_version = attr.ib(type=str, default=None)
hub_device_id = attr.ib(type=str, default=None)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
@ -54,28 +56,36 @@ class DeviceRegistry:
return None
@callback
def async_get_or_create(self, *, config_entry, connections, identifiers,
manufacturer, model, name=None, sw_version=None):
def async_get_or_create(self, *, config_entry_id, connections, identifiers,
manufacturer, model, name=None, sw_version=None,
via_hub=None):
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
return None
device = self.async_get_device(identifiers, connections)
if via_hub is not None:
hub_device = self.async_get_device({via_hub}, set())
hub_device_id = hub_device.id if hub_device else None
else:
hub_device_id = None
if device is not None:
if config_entry not in device.config_entries:
device.config_entries.add(config_entry)
self.async_schedule_save()
return device
return self._async_update_device(
device.id, config_entry_id=config_entry_id,
hub_device_id=hub_device_id
)
device = DeviceEntry(
config_entries=[config_entry],
config_entries={config_entry_id},
connections=connections,
identifiers=identifiers,
manufacturer=manufacturer,
model=model,
name=name,
sw_version=sw_version
sw_version=sw_version,
hub_device_id=hub_device_id
)
self.devices[device.id] = device
@ -83,24 +93,64 @@ class DeviceRegistry:
return device
@callback
def _async_update_device(self, device_id, *, config_entry_id=_UNDEF,
remove_config_entry_id=_UNDEF,
hub_device_id=_UNDEF):
"""Update device attributes."""
old = self.devices[device_id]
changes = {}
config_entries = old.config_entries
if (config_entry_id is not _UNDEF and
config_entry_id not in old.config_entries):
config_entries = old.config_entries | {config_entry_id}
if (remove_config_entry_id is not _UNDEF and
remove_config_entry_id in config_entries):
config_entries = set(config_entries)
config_entries.remove(remove_config_entry_id)
if config_entries is not old.config_entries:
changes['config_entries'] = config_entries
if (hub_device_id is not _UNDEF and
hub_device_id != old.hub_device_id):
changes['hub_device_id'] = hub_device_id
if not changes:
return old
new = self.devices[device_id] = attr.evolve(old, **changes)
self.async_schedule_save()
return new
async def async_load(self):
"""Load the device registry."""
devices = await self._store.async_load()
data = await self._store.async_load()
if devices is None:
self.devices = OrderedDict()
return
devices = OrderedDict()
self.devices = {device['id']: DeviceEntry(
config_entries=device['config_entries'],
connections={tuple(conn) for conn in device['connections']},
identifiers={tuple(iden) for iden in device['identifiers']},
if data is not None:
for device in data['devices']:
devices[device['id']] = DeviceEntry(
config_entries=set(device['config_entries']),
connections={tuple(conn) for conn
in device['connections']},
identifiers={tuple(iden) for iden
in device['identifiers']},
manufacturer=device['manufacturer'],
model=device['model'],
name=device['name'],
sw_version=device['sw_version'],
id=device['id'],
) for device in devices['devices']}
# Introduced in 0.79
hub_device_id=device.get('hub_device_id'),
)
self.devices = devices
@callback
def async_schedule_save(self):
@ -122,18 +172,19 @@ class DeviceRegistry:
'name': entry.name,
'sw_version': entry.sw_version,
'id': entry.id,
'hub_device_id': entry.hub_device_id,
} for entry in self.devices.values()
]
return data
@callback
def async_clear_config_entry(self, config_entry):
def async_clear_config_entry(self, config_entry_id):
"""Clear config entry from registry entries."""
for device in self.devices.values():
if config_entry in device.config_entries:
device.config_entries.remove(config_entry)
self.async_schedule_save()
for dev_id, device in self.devices.items():
if config_entry_id in device.config_entries:
self._async_update_device(
dev_id, remove_config_entry_id=config_entry_id)
@bind_hass

View file

@ -273,15 +273,18 @@ class EntityPlatform:
config_entry_id = None
device_info = entity.device_info
if config_entry_id is not None and device_info is not None:
device = device_registry.async_get_or_create(
config_entry=config_entry_id,
connections=device_info.get('connections', []),
identifiers=device_info.get('identifiers', []),
config_entry_id=config_entry_id,
connections=device_info.get('connections') or set(),
identifiers=device_info.get('identifiers') or set(),
manufacturer=device_info.get('manufacturer'),
model=device_info.get('model'),
name=device_info.get('name'),
sw_version=device_info.get('sw_version'))
sw_version=device_info.get('sw_version'),
via_hub=device_info.get('via_hub'))
if device:
device_id = device.id
else:
device_id = None

View file

@ -31,7 +31,7 @@ STORAGE_VERSION = 1
STORAGE_KEY = 'core.entity_registry'
@attr.s(slots=True)
@attr.s(slots=True, frozen=True)
class RegistryEntry:
"""Entity Registry Entry."""
@ -113,14 +113,9 @@ class EntityRegistry:
"""Get entity. Create if it doesn't exist."""
entity_id = self.async_get_entity_id(domain, platform, unique_id)
if entity_id:
entry = self.entities[entity_id]
if entry.config_entry_id == config_entry_id:
return entry
self._async_update_entity(
return self._async_update_entity(
entity_id, config_entry_id=config_entry_id,
device_id=device_id)
return self.entities[entity_id]
entity_id = self.async_generate_entity_id(
domain, suggested_object_id or '{}_{}'.format(platform, unique_id))
@ -253,10 +248,9 @@ class EntityRegistry:
@callback
def async_clear_config_entry(self, config_entry):
"""Clear config entry from registry entries."""
for entry in self.entities.values():
for entity_id, entry in self.entities.items():
if config_entry == entry.config_entry_id:
entry.config_entry_id = None
self.async_schedule_save()
self._async_update_entity(entity_id, config_entry_id=None)
@bind_hass

View file

@ -763,6 +763,11 @@ class MockEntity(entity.Entity):
"""Return True if entity is available."""
return self._handle('available')
@property
def device_info(self):
"""Info how it links to a device."""
return self._handle('device_info')
def _handle(self, attr):
"""Return attribute value."""
if attr in self._values:

View file

@ -21,15 +21,16 @@ def registry(hass):
async def test_list_devices(hass, client, registry):
"""Test list entries."""
registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections={},
identifiers={('bridgeid', '1234')},
manufacturer='manufacturer', model='model')
manufacturer='manufacturer', model='model',
via_hub=('bridgeid', '0123'))
await client.send_json({
'id': 5,
@ -37,8 +38,7 @@ async def test_list_devices(hass, client, registry):
})
msg = await client.receive_json()
for entry in msg['result']:
entry.pop('id')
dev1, dev2 = [entry.pop('id') for entry in msg['result']]
assert msg['result'] == [
{
@ -47,7 +47,8 @@ async def test_list_devices(hass, client, registry):
'manufacturer': 'manufacturer',
'model': 'model',
'name': None,
'sw_version': None
'sw_version': None,
'hub_device_id': None,
},
{
'config_entries': ['1234'],
@ -55,6 +56,7 @@ async def test_list_devices(hass, client, registry):
'manufacturer': 'manufacturer',
'model': 'model',
'name': None,
'sw_version': None
'sw_version': None,
'hub_device_id': dev1,
}
]

View file

@ -182,7 +182,7 @@ async def test_config_passed_to_config_entry(hass):
assert len(mock_registry.mock_calls) == 1
assert mock_registry.mock_calls[0][2] == {
'config_entry': entry.entry_id,
'config_entry_id': entry.entry_id,
'connections': {
('mac', 'mock-mac')
},
@ -192,7 +192,7 @@ async def test_config_passed_to_config_entry(hass):
'manufacturer': 'Signify',
'name': 'mock-name',
'model': 'mock-modelid',
'sw_version': 'mock-swversion'
'sw_version': 'mock-swversion',
}

View file

@ -2,7 +2,7 @@
import pytest
from homeassistant.helpers import device_registry
from tests.common import mock_device_registry
from tests.common import mock_device_registry, flush_store
@pytest.fixture
@ -14,41 +14,41 @@ def registry(hass):
async def test_get_or_create_returns_same_entry(registry):
"""Make sure we do not duplicate entries."""
entry = registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
entry2 = registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections={('ethernet', '11:22:33:44:55:66:77:88')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
entry3 = registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '1234')},
manufacturer='manufacturer', model='model')
assert len(registry.devices) == 1
assert entry is entry2
assert entry is entry3
assert entry.id == entry2.id
assert entry.id == entry3.id
assert entry.identifiers == {('bridgeid', '0123')}
async def test_requirement_for_identifier_or_connection(registry):
"""Make sure we do require some descriptor of device."""
entry = registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers=set(),
manufacturer='manufacturer', model='model')
entry2 = registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections=set(),
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
entry3 = registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections=set(),
identifiers=set(),
manufacturer='manufacturer', model='model')
@ -62,25 +62,25 @@ async def test_requirement_for_identifier_or_connection(registry):
async def test_multiple_config_entries(registry):
"""Make sure we do not get duplicate entries."""
entry = registry.async_get_or_create(
config_entry='123',
config_entry_id='123',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
entry2 = registry.async_get_or_create(
config_entry='456',
config_entry_id='456',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
entry3 = registry.async_get_or_create(
config_entry='123',
config_entry_id='123',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
assert len(registry.devices) == 1
assert entry is entry2
assert entry is entry3
assert entry.config_entries == {'123', '456'}
assert entry.id == entry2.id
assert entry.id == entry3.id
assert entry2.config_entries == {'123', '456'}
async def test_loading_from_storage(hass, hass_storage):
@ -118,7 +118,7 @@ async def test_loading_from_storage(hass, hass_storage):
registry = await device_registry.async_get_registry(hass)
entry = registry.async_get_or_create(
config_entry='1234',
config_entry_id='1234',
connections={('Zigbee', '01.23.45.67.89')},
identifiers={('serial', '12:34:56:78:90:AB:CD:EF')},
manufacturer='manufacturer', model='model')
@ -129,25 +129,106 @@ async def test_loading_from_storage(hass, hass_storage):
async def test_removing_config_entries(registry):
"""Make sure we do not get duplicate entries."""
entry = registry.async_get_or_create(
config_entry='123',
config_entry_id='123',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
entry2 = registry.async_get_or_create(
config_entry='456',
config_entry_id='456',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
entry3 = registry.async_get_or_create(
config_entry='123',
config_entry_id='123',
connections={('ethernet', '34:56:78:90:AB:CD:EF:12')},
identifiers={('bridgeid', '4567')},
manufacturer='manufacturer', model='model')
assert len(registry.devices) == 2
assert entry is entry2
assert entry is not entry3
assert entry.config_entries == {'123', '456'}
assert entry.id == entry2.id
assert entry.id != entry3.id
assert entry2.config_entries == {'123', '456'}
registry.async_clear_config_entry('123')
entry = registry.async_get_device({('bridgeid', '0123')}, set())
entry3 = registry.async_get_device({('bridgeid', '4567')}, set())
assert entry.config_entries == {'456'}
assert entry3.config_entries == set()
async def test_specifying_hub_device_create(registry):
"""Test specifying a hub and updating."""
hub = registry.async_get_or_create(
config_entry_id='123',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('hue', '0123')},
manufacturer='manufacturer', model='hub')
light = registry.async_get_or_create(
config_entry_id='456',
connections=set(),
identifiers={('hue', '456')},
manufacturer='manufacturer', model='light',
via_hub=('hue', '0123'))
assert light.hub_device_id == hub.id
async def test_specifying_hub_device_update(registry):
"""Test specifying a hub and updating."""
light = registry.async_get_or_create(
config_entry_id='456',
connections=set(),
identifiers={('hue', '456')},
manufacturer='manufacturer', model='light',
via_hub=('hue', '0123'))
assert light.hub_device_id is None
hub = registry.async_get_or_create(
config_entry_id='123',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('hue', '0123')},
manufacturer='manufacturer', model='hub')
light = registry.async_get_or_create(
config_entry_id='456',
connections=set(),
identifiers={('hue', '456')},
manufacturer='manufacturer', model='light',
via_hub=('hue', '0123'))
assert light.hub_device_id == hub.id
async def test_loading_saving_data(hass, registry):
"""Test that we load/save data correctly."""
orig_hub = registry.async_get_or_create(
config_entry_id='123',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('hue', '0123')},
manufacturer='manufacturer', model='hub')
orig_light = registry.async_get_or_create(
config_entry_id='456',
connections=set(),
identifiers={('hue', '456')},
manufacturer='manufacturer', model='light',
via_hub=('hue', '0123'))
assert len(registry.devices) == 2
# Now load written data in new registry
registry2 = device_registry.DeviceRegistry(hass)
await flush_store(registry._store)
await registry2.async_load()
# Ensure same order
assert list(registry.devices) == list(registry2.devices)
new_hub = registry2.async_get_device({('hue', '0123')}, set())
new_light = registry2.async_get_device({('hue', '456')}, set())
assert orig_hub == new_hub
assert orig_light == new_light

View file

@ -676,3 +676,55 @@ async def test_entity_registry_updates_invalid_entity_id(hass):
assert hass.states.get('test_domain.world') is not None
assert hass.states.get('invalid_entity_id') is None
assert hass.states.get('diff_domain.world') is None
async def test_device_info_called(hass):
"""Test device info is forwarded correctly."""
registry = await hass.helpers.device_registry.async_get_registry()
hub = registry.async_get_or_create(
config_entry_id='123',
connections=set(),
identifiers={('hue', 'hub-id')},
manufacturer='manufacturer', model='hub'
)
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Mock setup entry method."""
async_add_entities([
# Invalid device info
MockEntity(unique_id='abcd', device_info={}),
# Valid device info
MockEntity(unique_id='qwer', device_info={
'identifiers': {('hue', '1234')},
'connections': {('mac', 'abcd')},
'manufacturer': 'test-manuf',
'model': 'test-model',
'name': 'test-name',
'sw_version': 'test-sw',
'via_hub': ('hue', 'hub-id'),
}),
])
return True
platform = MockPlatform(
async_setup_entry=async_setup_entry
)
config_entry = MockConfigEntry(entry_id='super-mock-id')
entity_platform = MockEntityPlatform(
hass,
platform_name=config_entry.domain,
platform=platform
)
assert await entity_platform.async_setup_entry(config_entry)
await hass.async_block_till_done()
device = registry.async_get_device({('hue', '1234')}, set())
assert device is not None
assert device.identifiers == {('hue', '1234')}
assert device.connections == {('mac', 'abcd')}
assert device.manufacturer == 'test-manuf'
assert device.model == 'test-model'
assert device.name == 'test-name'
assert device.sw_version == 'test-sw'
assert device.hub_device_id == hub.id

View file

@ -6,7 +6,7 @@ import pytest
from homeassistant.helpers import entity_registry
from tests.common import mock_registry
from tests.common import mock_registry, flush_store
YAML__OPEN_PATH = 'homeassistant.util.yaml.open'
@ -77,8 +77,7 @@ async def test_loading_saving_data(hass, registry):
# Now load written data in new registry
registry2 = entity_registry.EntityRegistry(hass)
registry2._store = registry._store
await flush_store(registry._store)
await registry2.async_load()
# Ensure same order
@ -192,6 +191,8 @@ async def test_removing_config_entry_id(registry):
'light', 'hue', '5678', config_entry_id='mock-id-1')
assert entry.config_entry_id == 'mock-id-1'
registry.async_clear_config_entry('mock-id-1')
entry = registry.entities[entry.entity_id]
assert entry.config_entry_id is None