diff --git a/homeassistant/components/config/device_registry.py b/homeassistant/components/config/device_registry.py index 8383e0cdc7da..88aa5727a972 100644 --- a/homeassistant/components/config/device_registry.py +++ b/homeassistant/components/config/device_registry.py @@ -40,6 +40,7 @@ def websocket_list_devices(hass, connection, msg): 'name': entry.name, 'sw_version': entry.sw_version, 'id': entry.id, + 'hub_device_id': entry.hub_device_id, } for entry in registry.devices.values()] )) diff --git a/homeassistant/components/deconz/__init__.py b/homeassistant/components/deconz/__init__.py index 6ed0a6e2c11a..82f4233a7da7 100644 --- a/homeassistant/components/deconz/__init__.py +++ b/homeassistant/components/deconz/__init__.py @@ -127,7 +127,7 @@ async def async_setup_entry(hass, config_entry): device_registry = await \ hass.helpers.device_registry.async_get_registry() device_registry.async_get_or_create( - config_entry=config_entry.entry_id, + config_entry_id=config_entry.entry_id, connections={(CONNECTION_NETWORK_MAC, deconz.config.mac)}, identifiers={(DOMAIN, deconz.config.bridgeid)}, manufacturer='Dresden Elektronik', model=deconz.config.modelid, diff --git a/homeassistant/components/hue/__init__.py b/homeassistant/components/hue/__init__.py index 38b521078f42..7a781c99f538 100644 --- a/homeassistant/components/hue/__init__.py +++ b/homeassistant/components/hue/__init__.py @@ -140,7 +140,7 @@ async def async_setup_entry(hass, entry): config = bridge.api.config device_registry = await dr.async_get_registry(hass) device_registry.async_get_or_create( - config_entry=entry.entry_id, + config_entry_id=entry.entry_id, connections={ (dr.CONNECTION_NETWORK_MAC, config.mac) }, diff --git a/homeassistant/components/light/hue.py b/homeassistant/components/light/hue.py index 6f6e0ed617e0..958abaca0331 100644 --- a/homeassistant/components/light/hue.py +++ b/homeassistant/components/light/hue.py @@ -302,6 +302,7 @@ class HueLight(Light): 'model': self.light.productname or self.light.modelid, # Not yet exposed as properties in aiohue 'sw_version': self.light.raw['swversion'], + 'via_hub': (hue.DOMAIN, self.bridge.api.config.bridgeid), } async def async_turn_on(self, **kwargs): diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index e6ff45af2fe0..8d4cd0a5bbf6 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -10,6 +10,7 @@ from homeassistant.core import callback from homeassistant.loader import bind_hass _LOGGER = logging.getLogger(__name__) +_UNDEF = object() DATA_REGISTRY = 'device_registry' @@ -32,6 +33,7 @@ class DeviceEntry: model = attr.ib(type=str) name = attr.ib(type=str, default=None) sw_version = attr.ib(type=str, default=None) + hub_device_id = attr.ib(type=str, default=None) id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) @@ -54,28 +56,36 @@ class DeviceRegistry: return None @callback - def async_get_or_create(self, *, config_entry, connections, identifiers, - manufacturer, model, name=None, sw_version=None): + def async_get_or_create(self, *, config_entry_id, connections, identifiers, + manufacturer, model, name=None, sw_version=None, + via_hub=None): """Get device. Create if it doesn't exist.""" if not identifiers and not connections: return None device = self.async_get_device(identifiers, connections) + if via_hub is not None: + hub_device = self.async_get_device({via_hub}, set()) + hub_device_id = hub_device.id if hub_device else None + else: + hub_device_id = None + if device is not None: - if config_entry not in device.config_entries: - device.config_entries.add(config_entry) - self.async_schedule_save() - return device + return self._async_update_device( + device.id, config_entry_id=config_entry_id, + hub_device_id=hub_device_id + ) device = DeviceEntry( - config_entries=[config_entry], + config_entries={config_entry_id}, connections=connections, identifiers=identifiers, manufacturer=manufacturer, model=model, name=name, - sw_version=sw_version + sw_version=sw_version, + hub_device_id=hub_device_id ) self.devices[device.id] = device @@ -83,24 +93,64 @@ class DeviceRegistry: return device + @callback + def _async_update_device(self, device_id, *, config_entry_id=_UNDEF, + remove_config_entry_id=_UNDEF, + hub_device_id=_UNDEF): + """Update device attributes.""" + old = self.devices[device_id] + + changes = {} + + config_entries = old.config_entries + + if (config_entry_id is not _UNDEF and + config_entry_id not in old.config_entries): + config_entries = old.config_entries | {config_entry_id} + + if (remove_config_entry_id is not _UNDEF and + remove_config_entry_id in config_entries): + config_entries = set(config_entries) + config_entries.remove(remove_config_entry_id) + + if config_entries is not old.config_entries: + changes['config_entries'] = config_entries + + if (hub_device_id is not _UNDEF and + hub_device_id != old.hub_device_id): + changes['hub_device_id'] = hub_device_id + + if not changes: + return old + + new = self.devices[device_id] = attr.evolve(old, **changes) + self.async_schedule_save() + return new + async def async_load(self): """Load the device registry.""" - devices = await self._store.async_load() + data = await self._store.async_load() - if devices is None: - self.devices = OrderedDict() - return + devices = OrderedDict() - self.devices = {device['id']: DeviceEntry( - config_entries=device['config_entries'], - connections={tuple(conn) for conn in device['connections']}, - identifiers={tuple(iden) for iden in device['identifiers']}, - manufacturer=device['manufacturer'], - model=device['model'], - name=device['name'], - sw_version=device['sw_version'], - id=device['id'], - ) for device in devices['devices']} + if data is not None: + for device in data['devices']: + devices[device['id']] = DeviceEntry( + config_entries=set(device['config_entries']), + connections={tuple(conn) for conn + in device['connections']}, + identifiers={tuple(iden) for iden + in device['identifiers']}, + manufacturer=device['manufacturer'], + model=device['model'], + name=device['name'], + sw_version=device['sw_version'], + id=device['id'], + # Introduced in 0.79 + hub_device_id=device.get('hub_device_id'), + ) + + self.devices = devices @callback def async_schedule_save(self): @@ -122,18 +172,19 @@ class DeviceRegistry: 'name': entry.name, 'sw_version': entry.sw_version, 'id': entry.id, + 'hub_device_id': entry.hub_device_id, } for entry in self.devices.values() ] return data @callback - def async_clear_config_entry(self, config_entry): + def async_clear_config_entry(self, config_entry_id): """Clear config entry from registry entries.""" - for device in self.devices.values(): - if config_entry in device.config_entries: - device.config_entries.remove(config_entry) - self.async_schedule_save() + for dev_id, device in self.devices.items(): + if config_entry_id in device.config_entries: + self._async_update_device( + dev_id, remove_config_entry_id=config_entry_id) @bind_hass diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 083a2946122f..f2913e373391 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -273,16 +273,19 @@ class EntityPlatform: config_entry_id = None device_info = entity.device_info + if config_entry_id is not None and device_info is not None: device = device_registry.async_get_or_create( - config_entry=config_entry_id, - connections=device_info.get('connections', []), - identifiers=device_info.get('identifiers', []), + config_entry_id=config_entry_id, + connections=device_info.get('connections') or set(), + identifiers=device_info.get('identifiers') or set(), manufacturer=device_info.get('manufacturer'), model=device_info.get('model'), name=device_info.get('name'), - sw_version=device_info.get('sw_version')) - device_id = device.id + sw_version=device_info.get('sw_version'), + via_hub=device_info.get('via_hub')) + if device: + device_id = device.id else: device_id = None diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index da3645a96fe8..01c8419dc040 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -31,7 +31,7 @@ STORAGE_VERSION = 1 STORAGE_KEY = 'core.entity_registry' -@attr.s(slots=True) +@attr.s(slots=True, frozen=True) class RegistryEntry: """Entity Registry Entry.""" @@ -113,14 +113,9 @@ class EntityRegistry: """Get entity. Create if it doesn't exist.""" entity_id = self.async_get_entity_id(domain, platform, unique_id) if entity_id: - entry = self.entities[entity_id] - if entry.config_entry_id == config_entry_id: - return entry - - self._async_update_entity( + return self._async_update_entity( entity_id, config_entry_id=config_entry_id, device_id=device_id) - return self.entities[entity_id] entity_id = self.async_generate_entity_id( domain, suggested_object_id or '{}_{}'.format(platform, unique_id)) @@ -253,10 +248,9 @@ class EntityRegistry: @callback def async_clear_config_entry(self, config_entry): """Clear config entry from registry entries.""" - for entry in self.entities.values(): + for entity_id, entry in self.entities.items(): if config_entry == entry.config_entry_id: - entry.config_entry_id = None - self.async_schedule_save() + self._async_update_entity(entity_id, config_entry_id=None) @bind_hass diff --git a/tests/common.py b/tests/common.py index 6629207b2885..56e86a4cd5c1 100644 --- a/tests/common.py +++ b/tests/common.py @@ -763,6 +763,11 @@ class MockEntity(entity.Entity): """Return True if entity is available.""" return self._handle('available') + @property + def device_info(self): + """Info how it links to a device.""" + return self._handle('device_info') + def _handle(self, attr): """Return attribute value.""" if attr in self._values: diff --git a/tests/components/config/test_device_registry.py b/tests/components/config/test_device_registry.py index 491319bf9279..f8ea51cfdc82 100644 --- a/tests/components/config/test_device_registry.py +++ b/tests/components/config/test_device_registry.py @@ -21,15 +21,16 @@ def registry(hass): async def test_list_devices(hass, client, registry): """Test list entries.""" registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections={}, identifiers={('bridgeid', '1234')}, - manufacturer='manufacturer', model='model') + manufacturer='manufacturer', model='model', + via_hub=('bridgeid', '0123')) await client.send_json({ 'id': 5, @@ -37,8 +38,7 @@ async def test_list_devices(hass, client, registry): }) msg = await client.receive_json() - for entry in msg['result']: - entry.pop('id') + dev1, dev2 = [entry.pop('id') for entry in msg['result']] assert msg['result'] == [ { @@ -47,7 +47,8 @@ async def test_list_devices(hass, client, registry): 'manufacturer': 'manufacturer', 'model': 'model', 'name': None, - 'sw_version': None + 'sw_version': None, + 'hub_device_id': None, }, { 'config_entries': ['1234'], @@ -55,6 +56,7 @@ async def test_list_devices(hass, client, registry): 'manufacturer': 'manufacturer', 'model': 'model', 'name': None, - 'sw_version': None + 'sw_version': None, + 'hub_device_id': dev1, } ] diff --git a/tests/components/hue/test_init.py b/tests/components/hue/test_init.py index 1c4768746d5b..5da6d5b709aa 100644 --- a/tests/components/hue/test_init.py +++ b/tests/components/hue/test_init.py @@ -182,7 +182,7 @@ async def test_config_passed_to_config_entry(hass): assert len(mock_registry.mock_calls) == 1 assert mock_registry.mock_calls[0][2] == { - 'config_entry': entry.entry_id, + 'config_entry_id': entry.entry_id, 'connections': { ('mac', 'mock-mac') }, @@ -192,7 +192,7 @@ async def test_config_passed_to_config_entry(hass): 'manufacturer': 'Signify', 'name': 'mock-name', 'model': 'mock-modelid', - 'sw_version': 'mock-swversion' + 'sw_version': 'mock-swversion', } diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index 5ae6b4df651f..b251846c4912 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -2,7 +2,7 @@ import pytest from homeassistant.helpers import device_registry -from tests.common import mock_device_registry +from tests.common import mock_device_registry, flush_store @pytest.fixture @@ -14,41 +14,41 @@ def registry(hass): async def test_get_or_create_returns_same_entry(registry): """Make sure we do not duplicate entries.""" entry = registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry2 = registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections={('ethernet', '11:22:33:44:55:66:77:88')}, identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry3 = registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '1234')}, manufacturer='manufacturer', model='model') assert len(registry.devices) == 1 - assert entry is entry2 - assert entry is entry3 + assert entry.id == entry2.id + assert entry.id == entry3.id assert entry.identifiers == {('bridgeid', '0123')} async def test_requirement_for_identifier_or_connection(registry): """Make sure we do require some descriptor of device.""" entry = registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers=set(), manufacturer='manufacturer', model='model') entry2 = registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections=set(), identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry3 = registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections=set(), identifiers=set(), manufacturer='manufacturer', model='model') @@ -62,25 +62,25 @@ async def test_requirement_for_identifier_or_connection(registry): async def test_multiple_config_entries(registry): """Make sure we do not get duplicate entries.""" entry = registry.async_get_or_create( - config_entry='123', + config_entry_id='123', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry2 = registry.async_get_or_create( - config_entry='456', + config_entry_id='456', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry3 = registry.async_get_or_create( - config_entry='123', + config_entry_id='123', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') assert len(registry.devices) == 1 - assert entry is entry2 - assert entry is entry3 - assert entry.config_entries == {'123', '456'} + assert entry.id == entry2.id + assert entry.id == entry3.id + assert entry2.config_entries == {'123', '456'} async def test_loading_from_storage(hass, hass_storage): @@ -118,7 +118,7 @@ async def test_loading_from_storage(hass, hass_storage): registry = await device_registry.async_get_registry(hass) entry = registry.async_get_or_create( - config_entry='1234', + config_entry_id='1234', connections={('Zigbee', '01.23.45.67.89')}, identifiers={('serial', '12:34:56:78:90:AB:CD:EF')}, manufacturer='manufacturer', model='model') @@ -129,25 +129,106 @@ async def test_loading_from_storage(hass, hass_storage): async def test_removing_config_entries(registry): """Make sure we do not get duplicate entries.""" entry = registry.async_get_or_create( - config_entry='123', + config_entry_id='123', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry2 = registry.async_get_or_create( - config_entry='456', + config_entry_id='456', connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, identifiers={('bridgeid', '0123')}, manufacturer='manufacturer', model='model') entry3 = registry.async_get_or_create( - config_entry='123', + config_entry_id='123', connections={('ethernet', '34:56:78:90:AB:CD:EF:12')}, identifiers={('bridgeid', '4567')}, manufacturer='manufacturer', model='model') assert len(registry.devices) == 2 - assert entry is entry2 - assert entry is not entry3 - assert entry.config_entries == {'123', '456'} + assert entry.id == entry2.id + assert entry.id != entry3.id + assert entry2.config_entries == {'123', '456'} + registry.async_clear_config_entry('123') + entry = registry.async_get_device({('bridgeid', '0123')}, set()) + entry3 = registry.async_get_device({('bridgeid', '4567')}, set()) + assert entry.config_entries == {'456'} assert entry3.config_entries == set() + + +async def test_specifying_hub_device_create(registry): + """Test specifying a hub and updating.""" + hub = registry.async_get_or_create( + config_entry_id='123', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers={('hue', '0123')}, + manufacturer='manufacturer', model='hub') + + light = registry.async_get_or_create( + config_entry_id='456', + connections=set(), + identifiers={('hue', '456')}, + manufacturer='manufacturer', model='light', + via_hub=('hue', '0123')) + + assert light.hub_device_id == hub.id + + +async def test_specifying_hub_device_update(registry): + """Test specifying a hub and updating.""" + light = registry.async_get_or_create( + config_entry_id='456', + connections=set(), + identifiers={('hue', '456')}, + manufacturer='manufacturer', model='light', + via_hub=('hue', '0123')) + + assert light.hub_device_id is None + + hub = registry.async_get_or_create( + config_entry_id='123', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers={('hue', '0123')}, + manufacturer='manufacturer', model='hub') + + light = registry.async_get_or_create( + config_entry_id='456', + connections=set(), + identifiers={('hue', '456')}, + manufacturer='manufacturer', model='light', + via_hub=('hue', '0123')) + + assert light.hub_device_id == hub.id + + +async def test_loading_saving_data(hass, registry): + """Test that we load/save data correctly.""" + orig_hub = registry.async_get_or_create( + config_entry_id='123', + connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}, + identifiers={('hue', '0123')}, + manufacturer='manufacturer', model='hub') + + orig_light = registry.async_get_or_create( + config_entry_id='456', + connections=set(), + identifiers={('hue', '456')}, + manufacturer='manufacturer', model='light', + via_hub=('hue', '0123')) + + assert len(registry.devices) == 2 + + # Now load written data in new registry + registry2 = device_registry.DeviceRegistry(hass) + await flush_store(registry._store) + await registry2.async_load() + + # Ensure same order + assert list(registry.devices) == list(registry2.devices) + + new_hub = registry2.async_get_device({('hue', '0123')}, set()) + new_light = registry2.async_get_device({('hue', '456')}, set()) + + assert orig_hub == new_hub + assert orig_light == new_light diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index b51219ddbed0..631d446d1861 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -676,3 +676,55 @@ async def test_entity_registry_updates_invalid_entity_id(hass): assert hass.states.get('test_domain.world') is not None assert hass.states.get('invalid_entity_id') is None assert hass.states.get('diff_domain.world') is None + + +async def test_device_info_called(hass): + """Test device info is forwarded correctly.""" + registry = await hass.helpers.device_registry.async_get_registry() + hub = registry.async_get_or_create( + config_entry_id='123', + connections=set(), + identifiers={('hue', 'hub-id')}, + manufacturer='manufacturer', model='hub' + ) + + async def async_setup_entry(hass, config_entry, async_add_entities): + """Mock setup entry method.""" + async_add_entities([ + # Invalid device info + MockEntity(unique_id='abcd', device_info={}), + # Valid device info + MockEntity(unique_id='qwer', device_info={ + 'identifiers': {('hue', '1234')}, + 'connections': {('mac', 'abcd')}, + 'manufacturer': 'test-manuf', + 'model': 'test-model', + 'name': 'test-name', + 'sw_version': 'test-sw', + 'via_hub': ('hue', 'hub-id'), + }), + ]) + return True + + platform = MockPlatform( + async_setup_entry=async_setup_entry + ) + config_entry = MockConfigEntry(entry_id='super-mock-id') + entity_platform = MockEntityPlatform( + hass, + platform_name=config_entry.domain, + platform=platform + ) + + assert await entity_platform.async_setup_entry(config_entry) + await hass.async_block_till_done() + + device = registry.async_get_device({('hue', '1234')}, set()) + assert device is not None + assert device.identifiers == {('hue', '1234')} + assert device.connections == {('mac', 'abcd')} + assert device.manufacturer == 'test-manuf' + assert device.model == 'test-model' + assert device.name == 'test-name' + assert device.sw_version == 'test-sw' + assert device.hub_device_id == hub.id diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index bb28287ddd81..a8c9086b2d2d 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -6,7 +6,7 @@ import pytest from homeassistant.helpers import entity_registry -from tests.common import mock_registry +from tests.common import mock_registry, flush_store YAML__OPEN_PATH = 'homeassistant.util.yaml.open' @@ -77,8 +77,7 @@ async def test_loading_saving_data(hass, registry): # Now load written data in new registry registry2 = entity_registry.EntityRegistry(hass) - registry2._store = registry._store - + await flush_store(registry._store) await registry2.async_load() # Ensure same order @@ -192,6 +191,8 @@ async def test_removing_config_entry_id(registry): 'light', 'hue', '5678', config_entry_id='mock-id-1') assert entry.config_entry_id == 'mock-id-1' registry.async_clear_config_entry('mock-id-1') + + entry = registry.entities[entry.entity_id] assert entry.config_entry_id is None