diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 53ad54c5ed18..0e4d80ac0801 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -355,6 +355,7 @@ class EntityPlatform: capabilities=entity.capability_attributes, supported_features=entity.supported_features, device_class=entity.device_class, + unit_of_measurement=entity.unit_of_measurement, ) entity.registry_entry = entry diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index acb155ae594a..635f7feba130 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -18,6 +18,7 @@ import attr from homeassistant.const import ( ATTR_DEVICE_CLASS, ATTR_SUPPORTED_FEATURES, + ATTR_UNIT_OF_MEASUREMENT, EVENT_HOMEASSISTANT_START, STATE_UNAVAILABLE, ) @@ -77,6 +78,7 @@ class RegistryEntry: capabilities: Optional[Dict[str, Any]] = attr.ib(default=None) supported_features: int = attr.ib(default=0) device_class: Optional[str] = attr.ib(default=None) + unit_of_measurement: Optional[str] = attr.ib(default=None) domain = attr.ib(type=str, init=False, repr=False) @domain.default @@ -164,6 +166,7 @@ class EntityRegistry: capabilities: Optional[Dict[str, Any]] = None, supported_features: Optional[int] = None, device_class: Optional[str] = None, + unit_of_measurement: Optional[str] = None, ) -> RegistryEntry: """Get entity. Create if it doesn't exist.""" config_entry_id = None @@ -180,6 +183,7 @@ class EntityRegistry: capabilities=capabilities or _UNDEF, supported_features=supported_features or _UNDEF, device_class=device_class or _UNDEF, + unit_of_measurement=unit_of_measurement or _UNDEF, # When we changed our slugify algorithm, we invalidated some # stored entity IDs with either a __ or ending in _. # Fix introduced in 0.86 (Jan 23, 2019). Next line can be @@ -210,6 +214,7 @@ class EntityRegistry: capabilities=capabilities, supported_features=supported_features or 0, device_class=device_class, + unit_of_measurement=unit_of_measurement, ) self.entities[entity_id] = entity _LOGGER.info("Registered new %s.%s entity: %s", domain, platform, entity_id) @@ -279,6 +284,7 @@ class EntityRegistry: capabilities=_UNDEF, supported_features=_UNDEF, device_class=_UNDEF, + unit_of_measurement=_UNDEF, ): """Private facing update properties method.""" old = self.entities[entity_id] @@ -293,6 +299,7 @@ class EntityRegistry: ("capabilities", capabilities), ("supported_features", supported_features), ("device_class", device_class), + ("unit_of_measurement", unit_of_measurement), ): if value is not _UNDEF and value != getattr(old, attr_name): changes[attr_name] = value @@ -369,6 +376,7 @@ class EntityRegistry: capabilities=entity.get("capabilities") or {}, supported_features=entity.get("supported_features", 0), device_class=entity.get("device_class"), + unit_of_measurement=entity.get("unit_of_measurement"), ) self.entities = entities @@ -395,6 +403,7 @@ class EntityRegistry: "capabilities": entry.capabilities, "supported_features": entry.supported_features, "device_class": entry.device_class, + "unit_of_measurement": entry.unit_of_measurement, } for entry in self.entities.values() ] @@ -511,6 +520,9 @@ def async_setup_entity_restore( if entry.device_class is not None: attrs[ATTR_DEVICE_CLASS] = entry.device_class + if entry.unit_of_measurement is not None: + attrs[ATTR_UNIT_OF_MEASUREMENT] = entry.unit_of_measurement + states.async_set(entry.entity_id, STATE_UNAVAILABLE, attrs) hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states) diff --git a/tests/common.py b/tests/common.py index fd40b08635f7..5a00a2bc7df4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -922,6 +922,11 @@ class MockEntity(entity.Entity): """Info how device should be classified.""" return self._handle("device_class") + @property + def unit_of_measurement(self): + """Info on the units the entity state is in.""" + return self._handle("unit_of_measurement") + @property def capability_attributes(self): """Info about capabilities.""" diff --git a/tests/components/homekit/test_type_sensors.py b/tests/components/homekit/test_type_sensors.py index 43533840cc6a..969ea0bddc82 100644 --- a/tests/components/homekit/test_type_sensors.py +++ b/tests/components/homekit/test_type_sensors.py @@ -1,4 +1,5 @@ """Test different accessory types: Sensors.""" +from homeassistant.components.homekit import get_accessory from homeassistant.components.homekit.const import ( PROP_CELSIUS, THRESHOLD_CO, @@ -17,6 +18,7 @@ from homeassistant.components.homekit.type_sensors import ( from homeassistant.const import ( ATTR_DEVICE_CLASS, ATTR_UNIT_OF_MEASUREMENT, + EVENT_HOMEASSISTANT_START, STATE_HOME, STATE_NOT_HOME, STATE_OFF, @@ -25,6 +27,8 @@ from homeassistant.const import ( TEMP_CELSIUS, TEMP_FAHRENHEIT, ) +from homeassistant.core import CoreState +from homeassistant.helpers import entity_registry async def test_temperature(hass, hk_driver): @@ -262,3 +266,34 @@ async def test_binary_device_classes(hass, hk_driver): acc = BinarySensor(hass, hk_driver, "Binary Sensor", entity_id, 2, None) assert acc.get_service(service).display_name == service assert acc.char_detected.display_name == char + + +async def test_sensor_restore(hass, hk_driver, events): + """Test setting up an entity from state in the event registry.""" + hass.state = CoreState.not_running + + registry = await entity_registry.async_get_registry(hass) + + registry.async_get_or_create( + "sensor", + "generic", + "1234", + suggested_object_id="temperature", + device_class="temperature", + ) + registry.async_get_or_create( + "sensor", + "generic", + "12345", + suggested_object_id="humidity", + device_class="humidity", + unit_of_measurement="%", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_START, {}) + await hass.async_block_till_done() + + acc = get_accessory(hass, hk_driver, hass.states.get("sensor.temperature"), 2, {}) + assert acc.category == 10 + + acc = get_accessory(hass, hk_driver, hass.states.get("sensor.humidity"), 2, {}) + assert acc.category == 10 diff --git a/tests/helpers/test_entity_platform.py b/tests/helpers/test_entity_platform.py index 7797bf5057b8..8eea8ad004f8 100644 --- a/tests/helpers/test_entity_platform.py +++ b/tests/helpers/test_entity_platform.py @@ -804,6 +804,7 @@ async def test_entity_info_added_to_entity_registry(hass): capability_attributes={"max": 100}, supported_features=5, device_class="mock-device-class", + unit_of_measurement="%", ) await component.async_add_entities([entity_default]) @@ -815,6 +816,7 @@ async def test_entity_info_added_to_entity_registry(hass): assert entry_default.capabilities == {"max": 100} assert entry_default.supported_features == 5 assert entry_default.device_class == "mock-device-class" + assert entry_default.unit_of_measurement == "%" async def test_override_restored_entities(hass):