Avoid fetching metadata multiple times during stat compile (#70397)

This commit is contained in:
J. Nick Koston 2022-04-22 00:25:42 -10:00 committed by GitHub
parent be0fbba523
commit 3737b58e85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 26 deletions

View file

@ -157,6 +157,14 @@ DISPLAY_UNIT_TO_STATISTIC_UNIT_CONVERSIONS: dict[
_LOGGER = logging.getLogger(__name__)
@dataclasses.dataclass
class PlatformCompiledStatistics:
"""Compiled Statistics from a platform."""
platform_stats: list[StatisticResult]
current_metadata: dict[str, tuple[int, StatisticMetaData]]
def split_statistic_id(entity_id: str) -> list[str]:
"""Split a state entity ID into domain and object ID."""
return entity_id.split(":", 1)
@ -550,28 +558,32 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
_LOGGER.debug("Compiling statistics for %s-%s", start, end)
platform_stats: list[StatisticResult] = []
current_metadata: dict[str, tuple[int, StatisticMetaData]] = {}
# Collect statistics from all platforms implementing support
for domain, platform in instance.hass.data[DOMAIN].items():
if not hasattr(platform, "compile_statistics"):
continue
platform_stat = platform.compile_statistics(instance.hass, start, end)
_LOGGER.debug(
"Statistics for %s during %s-%s: %s", domain, start, end, platform_stat
compiled: PlatformCompiledStatistics = platform.compile_statistics(
instance.hass, start, end
)
platform_stats.extend(platform_stat)
_LOGGER.debug(
"Statistics for %s during %s-%s: %s",
domain,
start,
end,
compiled.platform_stats,
)
platform_stats.extend(compiled.platform_stats)
current_metadata.update(compiled.current_metadata)
# Insert collected statistics in the database
with session_scope(
session=instance.get_session(), # type: ignore[misc]
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
statistic_ids = [stats["meta"]["statistic_id"] for stats in platform_stats]
old_metadata_dict = get_metadata_with_session(
instance.hass, session, statistic_ids=statistic_ids
)
for stats in platform_stats:
metadata_id = _update_or_add_metadata(
session, stats["meta"], old_metadata_dict
session, stats["meta"], current_metadata
)
_insert_statistics(
session,
@ -1102,14 +1114,19 @@ def get_last_short_term_statistics(
def get_latest_short_term_statistics(
hass: HomeAssistant, statistic_ids: list[str]
hass: HomeAssistant,
statistic_ids: list[str],
metadata: dict[str, tuple[int, StatisticMetaData]] | None = None,
) -> dict[str, list[dict]]:
"""Return the latest short term statistics for a list of statistic_ids."""
# This function doesn't use a baked query, we instead rely on the
# "Transparent SQL Compilation Caching" feature introduced in SQLAlchemy 1.4
with session_scope(hass=hass) as session:
# Fetch metadata for the given statistic_ids
metadata = get_metadata_with_session(hass, session, statistic_ids=statistic_ids)
if not metadata:
metadata = get_metadata_with_session(
hass, session, statistic_ids=statistic_ids
)
if not metadata:
return {}
metadata_ids = [

View file

@ -387,14 +387,14 @@ def _last_reset_as_utc_isoformat(last_reset_s: Any, entity_id: str) -> str | Non
def compile_statistics(
hass: HomeAssistant, start: datetime.datetime, end: datetime.datetime
) -> list[StatisticResult]:
) -> statistics.PlatformCompiledStatistics:
"""Compile statistics for all entities during start-end.
Note: This will query the database and must not be run in the event loop
"""
with recorder_util.session_scope(hass=hass) as session:
result = _compile_statistics(hass, session, start, end)
return result
compiled = _compile_statistics(hass, session, start, end)
return compiled
def _compile_statistics( # noqa: C901
@ -402,7 +402,7 @@ def _compile_statistics( # noqa: C901
session: Session,
start: datetime.datetime,
end: datetime.datetime,
) -> list[StatisticResult]:
) -> statistics.PlatformCompiledStatistics:
"""Compile statistics for all entities during start-end."""
result: list[StatisticResult] = []
@ -473,7 +473,9 @@ def _compile_statistics( # noqa: C901
if "sum" in wanted_statistics[entity_id]:
to_query.append(entity_id)
last_stats = statistics.get_latest_short_term_statistics(hass, to_query)
last_stats = statistics.get_latest_short_term_statistics(
hass, to_query, metadata=old_metadatas
)
for ( # pylint: disable=too-many-nested-blocks
entity_id,
unit,
@ -609,7 +611,7 @@ def _compile_statistics( # noqa: C901
result.append({"meta": meta, "stat": stat})
return result
return statistics.PlatformCompiledStatistics(result, old_metadatas)
def list_statistic_ids(

View file

@ -106,7 +106,7 @@ async def test_cost_sensor_price_entity_total_increasing(
"""Test energy cost price from total_increasing type sensor entity."""
def _compile_statistics(_):
return compile_statistics(hass, now, now + timedelta(seconds=1))
return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats
energy_attributes = {
ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR,
@ -311,7 +311,7 @@ async def test_cost_sensor_price_entity_total(
"""Test energy cost price from total type sensor entity."""
def _compile_statistics(_):
return compile_statistics(hass, now, now + timedelta(seconds=1))
return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats
energy_attributes = {
ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR,
@ -518,7 +518,7 @@ async def test_cost_sensor_price_entity_total_no_reset(
"""Test energy cost price from total type sensor entity with no last_reset."""
def _compile_statistics(_):
return compile_statistics(hass, now, now + timedelta(seconds=1))
return compile_statistics(hass, now, now + timedelta(seconds=1)).platform_stats
energy_attributes = {
ATTR_UNIT_OF_MEASUREMENT: ENERGY_KILO_WATT_HOUR,

View file

@ -124,6 +124,11 @@ def test_compile_hourly_statistics(hass_recorder):
stats = get_latest_short_term_statistics(hass, ["sensor.test1"])
assert stats == {"sensor.test1": [{**expected_2, "statistic_id": "sensor.test1"}]}
metadata = get_metadata(hass, statistic_ids=['sensor.test1"'])
stats = get_latest_short_term_statistics(hass, ["sensor.test1"], metadata=metadata)
assert stats == {"sensor.test1": [{**expected_2, "statistic_id": "sensor.test1"}]}
stats = get_last_short_term_statistics(hass, 2, "sensor.test1", True)
assert stats == {"sensor.test1": expected_stats1[::-1]}
@ -156,11 +161,16 @@ def mock_sensor_statistics():
}
def get_fake_stats(_hass, start, _end):
return [
sensor_stats("sensor.test1", start),
sensor_stats("sensor.test2", start),
sensor_stats("sensor.test3", start),
]
return statistics.PlatformCompiledStatistics(
[
sensor_stats("sensor.test1", start),
sensor_stats("sensor.test2", start),
sensor_stats("sensor.test3", start),
],
get_metadata(
_hass, statistic_ids=["sensor.test1", "sensor.test2", "sensor.test3"]
),
)
with patch(
"homeassistant.components.sensor.recorder.compile_statistics",
@ -327,7 +337,8 @@ def test_statistics_duplicated(hass_recorder, caplog):
assert "Statistics already compiled" not in caplog.text
with patch(
"homeassistant.components.sensor.recorder.compile_statistics"
"homeassistant.components.sensor.recorder.compile_statistics",
return_value=statistics.PlatformCompiledStatistics([], {}),
) as compile_statistics:
recorder.do_adhoc_statistics(start=zero)
wait_recording_done(hass)