diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 1cc850f31e44..e87ee1ae2820 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -27,6 +27,7 @@ from .exceptions import HomeAssistantError from .helpers import ( area_registry, device_registry, + entity, entity_registry, issue_registry, recorder, @@ -236,6 +237,7 @@ async def load_registries(hass: core.HomeAssistant) -> None: platform.uname().processor # pylint: disable=expression-not-assigned # Load the registries and cache the result of platform.uname().processor + entity.async_setup(hass) await asyncio.gather( area_registry.async_load(hass), device_registry.async_load(hass), diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index a6cd56af7337..7070df6f9492 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -30,7 +30,7 @@ from homeassistant.const import ( MATCH_ALL, ) from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback -from homeassistant.helpers import entity_registry +from homeassistant.helpers.entity import entity_sources from homeassistant.helpers.event import ( async_track_time_change, async_track_time_interval, @@ -185,7 +185,7 @@ class Recorder(threading.Thread): self._queue_watch = threading.Event() self.engine: Engine | None = None self.run_history = RunHistory() - self._entity_registry = entity_registry.async_get(hass) + self._entity_sources = entity_sources(hass) # The entity_filter is exposed on the recorder instance so that # it can be used to see if an entity is being recorded and is called @@ -878,7 +878,7 @@ class Recorder(threading.Thread): dbstate = States.from_event(event) shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event( event, - self._entity_registry, + self._entity_sources, self._exclude_attributes_by_domain, self.dialect_name, ) diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index 88a8478047f5..1cd130bbf901 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -41,7 +41,6 @@ from homeassistant.const import ( MAX_LENGTH_STATE_STATE, ) from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id -from homeassistant.helpers import entity_registry as er from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null import homeassistant.util.dt as dt_util from homeassistant.util.json import ( @@ -460,7 +459,7 @@ class StateAttributes(Base): @staticmethod def shared_attrs_bytes_from_event( event: Event, - entity_registry: er.EntityRegistry, + entity_sources: dict[str, dict[str, str]], exclude_attrs_by_domain: dict[str, set[str]], dialect: SupportedDialect | None, ) -> bytes: @@ -473,8 +472,8 @@ class StateAttributes(Base): exclude_attrs = set(ALL_DOMAIN_EXCLUDE_ATTRS) if base_platform_attrs := exclude_attrs_by_domain.get(domain): exclude_attrs |= base_platform_attrs - if (reg_ent := entity_registry.async_get(state.entity_id)) and ( - integration_attrs := exclude_attrs_by_domain.get(reg_ent.platform) + if (entity_info := entity_sources.get(state.entity_id)) and ( + integration_attrs := exclude_attrs_by_domain.get(entity_info["domain"]) ): exclude_attrs |= integration_attrs encoder = json_bytes_strip_null if dialect == PSQL_DIALECT else json_bytes diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 85f5f381e91c..c4dfd7e9c5b5 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -57,11 +57,18 @@ SOURCE_PLATFORM_CONFIG = "platform_config" FLOAT_PRECISION = abs(int(math.floor(math.log10(abs(sys.float_info.epsilon))))) - 1 +@callback +def async_setup(hass: HomeAssistant) -> None: + """Set up entity sources.""" + hass.data[DATA_ENTITY_SOURCE] = {} + + @callback @bind_hass def entity_sources(hass: HomeAssistant) -> dict[str, dict[str, str]]: """Get the entity sources.""" - return hass.data.get(DATA_ENTITY_SOURCE, {}) + _entity_sources: dict[str, dict[str, str]] = hass.data[DATA_ENTITY_SOURCE] + return _entity_sources def generate_entity_id( @@ -868,7 +875,7 @@ class Entity(ABC): else: info["source"] = SOURCE_PLATFORM_CONFIG - self.hass.data.setdefault(DATA_ENTITY_SOURCE, {})[self.entity_id] = info + self.hass.data[DATA_ENTITY_SOURCE][self.entity_id] = info if self.registry_entry is not None: # This is an assert as it should never happen, but helps in tests diff --git a/tests/common.py b/tests/common.py index 380756db20f3..66875eb6e9f3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -247,6 +247,7 @@ async def async_test_home_assistant(event_loop, load_registries=True): ) # Load the registries + entity.async_setup(hass) if load_registries: with patch("homeassistant.helpers.storage.Store.async_load", return_value=None): await asyncio.gather( @@ -1087,6 +1088,11 @@ class MockEntity(entity.Entity): """Return the entity category.""" return self._handle("entity_category") + @property + def extra_state_attributes(self) -> Mapping[str, Any] | None: + """Return entity specific state attributes.""" + return self._handle("extra_state_attributes") + @property def has_entity_name(self) -> bool: """Return the has_entity_name name flag.""" diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 7d188f982f1b..3900585d544b 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -75,6 +75,8 @@ from .common import ( ) from tests.common import ( + MockEntity, + MockEntityPlatform, async_fire_time_changed, fire_time_changed, get_test_home_assistant, @@ -2037,12 +2039,6 @@ async def test_excluding_attributes_by_integration( """Test that an integration's recorder platform can exclude attributes.""" state = "restoring_from_db" attributes = {"test_attr": 5, "excluded": 10} - entry = entity_registry.async_get_or_create( - "test", - "fake_integration", - "recorder", - ) - entity_id = entry.entity_id mock_platform( hass, "fake_integration.recorder", @@ -2051,7 +2047,12 @@ async def test_excluding_attributes_by_integration( hass.config.components.add("fake_integration") hass.bus.async_fire(EVENT_COMPONENT_LOADED, {"component": "fake_integration"}) await hass.async_block_till_done() - hass.states.async_set(entity_id, state, attributes) + + entity_id = "test.fake_integration_recorder" + platform = MockEntityPlatform(hass, platform_name="fake_integration") + entity_platform = MockEntity(entity_id=entity_id, extra_state_attributes=attributes) + await platform.async_add_entities([entity_platform]) + await async_wait_recording_done(hass) with session_scope(hass=hass) as session: diff --git a/tests/components/recorder/test_models.py b/tests/components/recorder/test_models.py index 8089ea1ed7cb..a1ab4508042d 100644 --- a/tests/components/recorder/test_models.py +++ b/tests/components/recorder/test_models.py @@ -26,7 +26,6 @@ from homeassistant.const import EVENT_STATE_CHANGED import homeassistant.core as ha from homeassistant.core import HomeAssistant from homeassistant.exceptions import InvalidEntityFormatError -from homeassistant.helpers import entity_registry as er from homeassistant.util import dt, dt as dt_util @@ -50,7 +49,7 @@ def test_from_event_to_db_state() -> None: assert state.as_dict() == States.from_event(event).to_native().as_dict() -def test_from_event_to_db_state_attributes(entity_registry: er.EntityRegistry) -> None: +def test_from_event_to_db_state_attributes() -> None: """Test converting event to db state attributes.""" attrs = {"this_attr": True} state = ha.State("sensor.temperature", "18", attrs) @@ -63,7 +62,7 @@ def test_from_event_to_db_state_attributes(entity_registry: er.EntityRegistry) - dialect = SupportedDialect.MYSQL db_attrs.shared_attrs = StateAttributes.shared_attrs_bytes_from_event( - event, entity_registry, {}, dialect + event, {}, {}, dialect ) assert db_attrs.to_native() == attrs