mirror of
https://github.com/home-assistant/core
synced 2024-10-07 09:48:00 +00:00
Simplify device registry (#77715)
* Simplify device registry * Fix test fixture * Update homeassistant/helpers/device_registry.py Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> * Update device_registry.py * Remove dead code Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
parent
7e100b64ea
commit
56278a4421
|
@ -1,11 +1,11 @@
|
|||
"""Provide a way to connect entities belonging to one device."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections import UserDict
|
||||
from collections.abc import Coroutine
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, cast
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -48,11 +48,6 @@ ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30
|
|||
RUNTIME_ONLY_ATTRS = {"suggested_area"}
|
||||
|
||||
|
||||
class _DeviceIndex(NamedTuple):
|
||||
identifiers: dict[tuple[str, str], str]
|
||||
connections: dict[tuple[str, str], str]
|
||||
|
||||
|
||||
class DeviceEntryDisabler(StrEnum):
|
||||
"""What disabled a device entry."""
|
||||
|
||||
|
@ -149,23 +144,6 @@ def format_mac(mac: str) -> str:
|
|||
return mac
|
||||
|
||||
|
||||
def _async_get_device_id_from_index(
|
||||
devices_index: _DeviceIndex,
|
||||
identifiers: set[tuple[str, str]],
|
||||
connections: set[tuple[str, str]] | None,
|
||||
) -> str | None:
|
||||
"""Check if device has previously been registered."""
|
||||
for identifier in identifiers:
|
||||
if identifier in devices_index.identifiers:
|
||||
return devices_index.identifiers[identifier]
|
||||
if not connections:
|
||||
return None
|
||||
for connection in _normalize_connections(connections):
|
||||
if connection in devices_index.connections:
|
||||
return devices_index.connections[connection]
|
||||
return None
|
||||
|
||||
|
||||
class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
|
||||
"""Store entity registry data."""
|
||||
|
||||
|
@ -210,13 +188,69 @@ class DeviceRegistryStore(storage.Store[dict[str, list[dict[str, Any]]]]):
|
|||
return old_data
|
||||
|
||||
|
||||
_EntryTypeT = TypeVar("_EntryTypeT", DeviceEntry, DeletedDeviceEntry)
|
||||
|
||||
|
||||
class DeviceRegistryItems(UserDict[str, _EntryTypeT]):
|
||||
"""Container for device registry items, maps device id -> entry.
|
||||
|
||||
Maintains two additional indexes:
|
||||
- (connection_type, connection identifier) -> entry
|
||||
- (DOMAIN, identifier) -> entry
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the container."""
|
||||
super().__init__()
|
||||
self._connections: dict[tuple[str, str], _EntryTypeT] = {}
|
||||
self._identifiers: dict[tuple[str, str], _EntryTypeT] = {}
|
||||
|
||||
def __setitem__(self, key: str, entry: _EntryTypeT) -> None:
|
||||
"""Add an item."""
|
||||
if key in self:
|
||||
old_entry = self[key]
|
||||
for connection in old_entry.connections:
|
||||
del self._connections[connection]
|
||||
for identifier in old_entry.identifiers:
|
||||
del self._identifiers[identifier]
|
||||
# type ignore linked to mypy issue: https://github.com/python/mypy/issues/13596
|
||||
super().__setitem__(key, entry) # type: ignore[assignment]
|
||||
for connection in entry.connections:
|
||||
self._connections[connection] = entry
|
||||
for identifier in entry.identifiers:
|
||||
self._identifiers[identifier] = entry
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""Remove an item."""
|
||||
entry = self[key]
|
||||
for connection in entry.connections:
|
||||
del self._connections[connection]
|
||||
for identifier in entry.identifiers:
|
||||
del self._identifiers[identifier]
|
||||
super().__delitem__(key)
|
||||
|
||||
def get_entry(
|
||||
self,
|
||||
identifiers: set[tuple[str, str]],
|
||||
connections: set[tuple[str, str]] | None,
|
||||
) -> _EntryTypeT | None:
|
||||
"""Get entry from identifiers or connections."""
|
||||
for identifier in identifiers:
|
||||
if identifier in self._identifiers:
|
||||
return self._identifiers[identifier]
|
||||
if not connections:
|
||||
return None
|
||||
for connection in _normalize_connections(connections):
|
||||
if connection in self._connections:
|
||||
return self._connections[connection]
|
||||
return None
|
||||
|
||||
|
||||
class DeviceRegistry:
|
||||
"""Class to hold a registry of devices."""
|
||||
|
||||
devices: dict[str, DeviceEntry]
|
||||
deleted_devices: dict[str, DeletedDeviceEntry]
|
||||
_registered_index: _DeviceIndex
|
||||
_deleted_index: _DeviceIndex
|
||||
devices: DeviceRegistryItems[DeviceEntry]
|
||||
deleted_devices: DeviceRegistryItems[DeletedDeviceEntry]
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the device registry."""
|
||||
|
@ -228,7 +262,6 @@ class DeviceRegistry:
|
|||
atomic_writes=True,
|
||||
minor_version=STORAGE_VERSION_MINOR,
|
||||
)
|
||||
self._clear_index()
|
||||
|
||||
@callback
|
||||
def async_get(self, device_id: str) -> DeviceEntry | None:
|
||||
|
@ -242,12 +275,7 @@ class DeviceRegistry:
|
|||
connections: set[tuple[str, str]] | None = None,
|
||||
) -> DeviceEntry | None:
|
||||
"""Check if device is registered."""
|
||||
device_id = _async_get_device_id_from_index(
|
||||
self._registered_index, identifiers, connections
|
||||
)
|
||||
if device_id is None:
|
||||
return None
|
||||
return self.devices[device_id]
|
||||
return self.devices.get_entry(identifiers, connections)
|
||||
|
||||
def _async_get_deleted_device(
|
||||
self,
|
||||
|
@ -255,55 +283,7 @@ class DeviceRegistry:
|
|||
connections: set[tuple[str, str]] | None,
|
||||
) -> DeletedDeviceEntry | None:
|
||||
"""Check if device is deleted."""
|
||||
device_id = _async_get_device_id_from_index(
|
||||
self._deleted_index, identifiers, connections
|
||||
)
|
||||
if device_id is None:
|
||||
return None
|
||||
return self.deleted_devices[device_id]
|
||||
|
||||
def _add_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
|
||||
"""Add a device and index it."""
|
||||
if isinstance(device, DeletedDeviceEntry):
|
||||
devices_index = self._deleted_index
|
||||
self.deleted_devices[device.id] = device
|
||||
else:
|
||||
devices_index = self._registered_index
|
||||
self.devices[device.id] = device
|
||||
|
||||
_add_device_to_index(devices_index, device)
|
||||
|
||||
def _remove_device(self, device: DeviceEntry | DeletedDeviceEntry) -> None:
|
||||
"""Remove a device and remove it from the index."""
|
||||
if isinstance(device, DeletedDeviceEntry):
|
||||
devices_index = self._deleted_index
|
||||
self.deleted_devices.pop(device.id)
|
||||
else:
|
||||
devices_index = self._registered_index
|
||||
self.devices.pop(device.id)
|
||||
|
||||
_remove_device_from_index(devices_index, device)
|
||||
|
||||
def _update_device(self, old_device: DeviceEntry, new_device: DeviceEntry) -> None:
|
||||
"""Update a device and the index."""
|
||||
self.devices[new_device.id] = new_device
|
||||
|
||||
devices_index = self._registered_index
|
||||
_remove_device_from_index(devices_index, old_device)
|
||||
_add_device_to_index(devices_index, new_device)
|
||||
|
||||
def _clear_index(self) -> None:
|
||||
"""Clear the index."""
|
||||
self._registered_index = _DeviceIndex(identifiers={}, connections={})
|
||||
self._deleted_index = _DeviceIndex(identifiers={}, connections={})
|
||||
|
||||
def _rebuild_index(self) -> None:
|
||||
"""Create the index after loading devices."""
|
||||
self._clear_index()
|
||||
for device in self.devices.values():
|
||||
_add_device_to_index(self._registered_index, device)
|
||||
for deleted_device in self.deleted_devices.values():
|
||||
_add_device_to_index(self._deleted_index, deleted_device)
|
||||
return self.deleted_devices.get_entry(identifiers, connections)
|
||||
|
||||
@callback
|
||||
def async_get_or_create(
|
||||
|
@ -346,11 +326,11 @@ class DeviceRegistry:
|
|||
if deleted_device is None:
|
||||
device = DeviceEntry(is_new=True)
|
||||
else:
|
||||
self._remove_device(deleted_device)
|
||||
self.deleted_devices.pop(deleted_device.id)
|
||||
device = deleted_device.to_device_entry(
|
||||
config_entry_id, connections, identifiers
|
||||
)
|
||||
self._add_device(device)
|
||||
self.devices[device.id] = device
|
||||
|
||||
if default_manufacturer is not UNDEFINED and device.manufacturer is None:
|
||||
manufacturer = default_manufacturer
|
||||
|
@ -516,7 +496,7 @@ class DeviceRegistry:
|
|||
return old
|
||||
|
||||
new = attr.evolve(old, **new_values)
|
||||
self._update_device(old, new)
|
||||
self.devices[device_id] = new
|
||||
|
||||
# If its only run time attributes (suggested_area)
|
||||
# that do not get saved we do not want to write
|
||||
|
@ -542,16 +522,13 @@ class DeviceRegistry:
|
|||
@callback
|
||||
def async_remove_device(self, device_id: str) -> None:
|
||||
"""Remove a device from the device registry."""
|
||||
device = self.devices[device_id]
|
||||
self._remove_device(device)
|
||||
self._add_device(
|
||||
DeletedDeviceEntry(
|
||||
config_entries=device.config_entries,
|
||||
connections=device.connections,
|
||||
identifiers=device.identifiers,
|
||||
id=device.id,
|
||||
orphaned_timestamp=None,
|
||||
)
|
||||
device = self.devices.pop(device_id)
|
||||
self.deleted_devices[device_id] = DeletedDeviceEntry(
|
||||
config_entries=device.config_entries,
|
||||
connections=device.connections,
|
||||
identifiers=device.identifiers,
|
||||
id=device.id,
|
||||
orphaned_timestamp=None,
|
||||
)
|
||||
for other_device in list(self.devices.values()):
|
||||
if other_device.via_device_id == device_id:
|
||||
|
@ -567,8 +544,8 @@ class DeviceRegistry:
|
|||
|
||||
data = await self._store.async_load()
|
||||
|
||||
devices = OrderedDict()
|
||||
deleted_devices = OrderedDict()
|
||||
devices: DeviceRegistryItems[DeviceEntry] = DeviceRegistryItems()
|
||||
deleted_devices: DeviceRegistryItems[DeletedDeviceEntry] = DeviceRegistryItems()
|
||||
|
||||
if data is not None:
|
||||
for device in data["devices"]:
|
||||
|
@ -607,7 +584,6 @@ class DeviceRegistry:
|
|||
|
||||
self.devices = devices
|
||||
self.deleted_devices = deleted_devices
|
||||
self._rebuild_index()
|
||||
|
||||
@callback
|
||||
def async_schedule_save(self) -> None:
|
||||
|
@ -692,7 +668,7 @@ class DeviceRegistry:
|
|||
deleted_device.orphaned_timestamp + ORPHANED_DEVICE_KEEP_SECONDS
|
||||
< now_time
|
||||
):
|
||||
self._remove_device(deleted_device)
|
||||
del self.deleted_devices[deleted_device.id]
|
||||
|
||||
@callback
|
||||
def async_clear_area_id(self, area_id: str) -> None:
|
||||
|
@ -879,27 +855,3 @@ def _normalize_connections(connections: set[tuple[str, str]]) -> set[tuple[str,
|
|||
(key, format_mac(value)) if key == CONNECTION_NETWORK_MAC else (key, value)
|
||||
for key, value in connections
|
||||
}
|
||||
|
||||
|
||||
def _add_device_to_index(
|
||||
devices_index: _DeviceIndex,
|
||||
device: DeviceEntry | DeletedDeviceEntry,
|
||||
) -> None:
|
||||
"""Add a device to the index."""
|
||||
for identifier in device.identifiers:
|
||||
devices_index.identifiers[identifier] = device.id
|
||||
for connection in device.connections:
|
||||
devices_index.connections[connection] = device.id
|
||||
|
||||
|
||||
def _remove_device_from_index(
|
||||
devices_index: _DeviceIndex,
|
||||
device: DeviceEntry | DeletedDeviceEntry,
|
||||
) -> None:
|
||||
"""Remove a device from the index."""
|
||||
for identifier in device.identifiers:
|
||||
if identifier in devices_index.identifiers:
|
||||
del devices_index.identifiers[identifier]
|
||||
for connection in device.connections:
|
||||
if connection in devices_index.connections:
|
||||
del devices_index.connections[connection]
|
||||
|
|
|
@ -469,12 +469,15 @@ def mock_area_registry(hass, mock_entries=None):
|
|||
return registry
|
||||
|
||||
|
||||
def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None):
|
||||
def mock_device_registry(hass, mock_entries=None):
|
||||
"""Mock the Device Registry."""
|
||||
registry = device_registry.DeviceRegistry(hass)
|
||||
registry.devices = mock_entries or OrderedDict()
|
||||
registry.deleted_devices = mock_deleted_entries or OrderedDict()
|
||||
registry._rebuild_index()
|
||||
registry.devices = device_registry.DeviceRegistryItems()
|
||||
if mock_entries is None:
|
||||
mock_entries = {}
|
||||
for key, entry in mock_entries.items():
|
||||
registry.devices[key] = entry
|
||||
registry.deleted_devices = device_registry.DeviceRegistryItems()
|
||||
|
||||
hass.data[device_registry.DATA_REGISTRY] = registry
|
||||
return registry
|
||||
|
|
Loading…
Reference in a new issue