Simplify device registry (#77715)

* Simplify device registry

* Fix test fixture

* Update homeassistant/helpers/device_registry.py

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>

* Update device_registry.py

* Remove dead code

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Erik Montnemery 2022-09-03 12:50:55 +02:00 committed by GitHub
parent 7e100b64ea
commit 56278a4421
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 129 deletions

View file

@ -1,11 +1,11 @@
"""Provide a way to connect entities belonging to one device."""
from __future__ import annotations
from collections import OrderedDict
from collections import UserDict
from collections.abc import Coroutine
import logging
import time
from typing import TYPE_CHECKING, Any, NamedTuple, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast
import attr
@ -48,11 +48,6 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30
RUNTIME_ONLY_ATTRS = {"suggested_area"}
class _DeviceIndex(NamedTuple):
identifiers: dict[tuple[str, str], str]
connections: dict[tuple[str, str], str]
class DeviceEntryDisabler(StrEnum):
"""What disabled a device entry."""
@ -149,23 +144,6 @@ def format_mac(mac: str) -> str:
return mac
def _async_get_device_id_from_index(
devices_index: _DeviceIndex,
identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None,
) -> str | None:
"""Check if device has previously been registered."""
for identifier in identifiers:
if identifier in devices_index.identifiers:
return devices_index.identifiers[identifier]
if not connections:
return None
for connection in _normalize_connections(connections):
if connection in devices_index.connections:
return devices_index.connections[connection]
return None
class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
"""Store entity registry data."""
@ -210,13 +188,69 @@ class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
return old_data
_EntryTypeT = TypeVar("_EntryTypeT", DeviceEntry, DeletedDeviceEntry)
class DeviceRegistryItems(UserDict[str, _EntryTypeT]):
"""Container for device registry items, maps device id -> entry.
Maintains two additional indexes:
- (connection_type, connection identifier) -> entry
- (DOMAIN, identifier) -> entry
"""
def __init__(self) -> None:
"""Initialize the container."""
super().__init__()
self._connections: dict[tuple[str, str], _EntryTypeT] = {}
self._identifiers: dict[tuple[str, str], _EntryTypeT] = {}
def __setitem__(self, key: str, entry: _EntryTypeT) -> None:
"""Add an item."""
if key in self:
old_entry = self[key]
for connection in old_entry.connections:
del self._connections[connection]
for identifier in old_entry.identifiers:
del self._identifiers[identifier]
# type ignore linked to mypy issue: https://github.com/python/mypy/issues/13596
super().__setitem__(key, entry) # type: ignore[assignment]
for connection in entry.connections:
self._connections[connection] = entry
for identifier in entry.identifiers:
self._identifiers[identifier] = entry
def __delitem__(self, key: str) -> None:
"""Remove an item."""
entry = self[key]
for connection in entry.connections:
del self._connections[connection]
for identifier in entry.identifiers:
del self._identifiers[identifier]
super().__delitem__(key)
def get_entry(
self,
identifiers: set[tuple[str, str]],
connections: set[tuple[str, str]] | None,
) -> _EntryTypeT | None:
"""Get entry from identifiers or connections."""
for identifier in identifiers:
if identifier in self._identifiers:
return self._identifiers[identifier]
if not connections:
return None
for connection in _normalize_connections(connections):
if connection in self._connections:
return self._connections[connection]
return None
class DeviceRegistry:
"""Class to hold a registry of devices."""
devices: dict[str, DeviceEntry]
deleted_devices: dict[str, DeletedDeviceEntry]
_registered_index: _DeviceIndex
_deleted_index: _DeviceIndex
devices: DeviceRegistryItems[DeviceEntry]
deleted_devices: DeviceRegistryItems[DeletedDeviceEntry]
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the device registry."""
@ -228,7 +262,6 @@ class DeviceRegistry:
atomic_writes=True,
minor_version=STORAGE_VERSION_MINOR,
)
self._clear_index()
@callback
def async_get(self, device_id: str) -> DeviceEntry | None:
@ -242,12 +275,7 @@ class DeviceRegistry:
connections: set[tuple[str, str]] | None = None,
) -> DeviceEntry | None:
"""Check if device is registered."""
device_id = _async_get_device_id_from_index(
self._registered_index, identifiers, connections
)
if device_id is None:
return None
return self.devices[device_id]
return self.devices.get_entry(identifiers, connections)
def _async_get_deleted_device(
self,
@ -255,55 +283,7 @@ class DeviceRegistry:
connections: set[tuple[str, str]] | None,
) -> DeletedDeviceEntry | None:
"""Check if device is deleted."""
device_id = _async_get_device_id_from_index(
self._deleted_index, identifiers, connections
)
if device_id is None:
return None
return self.deleted_devices[device_id]
def _add_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
"""Add a device and index it."""
if isinstance(device, DeletedDeviceEntry):
devices_index = self._deleted_index
self.deleted_devices[device.id] = device
else:
devices_index = self._registered_index
self.devices[device.id] = device
_add_device_to_index(devices_index, device)
def _remove_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
"""Remove a device and remove it from the index."""
if isinstance(device, DeletedDeviceEntry):
devices_index = self._deleted_index
self.deleted_devices.pop(device.id)
else:
devices_index = self._registered_index
self.devices.pop(device.id)
_remove_device_from_index(devices_index, device)
def _update_device(self, old_device: DeviceEntry, new_device: DeviceEntry) -> None:
"""Update a device and the index."""
self.devices[new_device.id] = new_device
devices_index = self._registered_index
_remove_device_from_index(devices_index, old_device)
_add_device_to_index(devices_index, new_device)
def _clear_index(self) -> None:
"""Clear the index."""
self._registered_index = _DeviceIndex(identifiers={}, connections={})
self._deleted_index = _DeviceIndex(identifiers={}, connections={})
def _rebuild_index(self) -> None:
"""Create the index after loading devices."""
self._clear_index()
for device in self.devices.values():
_add_device_to_index(self._registered_index, device)
for deleted_device in self.deleted_devices.values():
_add_device_to_index(self._deleted_index, deleted_device)
return self.deleted_devices.get_entry(identifiers, connections)
@callback
def async_get_or_create(
@ -346,11 +326,11 @@ class DeviceRegistry:
if deleted_device is None:
device = DeviceEntry(is_new=True)
else:
self._remove_device(deleted_device)
self.deleted_devices.pop(deleted_device.id)
device = deleted_device.to_device_entry(
config_entry_id, connections, identifiers
)
self._add_device(device)
self.devices[device.id] = device
if default_manufacturer is not UNDEFINED and device.manufacturer is None:
manufacturer = default_manufacturer
@ -516,7 +496,7 @@ class DeviceRegistry:
return old
new = attr.evolve(old, **new_values)
self._update_device(old, new)
self.devices[device_id] = new
# If its only run time attributes (suggested_area)
# that do not get saved we do not want to write
@ -542,16 +522,13 @@ class DeviceRegistry:
@callback
def async_remove_device(self, device_id: str) -> None:
"""Remove a device from the device registry."""
device = self.devices[device_id]
self._remove_device(device)
self._add_device(
DeletedDeviceEntry(
config_entries=device.config_entries,
connections=device.connections,
identifiers=device.identifiers,
id=device.id,
orphaned_timestamp=None,
)
device = self.devices.pop(device_id)
self.deleted_devices[device_id] = DeletedDeviceEntry(
config_entries=device.config_entries,
connections=device.connections,
identifiers=device.identifiers,
id=device.id,
orphaned_timestamp=None,
)
for other_device in list(self.devices.values()):
if other_device.via_device_id == device_id:
@ -567,8 +544,8 @@ class DeviceRegistry:
data = await self._store.async_load()
devices = OrderedDict()
deleted_devices = OrderedDict()
devices: DeviceRegistryItems[DeviceEntry] = DeviceRegistryItems()
deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] = DeviceRegistryItems()
if data is not None:
for device in data["devices"]:
@ -607,7 +584,6 @@ class DeviceRegistry:
self.devices = devices
self.deleted_devices = deleted_devices
self._rebuild_index()
@callback
def async_schedule_save(self) -> None:
@ -692,7 +668,7 @@ class DeviceRegistry:
deleted_device.orphaned_timestamp + ORPHANED_DEVICE_KEEP_SECONDS
< now_time
):
self._remove_device(deleted_device)
del self.deleted_devices[deleted_device.id]
@callback
def async_clear_area_id(self, area_id: str) -> None:
@ -879,27 +855,3 @@ def _normalize_connections(connections: set[tuple[str, str]]) -> set[tuple[str,
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
for key, value in connections
}
def _add_device_to_index(
devices_index: _DeviceIndex,
device: DeviceEntry | DeletedDeviceEntry,
) -> None:
"""Add a device to the index."""
for identifier in device.identifiers:
devices_index.identifiers[identifier] = device.id
for connection in device.connections:
devices_index.connections[connection] = device.id
def _remove_device_from_index(
devices_index: _DeviceIndex,
device: DeviceEntry | DeletedDeviceEntry,
) -> None:
"""Remove a device from the index."""
for identifier in device.identifiers:
if identifier in devices_index.identifiers:
del devices_index.identifiers[identifier]
for connection in device.connections:
if connection in devices_index.connections:
del devices_index.connections[connection]

View file

@ -469,12 +469,15 @@ def mock_area_registry(hass, mock_entries=None):
return registry
def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None):
def mock_device_registry(hass, mock_entries=None):
"""Mock the Device Registry."""
registry = device_registry.DeviceRegistry(hass)
registry.devices = mock_entries or OrderedDict()
registry.deleted_devices = mock_deleted_entries or OrderedDict()
registry._rebuild_index()
registry.devices = device_registry.DeviceRegistryItems()
if mock_entries is None:
mock_entries = {}
for key, entry in mock_entries.items():
registry.devices[key] = entry
registry.deleted_devices = device_registry.DeviceRegistryItems()
hass.data[device_registry.DATA_REGISTRY] = registry
return registry