Support non-live database migration (#72433)

* Support non-live database migration

* Tweak startup order, add test

* Address review comments

* Fix typo

* Clarify comment about promoting dependencies

* Tweak

* Fix merge mistake

* Fix some tests

* Fix additional test

* Fix additional test

* Adjust tests

* Improve test coverage
This commit is contained in:
Erik Montnemery 2022-07-22 15:11:34 +02:00 committed by GitHub
parent 9d0a252ca7
commit fd6ffef52f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 993 additions and 61 deletions

View file

@ -24,7 +24,7 @@ from .const import (
SIGNAL_BOOTSTRAP_INTEGRATONS,
)
from .exceptions import HomeAssistantError
from .helpers import area_registry, device_registry, entity_registry
from .helpers import area_registry, device_registry, entity_registry, recorder
from .helpers.dispatcher import async_dispatcher_send
from .helpers.typing import ConfigType
from .setup import (
@ -66,6 +66,15 @@ LOGGING_INTEGRATIONS = {
# Error logging
"system_log",
"sentry",
}
FRONTEND_INTEGRATIONS = {
# Get the frontend up and running as soon as possible so problem
# integrations can be removed and database migration status is
# visible in frontend
"frontend",
}
RECORDER_INTEGRATIONS = {
# Setup after frontend
# To record data
"recorder",
}
@ -83,10 +92,6 @@ STAGE_1_INTEGRATIONS = {
"cloud",
# Ensure supervisor is available
"hassio",
# Get the frontend up and running as soon
# as possible so problem integrations can
# be removed
"frontend",
}
@ -504,11 +509,43 @@ async def _async_set_up_integrations(
_LOGGER.info("Domains to be set up: %s", domains_to_setup)
def _cache_uname_processor() -> None:
"""Cache the result of platform.uname().processor in the executor.
Multiple modules call this function at startup which
executes a blocking subprocess call. This is a problem for the
asyncio event loop. By primeing the cache of uname we can
avoid the blocking call in the event loop.
"""
platform.uname().processor # pylint: disable=expression-not-assigned
# Load the registries and cache the result of platform.uname().processor
await asyncio.gather(
device_registry.async_load(hass),
entity_registry.async_load(hass),
area_registry.async_load(hass),
hass.async_add_executor_job(_cache_uname_processor),
)
# Initialize recorder
if "recorder" in domains_to_setup:
recorder.async_initialize_recorder(hass)
# Load logging as soon as possible
if logging_domains := domains_to_setup & LOGGING_INTEGRATIONS:
_LOGGER.info("Setting up logging: %s", logging_domains)
await async_setup_multi_components(hass, logging_domains, config)
# Setup frontend
if frontend_domains := domains_to_setup & FRONTEND_INTEGRATIONS:
_LOGGER.info("Setting up frontend: %s", frontend_domains)
await async_setup_multi_components(hass, frontend_domains, config)
# Setup recorder
if recorder_domains := domains_to_setup & RECORDER_INTEGRATIONS:
_LOGGER.info("Setting up recorder: %s", recorder_domains)
await async_setup_multi_components(hass, recorder_domains, config)
# Start up debuggers. Start these first in case they want to wait.
if debuggers := domains_to_setup & DEBUGGER_INTEGRATIONS:
_LOGGER.debug("Setting up debuggers: %s", debuggers)
@ -518,7 +555,8 @@ async def _async_set_up_integrations(
stage_1_domains: set[str] = set()
# Find all dependencies of any dependency of any stage 1 integration that
# we plan on loading and promote them to stage 1
# we plan on loading and promote them to stage 1. This is done only to not
# get misleading log messages
deps_promotion: set[str] = STAGE_1_INTEGRATIONS
while deps_promotion:
old_deps_promotion = deps_promotion
@ -535,24 +573,13 @@ async def _async_set_up_integrations(
deps_promotion.update(dep_itg.all_dependencies)
stage_2_domains = domains_to_setup - logging_domains - debuggers - stage_1_domains
def _cache_uname_processor() -> None:
"""Cache the result of platform.uname().processor in the executor.
Multiple modules call this function at startup which
executes a blocking subprocess call. This is a problem for the
asyncio event loop. By primeing the cache of uname we can
avoid the blocking call in the event loop.
"""
platform.uname().processor # pylint: disable=expression-not-assigned
# Load the registries
await asyncio.gather(
device_registry.async_load(hass),
entity_registry.async_load(hass),
area_registry.async_load(hass),
hass.async_add_executor_job(_cache_uname_processor),
stage_2_domains = (
domains_to_setup
- logging_domains
- frontend_domains
- recorder_domains
- debuggers
- stage_1_domains
)
# Start setup

View file

@ -123,7 +123,6 @@ def is_entity_recorded(hass: HomeAssistant, entity_id: str) -> bool:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the recorder."""
hass.data[DOMAIN] = {}
exclude_attributes_by_domain: dict[str, set[str]] = {}
hass.data[EXCLUDE_ATTRIBUTES] = exclude_attributes_by_domain
conf = config[DOMAIN]

View file

@ -43,6 +43,7 @@ import homeassistant.util.dt as dt_util
from . import migration, statistics
from .const import (
DB_WORKER_PREFIX,
DOMAIN,
KEEPALIVE_TIME,
MAX_QUEUE_BACKLOG,
MYSQLDB_URL_PREFIX,
@ -166,7 +167,12 @@ class Recorder(threading.Thread):
self.db_max_retries = db_max_retries
self.db_retry_wait = db_retry_wait
self.engine_version: AwesomeVersion | None = None
# Database connection is ready, but non-live migration may be in progress
db_connected: asyncio.Future[bool] = hass.data[DOMAIN].db_connected
self.async_db_connected: asyncio.Future[bool] = db_connected
# Database is ready to use but live migration may be in progress
self.async_db_ready: asyncio.Future[bool] = asyncio.Future()
# Database is ready to use and all migration steps completed (used by tests)
self.async_recorder_ready = asyncio.Event()
self._queue_watch = threading.Event()
self.engine: Engine | None = None
@ -188,6 +194,7 @@ class Recorder(threading.Thread):
self._completed_first_database_setup: bool | None = None
self.async_migration_event = asyncio.Event()
self.migration_in_progress = False
self.migration_is_live = False
self._database_lock_task: DatabaseLockTask | None = None
self._db_executor: DBInterruptibleThreadPoolExecutor | None = None
self._exclude_attributes_by_domain = exclude_attributes_by_domain
@ -289,7 +296,8 @@ class Recorder(threading.Thread):
def _stop_executor(self) -> None:
"""Stop the executor."""
assert self._db_executor is not None
if self._db_executor is None:
return
self._db_executor.shutdown()
self._db_executor = None
@ -410,6 +418,7 @@ class Recorder(threading.Thread):
@callback
def async_connection_failed(self) -> None:
"""Connect failed tasks."""
self.async_db_connected.set_result(False)
self.async_db_ready.set_result(False)
persistent_notification.async_create(
self.hass,
@ -420,13 +429,29 @@ class Recorder(threading.Thread):
@callback
def async_connection_success(self) -> None:
"""Connect success tasks."""
"""Connect to the database succeeded, schema version and migration need known.
The database may not yet be ready for use in case of a non-live migration.
"""
self.async_db_connected.set_result(True)
@callback
def async_set_recorder_ready(self) -> None:
"""Database live and ready for use.
Called after non-live migration steps are finished.
"""
if self.async_db_ready.done():
return
self.async_db_ready.set_result(True)
self.async_start_executor()
@callback
def _async_recorder_ready(self) -> None:
"""Finish start and mark recorder ready."""
def _async_set_recorder_ready_migration_done(self) -> None:
"""Finish start and mark recorder ready.
Called after all migration steps are finished.
"""
self._async_setup_periodic_tasks()
self.async_recorder_ready.set()
@ -548,6 +573,7 @@ class Recorder(threading.Thread):
self._setup_run()
else:
self.migration_in_progress = True
self.migration_is_live = migration.live_migration(current_version)
self.hass.add_job(self.async_connection_success)
@ -557,6 +583,7 @@ class Recorder(threading.Thread):
# Make sure we cleanly close the run if
# we restart before startup finishes
self._shutdown()
self.hass.add_job(self.async_set_recorder_ready)
return
# We wait to start the migration until startup has finished
@ -577,11 +604,14 @@ class Recorder(threading.Thread):
"Database Migration Failed",
"recorder_database_migration",
)
self.hass.add_job(self.async_set_recorder_ready)
self._shutdown()
return
self.hass.add_job(self.async_set_recorder_ready)
_LOGGER.debug("Recorder processing the queue")
self.hass.add_job(self._async_recorder_ready)
self.hass.add_job(self._async_set_recorder_ready_migration_done)
self._run_event_loop()
def _run_event_loop(self) -> None:
@ -659,7 +689,7 @@ class Recorder(threading.Thread):
try:
migration.migrate_schema(
self.hass, self.engine, self.get_session, current_version
self, self.hass, self.engine, self.get_session, current_version
)
except exc.DatabaseError as err:
if self._handle_database_error(err):

View file

@ -3,7 +3,7 @@ from collections.abc import Callable, Iterable
import contextlib
from datetime import timedelta
import logging
from typing import cast
from typing import Any, cast
import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
@ -40,6 +40,8 @@ from .statistics import (
)
from .util import session_scope
LIVE_MIGRATION_MIN_SCHEMA_VERSION = 0
_LOGGER = logging.getLogger(__name__)
@ -78,7 +80,13 @@ def schema_is_current(current_version: int) -> bool:
return current_version == SCHEMA_VERSION
def live_migration(current_version: int) -> bool:
"""Check if live migration is possible."""
return current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION
def migrate_schema(
instance: Any,
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
@ -86,7 +94,12 @@ def migrate_schema(
) -> None:
"""Check if the schema needs to be upgraded."""
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
db_ready = False
for version in range(current_version, SCHEMA_VERSION):
if live_migration(version) and not db_ready:
db_ready = True
instance.migration_is_live = True
hass.add_job(instance.async_set_recorder_ready)
new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(hass, engine, session_maker, new_version, current_version)

View file

@ -1,6 +1,8 @@
"""Models for Recorder."""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
import logging
from typing import Any, TypedDict, overload
@ -30,6 +32,14 @@ class UnsupportedDialect(Exception):
"""The dialect or its version is not supported."""
@dataclass
class RecorderData:
"""Recorder data stored in hass.data."""
recorder_platforms: dict[str, Any] = field(default_factory=dict)
db_connected: asyncio.Future = field(default_factory=asyncio.Future)
class StatisticResult(TypedDict):
"""Statistic result data class.

View file

@ -576,7 +576,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
platform_stats: list[StatisticResult] = []
current_metadata: dict[str, tuple[int, StatisticMetaData]] = {}
# Collect statistics from all platforms implementing support
for domain, platform in instance.hass.data[DOMAIN].items():
for domain, platform in instance.hass.data[DOMAIN].recorder_platforms.items():
if not hasattr(platform, "compile_statistics"):
continue
compiled: PlatformCompiledStatistics = platform.compile_statistics(
@ -851,7 +851,7 @@ def list_statistic_ids(
}
# Query all integrations with a registered recorder platform
for platform in hass.data[DOMAIN].values():
for platform in hass.data[DOMAIN].recorder_platforms.values():
if not hasattr(platform, "list_statistic_ids"):
continue
platform_statistic_ids = platform.list_statistic_ids(
@ -1339,7 +1339,7 @@ def _sorted_statistics_to_dict(
def validate_statistics(hass: HomeAssistant) -> dict[str, list[ValidationIssue]]:
"""Validate statistics."""
platform_validation: dict[str, list[ValidationIssue]] = {}
for platform in hass.data[DOMAIN].values():
for platform in hass.data[DOMAIN].recorder_platforms.values():
if not hasattr(platform, "validate_statistics"):
continue
platform_validation.update(platform.validate_statistics(hass))

View file

@ -249,7 +249,7 @@ class AddRecorderPlatformTask(RecorderTask):
domain = self.domain
platform = self.platform
platforms: dict[str, Any] = hass.data[DOMAIN]
platforms: dict[str, Any] = hass.data[DOMAIN].recorder_platforms
platforms[domain] = platform
if hasattr(self.platform, "exclude_attributes"):
hass.data[EXCLUDE_ATTRIBUTES][domain] = platform.exclude_attributes(hass)

View file

@ -552,7 +552,7 @@ def write_lock_db_sqlite(instance: Recorder) -> Generator[None, None, None]:
def async_migration_in_progress(hass: HomeAssistant) -> bool:
"""Determine is a migration is in progress.
"""Determine if a migration is in progress.
This is a thin wrapper that allows us to change
out the implementation later.
@ -563,6 +563,18 @@ def async_migration_in_progress(hass: HomeAssistant) -> bool:
return instance.migration_in_progress
def async_migration_is_live(hass: HomeAssistant) -> bool:
"""Determine if a migration is live.
This is a thin wrapper that allows us to change
out the implementation later.
"""
if DATA_INSTANCE not in hass.data:
return False
instance: Recorder = hass.data[DATA_INSTANCE]
return instance.migration_is_live
def second_sunday(year: int, month: int) -> date:
"""Return the datetime.date for the second sunday of a month."""
second = date(year, month, FIRST_POSSIBLE_SUNDAY)

View file

@ -17,7 +17,7 @@ from .statistics import (
list_statistic_ids,
validate_statistics,
)
from .util import async_migration_in_progress, get_instance
from .util import async_migration_in_progress, async_migration_is_live, get_instance
_LOGGER: logging.Logger = logging.getLogger(__package__)
@ -193,6 +193,7 @@ def ws_info(
backlog = instance.backlog if instance else None
migration_in_progress = async_migration_in_progress(hass)
migration_is_live = async_migration_is_live(hass)
recording = instance.recording if instance else False
thread_alive = instance.is_alive() if instance else False
@ -200,6 +201,7 @@ def ws_info(
"backlog": backlog,
"max_backlog": MAX_QUEUE_BACKLOG,
"migration_in_progress": migration_in_progress,
"migration_is_live": migration_is_live,
"recording": recording,
"thread_running": thread_alive,
}

View file

@ -1,7 +1,8 @@
"""Helpers to check recorder."""
import asyncio
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
def async_migration_in_progress(hass: HomeAssistant) -> bool:
@ -12,3 +13,26 @@ def async_migration_in_progress(hass: HomeAssistant) -> bool:
from homeassistant.components import recorder
return recorder.util.async_migration_in_progress(hass)
@callback
def async_initialize_recorder(hass: HomeAssistant) -> None:
"""Initialize recorder data."""
# pylint: disable-next=import-outside-toplevel
from homeassistant.components.recorder import const, models
hass.data[const.DOMAIN] = models.RecorderData()
async def async_wait_recorder(hass: HomeAssistant) -> bool:
"""Wait for recorder to initialize and return connection status.
Returns False immediately if the recorder is not enabled.
"""
# pylint: disable-next=import-outside-toplevel
from homeassistant.components.recorder import const
if const.DOMAIN not in hass.data:
return False
db_connected: asyncio.Future[bool] = hass.data[const.DOMAIN].db_connected
return await db_connected

View file

@ -50,6 +50,7 @@ from homeassistant.helpers import (
entity_platform,
entity_registry,
intent,
recorder as recorder_helper,
restore_state,
storage,
)
@ -914,6 +915,8 @@ def init_recorder_component(hass, add_config=None):
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.migration.migrate_schema"
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config})
assert recorder.DOMAIN in hass.config.components
_LOGGER.info(

View file

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from homeassistant.helpers import recorder as recorder_helper
from homeassistant.setup import async_setup_component
from tests.components.blueprint.conftest import stub_blueprint_populate # noqa: F401
@ -24,4 +25,5 @@ def recorder_url_mock():
async def test_setup(hass, mock_zeroconf, mock_get_source_ip):
"""Test setup."""
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(hass, "default_config", {"foo": "bar"})

View file

@ -0,0 +1,673 @@
"""Models for SQLAlchemy."""
from __future__ import annotations
from datetime import datetime, timedelta
import json
import logging
from typing import Any, TypedDict, cast, overload
from fnvhash import fnv1a_32
from sqlalchemy import (
BigInteger,
Boolean,
Column,
DateTime,
Float,
ForeignKey,
Identity,
Index,
Integer,
String,
Text,
distinct,
)
from sqlalchemy.dialects import mysql, oracle, postgresql
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm.session import Session
from homeassistant.components.recorder.const import ALL_DOMAIN_EXCLUDE_ATTRS, JSON_DUMP
from homeassistant.const import (
MAX_LENGTH_EVENT_CONTEXT_ID,
MAX_LENGTH_EVENT_EVENT_TYPE,
MAX_LENGTH_EVENT_ORIGIN,
MAX_LENGTH_STATE_ENTITY_ID,
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
import homeassistant.util.dt as dt_util
# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
SCHEMA_VERSION = 25
_LOGGER = logging.getLogger(__name__)
DB_TIMEZONE = "+00:00"
TABLE_EVENTS = "events"
TABLE_STATES = "states"
TABLE_STATE_ATTRIBUTES = "state_attributes"
TABLE_RECORDER_RUNS = "recorder_runs"
TABLE_SCHEMA_CHANGES = "schema_changes"
TABLE_STATISTICS = "statistics"
TABLE_STATISTICS_META = "statistics_meta"
TABLE_STATISTICS_RUNS = "statistics_runs"
TABLE_STATISTICS_SHORT_TERM = "statistics_short_term"
ALL_TABLES = [
TABLE_STATES,
TABLE_EVENTS,
TABLE_RECORDER_RUNS,
TABLE_SCHEMA_CHANGES,
TABLE_STATISTICS,
TABLE_STATISTICS_META,
TABLE_STATISTICS_RUNS,
TABLE_STATISTICS_SHORT_TERM,
]
EMPTY_JSON_OBJECT = "{}"
DATETIME_TYPE = DateTime(timezone=True).with_variant(
mysql.DATETIME(timezone=True, fsp=6), "mysql"
)
DOUBLE_TYPE = (
Float()
.with_variant(mysql.DOUBLE(asdecimal=False), "mysql")
.with_variant(oracle.DOUBLE_PRECISION(), "oracle")
.with_variant(postgresql.DOUBLE_PRECISION(), "postgresql")
)
class Events(Base): # type: ignore[misc,valid-type]
"""Event history data."""
__table_args__ = (
# Used for fetching events at a specific time
# see logbook
Index("ix_events_event_type_time_fired", "event_type", "time_fired"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_EVENTS
event_id = Column(Integer, Identity(), primary_key=True)
event_type = Column(String(MAX_LENGTH_EVENT_EVENT_TYPE))
event_data = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
origin = Column(String(MAX_LENGTH_EVENT_ORIGIN))
time_fired = Column(DATETIME_TYPE, index=True)
context_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_user_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
context_parent_id = Column(String(MAX_LENGTH_EVENT_CONTEXT_ID), index=True)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.Events("
f"id={self.event_id}, type='{self.event_type}', data='{self.event_data}', "
f"origin='{self.origin}', time_fired='{self.time_fired}'"
f")>"
)
@staticmethod
def from_event(
event: Event, event_data: UndefinedType | None = UNDEFINED
) -> Events:
"""Create an event database object from a native event."""
return Events(
event_type=event.event_type,
event_data=JSON_DUMP(event.data) if event_data is UNDEFINED else event_data,
origin=str(event.origin.value),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id,
context_parent_id=event.context.parent_id,
)
def to_native(self, validate_entity_id: bool = True) -> Event | None:
"""Convert to a native HA Event."""
context = Context(
id=self.context_id,
user_id=self.context_user_id,
parent_id=self.context_parent_id,
)
try:
return Event(
self.event_type,
json.loads(self.event_data),
EventOrigin(self.origin),
process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting to event: %s", self)
return None
class States(Base): # type: ignore[misc,valid-type]
"""State change history."""
__table_args__ = (
# Used for fetching the state of entities at a specific time
# (get_states in history.py)
Index("ix_states_entity_id_last_updated", "entity_id", "last_updated"),
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATES
state_id = Column(Integer, Identity(), primary_key=True)
entity_id = Column(String(MAX_LENGTH_STATE_ENTITY_ID))
state = Column(String(MAX_LENGTH_STATE_STATE))
attributes = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
event_id = Column(
Integer, ForeignKey("events.event_id", ondelete="CASCADE"), index=True
)
last_changed = Column(DATETIME_TYPE, default=dt_util.utcnow)
last_updated = Column(DATETIME_TYPE, default=dt_util.utcnow, index=True)
old_state_id = Column(Integer, ForeignKey("states.state_id"), index=True)
attributes_id = Column(
Integer, ForeignKey("state_attributes.attributes_id"), index=True
)
event = relationship("Events", uselist=False)
old_state = relationship("States", remote_side=[state_id])
state_attributes = relationship("StateAttributes")
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.States("
f"id={self.state_id}, entity_id='{self.entity_id}', "
f"state='{self.state}', event_id='{self.event_id}', "
f"last_updated='{self.last_updated.isoformat(sep=' ', timespec='seconds')}', "
f"old_state_id={self.old_state_id}, attributes_id={self.attributes_id}"
f")>"
)
@staticmethod
def from_event(event: Event) -> States:
"""Create object from a state_changed event."""
entity_id = event.data["entity_id"]
state: State | None = event.data.get("new_state")
dbstate = States(entity_id=entity_id, attributes=None)
# None state means the state was removed from the state machine
if state is None:
dbstate.state = ""
dbstate.last_changed = event.time_fired
dbstate.last_updated = event.time_fired
else:
dbstate.state = state.state
dbstate.last_changed = state.last_changed
dbstate.last_updated = state.last_updated
return dbstate
def to_native(self, validate_entity_id: bool = True) -> State | None:
"""Convert to an HA state object."""
try:
return State(
self.entity_id,
self.state,
# Join the state_attributes table on attributes_id to get the attributes
# for newer states
json.loads(self.attributes) if self.attributes else {},
process_timestamp(self.last_changed),
process_timestamp(self.last_updated),
# Join the events table on event_id to get the context instead
# as it will always be there for state_changed events
context=Context(id=None), # type: ignore[arg-type]
validate_entity_id=validate_entity_id,
)
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state: %s", self)
return None
class StateAttributes(Base): # type: ignore[misc,valid-type]
"""State attribute change history."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATE_ATTRIBUTES
attributes_id = Column(Integer, Identity(), primary_key=True)
hash = Column(BigInteger, index=True)
# Note that this is not named attributes to avoid confusion with the states table
shared_attrs = Column(Text().with_variant(mysql.LONGTEXT, "mysql"))
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.StateAttributes("
f"id={self.attributes_id}, hash='{self.hash}', attributes='{self.shared_attrs}'"
f")>"
)
@staticmethod
def from_event(event: Event) -> StateAttributes:
"""Create object from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
dbstate = StateAttributes(
shared_attrs="{}" if state is None else JSON_DUMP(state.attributes)
)
dbstate.hash = StateAttributes.hash_shared_attrs(dbstate.shared_attrs)
return dbstate
@staticmethod
def shared_attrs_from_event(
event: Event, exclude_attrs_by_domain: dict[str, set[str]]
) -> str:
"""Create shared_attrs from a state_changed event."""
state: State | None = event.data.get("new_state")
# None state means the state was removed from the state machine
if state is None:
return "{}"
domain = split_entity_id(state.entity_id)[0]
exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
)
return JSON_DUMP(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}
)
@staticmethod
def hash_shared_attrs(shared_attrs: str) -> int:
"""Return the hash of json encoded shared attributes."""
return cast(int, fnv1a_32(shared_attrs.encode("utf-8")))
def to_native(self) -> dict[str, Any]:
"""Convert to an HA state object."""
try:
return cast(dict[str, Any], json.loads(self.shared_attrs))
except ValueError:
# When json.loads fails
_LOGGER.exception("Error converting row to state attributes: %s", self)
return {}
class StatisticResult(TypedDict):
"""Statistic result data class.
Allows multiple datapoints for the same statistic_id.
"""
meta: StatisticMetaData
stat: StatisticData
class StatisticDataBase(TypedDict):
"""Mandatory fields for statistic data class."""
start: datetime
class StatisticData(StatisticDataBase, total=False):
"""Statistic data class."""
mean: float
min: float
max: float
last_reset: datetime | None
state: float
sum: float
class StatisticsBase:
"""Statistics base class."""
id = Column(Integer, Identity(), primary_key=True)
created = Column(DATETIME_TYPE, default=dt_util.utcnow)
@declared_attr # type: ignore[misc]
def metadata_id(self) -> Column:
"""Define the metadata_id column for sub classes."""
return Column(
Integer,
ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"),
index=True,
)
start = Column(DATETIME_TYPE, index=True)
mean = Column(DOUBLE_TYPE)
min = Column(DOUBLE_TYPE)
max = Column(DOUBLE_TYPE)
last_reset = Column(DATETIME_TYPE)
state = Column(DOUBLE_TYPE)
sum = Column(DOUBLE_TYPE)
@classmethod
def from_stats(cls, metadata_id: int, stats: StatisticData) -> StatisticsBase:
"""Create object from a statistics."""
return cls( # type: ignore[call-arg,misc]
metadata_id=metadata_id,
**stats,
)
class Statistics(Base, StatisticsBase): # type: ignore[misc,valid-type]
"""Long term statistics."""
duration = timedelta(hours=1)
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index("ix_statistics_statistic_id_start", "metadata_id", "start", unique=True),
)
__tablename__ = TABLE_STATISTICS
class StatisticsShortTerm(Base, StatisticsBase): # type: ignore[misc,valid-type]
"""Short term statistics."""
duration = timedelta(minutes=5)
__table_args__ = (
# Used for fetching statistics for a certain entity at a specific time
Index(
"ix_statistics_short_term_statistic_id_start",
"metadata_id",
"start",
unique=True,
),
)
__tablename__ = TABLE_STATISTICS_SHORT_TERM
class StatisticMetaData(TypedDict):
"""Statistic meta data class."""
has_mean: bool
has_sum: bool
name: str | None
source: str
statistic_id: str
unit_of_measurement: str | None
class StatisticsMeta(Base): # type: ignore[misc,valid-type]
"""Statistics meta data."""
__table_args__ = (
{"mysql_default_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
)
__tablename__ = TABLE_STATISTICS_META
id = Column(Integer, Identity(), primary_key=True)
statistic_id = Column(String(255), index=True)
source = Column(String(32))
unit_of_measurement = Column(String(255))
has_mean = Column(Boolean)
has_sum = Column(Boolean)
name = Column(String(255))
@staticmethod
def from_meta(meta: StatisticMetaData) -> StatisticsMeta:
"""Create object from meta data."""
return StatisticsMeta(**meta)
class RecorderRuns(Base): # type: ignore[misc,valid-type]
"""Representation of recorder run."""
__table_args__ = (Index("ix_recorder_runs_start_end", "start", "end"),)
__tablename__ = TABLE_RECORDER_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True), default=dt_util.utcnow)
end = Column(DateTime(timezone=True))
closed_incorrect = Column(Boolean, default=False)
created = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
end = (
f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None
)
return (
f"<recorder.RecorderRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f"end={end}, closed_incorrect={self.closed_incorrect}, "
f"created='{self.created.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
def entity_ids(self, point_in_time: datetime | None = None) -> list[str]:
"""Return the entity ids that existed in this run.
Specify point_in_time if you want to know which existed at that point
in time inside the run.
"""
session = Session.object_session(self)
assert session is not None, "RecorderRuns need to be persisted"
query = session.query(distinct(States.entity_id)).filter(
States.last_updated >= self.start
)
if point_in_time is not None:
query = query.filter(States.last_updated < point_in_time)
elif self.end is not None:
query = query.filter(States.last_updated < self.end)
return [row[0] for row in query]
def to_native(self, validate_entity_id: bool = True) -> RecorderRuns:
"""Return self, native format is this model."""
return self
class SchemaChanges(Base): # type: ignore[misc,valid-type]
"""Representation of schema version changes."""
__tablename__ = TABLE_SCHEMA_CHANGES
change_id = Column(Integer, Identity(), primary_key=True)
schema_version = Column(Integer)
changed = Column(DateTime(timezone=True), default=dt_util.utcnow)
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.SchemaChanges("
f"id={self.change_id}, schema_version={self.schema_version}, "
f"changed='{self.changed.isoformat(sep=' ', timespec='seconds')}'"
f")>"
)
class StatisticsRuns(Base): # type: ignore[misc,valid-type]
"""Representation of statistics run."""
__tablename__ = TABLE_STATISTICS_RUNS
run_id = Column(Integer, Identity(), primary_key=True)
start = Column(DateTime(timezone=True))
def __repr__(self) -> str:
"""Return string representation of instance for debugging."""
return (
f"<recorder.StatisticsRuns("
f"id={self.run_id}, start='{self.start.isoformat(sep=' ', timespec='seconds')}', "
f")>"
)
@overload
def process_timestamp(ts: None) -> None:
...
@overload
def process_timestamp(ts: datetime) -> datetime:
...
def process_timestamp(ts: datetime | None) -> datetime | None:
"""Process a timestamp into datetime object."""
if ts is None:
return None
if ts.tzinfo is None:
return ts.replace(tzinfo=dt_util.UTC)
return dt_util.as_utc(ts)
@overload
def process_timestamp_to_utc_isoformat(ts: None) -> None:
...
@overload
def process_timestamp_to_utc_isoformat(ts: datetime) -> str:
...
def process_timestamp_to_utc_isoformat(ts: datetime | None) -> str | None:
"""Process a timestamp into UTC isotime."""
if ts is None:
return None
if ts.tzinfo == dt_util.UTC:
return ts.isoformat()
if ts.tzinfo is None:
return f"{ts.isoformat()}{DB_TIMEZONE}"
return ts.astimezone(dt_util.UTC).isoformat()
class LazyState(State):
"""A lazy version of core State."""
__slots__ = [
"_row",
"_attributes",
"_last_changed",
"_last_updated",
"_context",
"_attr_cache",
]
def __init__( # pylint: disable=super-init-not-called
self, row: Row, attr_cache: dict[str, dict[str, Any]] | None = None
) -> None:
"""Init the lazy state."""
self._row = row
self.entity_id: str = self._row.entity_id
self.state = self._row.state or ""
self._attributes: dict[str, Any] | None = None
self._last_changed: datetime | None = None
self._last_updated: datetime | None = None
self._context: Context | None = None
self._attr_cache = attr_cache
@property # type: ignore[override]
def attributes(self) -> dict[str, Any]: # type: ignore[override]
"""State attributes."""
if self._attributes is None:
source = self._row.shared_attrs or self._row.attributes
if self._attr_cache is not None and (
attributes := self._attr_cache.get(source)
):
self._attributes = attributes
return attributes
if source == EMPTY_JSON_OBJECT or source is None:
self._attributes = {}
return self._attributes
try:
self._attributes = json.loads(source)
except ValueError:
# When json.loads fails
_LOGGER.exception(
"Error converting row to state attributes: %s", self._row
)
self._attributes = {}
if self._attr_cache is not None:
self._attr_cache[source] = self._attributes
return self._attributes
@attributes.setter
def attributes(self, value: dict[str, Any]) -> None:
"""Set attributes."""
self._attributes = value
@property # type: ignore[override]
def context(self) -> Context: # type: ignore[override]
"""State context."""
if self._context is None:
self._context = Context(id=None) # type: ignore[arg-type]
return self._context
@context.setter
def context(self, value: Context) -> None:
"""Set context."""
self._context = value
@property # type: ignore[override]
def last_changed(self) -> datetime: # type: ignore[override]
"""Last changed datetime."""
if self._last_changed is None:
self._last_changed = process_timestamp(self._row.last_changed)
return self._last_changed
@last_changed.setter
def last_changed(self, value: datetime) -> None:
"""Set last changed datetime."""
self._last_changed = value
@property # type: ignore[override]
def last_updated(self) -> datetime: # type: ignore[override]
"""Last updated datetime."""
if self._last_updated is None:
if (last_updated := self._row.last_updated) is not None:
self._last_updated = process_timestamp(last_updated)
else:
self._last_updated = self.last_changed
return self._last_updated
@last_updated.setter
def last_updated(self, value: datetime) -> None:
"""Set last updated datetime."""
self._last_updated = value
def as_dict(self) -> dict[str, Any]: # type: ignore[override]
"""Return a dict representation of the LazyState.
Async friendly.
To be used for JSON serialization.
"""
if self._last_changed is None and self._last_updated is None:
last_changed_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_changed
)
if (
self._row.last_updated is None
or self._row.last_changed == self._row.last_updated
):
last_updated_isoformat = last_changed_isoformat
else:
last_updated_isoformat = process_timestamp_to_utc_isoformat(
self._row.last_updated
)
else:
last_changed_isoformat = self.last_changed.isoformat()
if self.last_changed == self.last_updated:
last_updated_isoformat = last_changed_isoformat
else:
last_updated_isoformat = self.last_updated.isoformat()
return {
"entity_id": self.entity_id,
"state": self.state,
"attributes": self._attributes or self.attributes,
"last_changed": last_changed_isoformat,
"last_updated": last_updated_isoformat,
}
def __eq__(self, other: Any) -> bool:
"""Return the comparison."""
return (
other.__class__ in [self.__class__, State]
and self.entity_id == other.entity_id
and self.state == other.state
and self.attributes == other.attributes
)

View file

@ -51,6 +51,7 @@ from homeassistant.const import (
STATE_UNLOCKED,
)
from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.helpers import recorder as recorder_helper
from homeassistant.setup import async_setup_component, setup_component
from homeassistant.util import dt as dt_util
@ -100,9 +101,10 @@ async def test_shutdown_before_startup_finishes(
}
hass.state = CoreState.not_running
instance = await async_setup_recorder_instance(hass, config)
await instance.async_db_ready
await hass.async_block_till_done()
recorder_helper.async_initialize_recorder(hass)
hass.create_task(async_setup_recorder_instance(hass, config))
await recorder_helper.async_wait_recorder(hass)
instance = get_instance(hass)
session = await hass.async_add_executor_job(instance.get_session)
@ -125,9 +127,11 @@ async def test_canceled_before_startup_finishes(
):
"""Test recorder shuts down when its startup future is canceled out from under it."""
hass.state = CoreState.not_running
await async_setup_recorder_instance(hass)
recorder_helper.async_initialize_recorder(hass)
hass.create_task(async_setup_recorder_instance(hass))
await recorder_helper.async_wait_recorder(hass)
instance = get_instance(hass)
await instance.async_db_ready
instance._hass_started.cancel()
with patch.object(instance, "engine"):
await hass.async_block_till_done()
@ -170,7 +174,9 @@ async def test_state_gets_saved_when_set_before_start_event(
hass.state = CoreState.not_running
await async_setup_recorder_instance(hass)
recorder_helper.async_initialize_recorder(hass)
hass.create_task(async_setup_recorder_instance(hass))
await recorder_helper.async_wait_recorder(hass)
entity_id = "test.recorder"
state = "restoring_from_db"
@ -643,6 +649,7 @@ def test_saving_state_and_removing_entity(hass, hass_recorder):
def test_recorder_setup_failure(hass):
"""Test some exceptions."""
recorder_helper.async_initialize_recorder(hass)
with patch.object(Recorder, "_setup_connection") as setup, patch(
"homeassistant.components.recorder.core.time.sleep"
):
@ -657,6 +664,7 @@ def test_recorder_setup_failure(hass):
def test_recorder_setup_failure_without_event_listener(hass):
"""Test recorder setup failure when the event listener is not setup."""
recorder_helper.async_initialize_recorder(hass)
with patch.object(Recorder, "_setup_connection") as setup, patch(
"homeassistant.components.recorder.core.time.sleep"
):
@ -985,6 +993,7 @@ def test_compile_missing_statistics(tmpdir):
):
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}})
hass.start()
wait_recording_done(hass)
@ -1006,6 +1015,7 @@ def test_compile_missing_statistics(tmpdir):
):
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}})
hass.start()
wait_recording_done(hass)
@ -1197,6 +1207,7 @@ def test_service_disable_run_information_recorded(tmpdir):
dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}"
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}})
hass.start()
wait_recording_done(hass)
@ -1218,6 +1229,7 @@ def test_service_disable_run_information_recorded(tmpdir):
hass.stop()
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl}})
hass.start()
wait_recording_done(hass)
@ -1246,6 +1258,7 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
test_db_file = await hass.async_add_executor_job(_create_tmpdir_for_test_db)
dburl = f"{SQLITE_URL_PREFIX}//{test_db_file}"
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
hass, DOMAIN, {DOMAIN: {CONF_DB_URL: dburl, CONF_COMMIT_INTERVAL: 0}}
)

View file

@ -27,6 +27,7 @@ from homeassistant.components.recorder.db_schema import (
States,
)
from homeassistant.components.recorder.util import session_scope
from homeassistant.helpers import recorder as recorder_helper
import homeassistant.util.dt as dt_util
from .common import async_wait_recording_done, create_engine_test
@ -53,6 +54,7 @@ async def test_schema_update_calls(hass):
"homeassistant.components.recorder.migration._apply_update",
wraps=migration._apply_update,
) as update:
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
)
@ -74,10 +76,11 @@ async def test_migration_in_progress(hass):
"""Test that we can check for migration in progress."""
assert recorder.util.async_migration_in_progress(hass) is False
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True,), patch(
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.core.create_engine",
new=create_engine_test,
):
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
)
@ -105,6 +108,7 @@ async def test_database_migration_failed(hass):
"homeassistant.components.persistent_notification.dismiss",
side_effect=pn.dismiss,
) as mock_dismiss:
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
)
@ -136,6 +140,7 @@ async def test_database_migration_encounters_corruption(hass):
), patch(
"homeassistant.components.recorder.core.move_away_broken_database"
) as move_away:
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
)
@ -165,6 +170,7 @@ async def test_database_migration_encounters_corruption_not_sqlite(hass):
"homeassistant.components.persistent_notification.dismiss",
side_effect=pn.dismiss,
) as mock_dismiss:
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
)
@ -189,6 +195,7 @@ async def test_events_during_migration_are_queued(hass):
"homeassistant.components.recorder.core.create_engine",
new=create_engine_test,
):
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass,
"recorder",
@ -219,6 +226,7 @@ async def test_events_during_migration_queue_exhausted(hass):
"homeassistant.components.recorder.core.create_engine",
new=create_engine_test,
), patch.object(recorder.core, "MAX_QUEUE_BACKLOG", 1):
recorder_helper.async_initialize_recorder(hass)
await async_setup_component(
hass,
"recorder",
@ -247,8 +255,11 @@ async def test_events_during_migration_queue_exhausted(hass):
assert len(db_states) == 2
@pytest.mark.parametrize("start_version", [0, 16, 18, 22])
async def test_schema_migrate(hass, start_version):
@pytest.mark.parametrize(
"start_version,live",
[(0, True), (16, True), (18, True), (22, True), (25, True)],
)
async def test_schema_migrate(hass, start_version, live):
"""Test the full schema migration logic.
We're just testing that the logic can execute successfully here without
@ -259,7 +270,8 @@ async def test_schema_migrate(hass, start_version):
migration_done = threading.Event()
migration_stall = threading.Event()
migration_version = None
real_migration = recorder.migration.migrate_schema
real_migrate_schema = recorder.migration.migrate_schema
real_apply_update = recorder.migration._apply_update
def _create_engine_test(*args, **kwargs):
"""Test version of create_engine that initializes with old schema.
@ -284,14 +296,12 @@ async def test_schema_migrate(hass, start_version):
start=self.run_history.recording_start, created=dt_util.utcnow()
)
def _instrument_migration(*args):
def _instrument_migrate_schema(*args):
"""Control migration progress and check results."""
nonlocal migration_done
nonlocal migration_version
nonlocal migration_stall
migration_stall.wait()
try:
real_migration(*args)
real_migrate_schema(*args)
except Exception:
migration_done.set()
raise
@ -307,6 +317,12 @@ async def test_schema_migrate(hass, start_version):
migration_version = res.schema_version
migration_done.set()
def _instrument_apply_update(*args):
"""Control migration progress."""
nonlocal migration_stall
migration_stall.wait()
real_apply_update(*args)
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.core.create_engine",
new=_create_engine_test,
@ -316,12 +332,21 @@ async def test_schema_migrate(hass, start_version):
autospec=True,
) as setup_run, patch(
"homeassistant.components.recorder.migration.migrate_schema",
wraps=_instrument_migration,
wraps=_instrument_migrate_schema,
), patch(
"homeassistant.components.recorder.migration._apply_update",
wraps=_instrument_apply_update,
):
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
recorder_helper.async_initialize_recorder(hass)
hass.async_create_task(
async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
)
)
await recorder_helper.async_wait_recorder(hass)
assert recorder.util.async_migration_in_progress(hass) is True
assert recorder.util.async_migration_is_live(hass) == live
migration_stall.set()
await hass.async_block_till_done()
await hass.async_add_executor_job(migration_done.wait)

View file

@ -31,6 +31,7 @@ from homeassistant.components.recorder.util import session_scope
from homeassistant.const import TEMP_CELSIUS
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import recorder as recorder_helper
from homeassistant.setup import setup_component
import homeassistant.util.dt as dt_util
@ -1128,6 +1129,7 @@ def test_delete_metadata_duplicates(caplog, tmpdir):
"homeassistant.components.recorder.core.create_engine", new=_create_engine_28
):
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
@ -1158,6 +1160,7 @@ def test_delete_metadata_duplicates(caplog, tmpdir):
# Test that the duplicates are removed during migration from schema 28
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)
@ -1217,6 +1220,7 @@ def test_delete_metadata_duplicates_many(caplog, tmpdir):
"homeassistant.components.recorder.core.create_engine", new=_create_engine_28
):
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
@ -1249,6 +1253,7 @@ def test_delete_metadata_duplicates_many(caplog, tmpdir):
# Test that the duplicates are removed during migration from schema 28
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)

View file

@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from homeassistant.components import recorder
from homeassistant.components.recorder import SQLITE_URL_PREFIX, statistics
from homeassistant.components.recorder.util import session_scope
from homeassistant.helpers import recorder as recorder_helper
from homeassistant.setup import setup_component
import homeassistant.util.dt as dt_util
@ -179,6 +180,7 @@ def test_delete_duplicates(caplog, tmpdir):
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
), patch(CREATE_ENGINE_TARGET, new=_create_engine_test):
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
@ -206,6 +208,7 @@ def test_delete_duplicates(caplog, tmpdir):
# Test that the duplicates are removed during migration from schema 23
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)
@ -347,6 +350,7 @@ def test_delete_duplicates_many(caplog, tmpdir):
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
), patch(CREATE_ENGINE_TARGET, new=_create_engine_test):
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
@ -380,6 +384,7 @@ def test_delete_duplicates_many(caplog, tmpdir):
# Test that the duplicates are removed during migration from schema 23
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)
@ -492,6 +497,7 @@ def test_delete_duplicates_non_identical(caplog, tmpdir):
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
), patch(CREATE_ENGINE_TARGET, new=_create_engine_test):
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
@ -515,6 +521,7 @@ def test_delete_duplicates_non_identical(caplog, tmpdir):
# Test that the duplicates are removed during migration from schema 23
hass = get_test_home_assistant()
hass.config.config_dir = tmpdir
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)
@ -592,6 +599,7 @@ def test_delete_duplicates_short_term(caplog, tmpdir):
recorder.migration, "SCHEMA_VERSION", old_db_schema.SCHEMA_VERSION
), patch(CREATE_ENGINE_TARGET, new=_create_engine_test):
hass = get_test_home_assistant()
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
wait_recording_done(hass)
wait_recording_done(hass)
@ -614,6 +622,7 @@ def test_delete_duplicates_short_term(caplog, tmpdir):
# Test that the duplicates are removed during migration from schema 23
hass = get_test_home_assistant()
hass.config.config_dir = tmpdir
recorder_helper.async_initialize_recorder(hass)
setup_component(hass, "recorder", {"recorder": {"db_url": dburl}})
hass.start()
wait_recording_done(hass)

View file

@ -15,6 +15,7 @@ from homeassistant.components.recorder.statistics import (
list_statistic_ids,
statistics_during_period,
)
from homeassistant.helpers import recorder as recorder_helper
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM
@ -274,6 +275,7 @@ async def test_recorder_info(hass, hass_ws_client, recorder_mock):
"backlog": 0,
"max_backlog": 40000,
"migration_in_progress": False,
"migration_is_live": False,
"recording": True,
"thread_running": True,
}
@ -296,6 +298,7 @@ async def test_recorder_info_bad_recorder_config(hass, hass_ws_client):
client = await hass_ws_client()
with patch("homeassistant.components.recorder.migration.migrate_schema"):
recorder_helper.async_initialize_recorder(hass)
assert not await async_setup_component(
hass, recorder.DOMAIN, {recorder.DOMAIN: config}
)
@ -318,7 +321,7 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client):
migration_done = threading.Event()
real_migration = recorder.migration.migrate_schema
real_migration = recorder.migration._apply_update
def stalled_migration(*args):
"""Make migration stall."""
@ -334,12 +337,16 @@ async def test_recorder_info_migration_queue_exhausted(hass, hass_ws_client):
), patch.object(
recorder.core, "MAX_QUEUE_BACKLOG", 1
), patch(
"homeassistant.components.recorder.migration.migrate_schema",
"homeassistant.components.recorder.migration._apply_update",
wraps=stalled_migration,
):
await async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
recorder_helper.async_initialize_recorder(hass)
hass.create_task(
async_setup_component(
hass, "recorder", {"recorder": {"db_url": "sqlite://"}}
)
)
await recorder_helper.async_wait_recorder(hass)
hass.states.async_set("my.entity", "on", {})
await hass.async_block_till_done()

View file

@ -31,7 +31,7 @@ from homeassistant.components.websocket_api.auth import (
from homeassistant.components.websocket_api.http import URL
from homeassistant.const import HASSIO_USER_NAME
from homeassistant.core import CoreState, HomeAssistant
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.helpers import config_entry_oauth2_flow, recorder as recorder_helper
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util, location
@ -790,6 +790,8 @@ async def _async_init_recorder_component(hass, add_config=None):
with patch("homeassistant.components.recorder.ALLOW_IN_MEMORY_DB", True), patch(
"homeassistant.components.recorder.migration.migrate_schema"
):
if recorder.DOMAIN not in hass.data:
recorder_helper.async_initialize_recorder(hass)
assert await async_setup_component(
hass, recorder.DOMAIN, {recorder.DOMAIN: config}
)

View file

@ -211,6 +211,82 @@ async def test_setup_after_deps_in_stage_1_ignored(hass):
assert order == ["cloud", "an_after_dep", "normal_integration"]
@pytest.mark.parametrize("load_registries", [False])
async def test_setup_frontend_before_recorder(hass):
"""Test frontend is setup before recorder."""
order = []
def gen_domain_setup(domain):
async def async_setup(hass, config):
order.append(domain)
return True
return async_setup
mock_integration(
hass,
MockModule(
domain="normal_integration",
async_setup=gen_domain_setup("normal_integration"),
partial_manifest={"after_dependencies": ["an_after_dep"]},
),
)
mock_integration(
hass,
MockModule(
domain="an_after_dep",
async_setup=gen_domain_setup("an_after_dep"),
),
)
mock_integration(
hass,
MockModule(
domain="frontend",
async_setup=gen_domain_setup("frontend"),
partial_manifest={
"dependencies": ["http"],
"after_dependencies": ["an_after_dep"],
},
),
)
mock_integration(
hass,
MockModule(
domain="http",
async_setup=gen_domain_setup("http"),
),
)
mock_integration(
hass,
MockModule(
domain="recorder",
async_setup=gen_domain_setup("recorder"),
),
)
await bootstrap._async_set_up_integrations(
hass,
{
"frontend": {},
"http": {},
"recorder": {},
"normal_integration": {},
"an_after_dep": {},
},
)
assert "frontend" in hass.config.components
assert "normal_integration" in hass.config.components
assert "recorder" in hass.config.components
assert order == [
"http",
"frontend",
"recorder",
"an_after_dep",
"normal_integration",
]
@pytest.mark.parametrize("load_registries", [False])
async def test_setup_after_deps_via_platform(hass):
"""Test after_dependencies set up via platform."""