Add a context variable holding a HomeAssistant reference (#76303)

* Add a context variable holding a HomeAssistant reference

* Move variable setup and update test

* Refactor

* Revert "Refactor"

This reverts commit 346d005ee6.

* Set context variable when creating HomeAssistant object

* Update docstring

* Update docstring

Co-authored-by: jbouwh <jan@jbsoft.nl>
This commit is contained in:
Erik Montnemery 2022-08-22 15:58:01 +02:00 committed by GitHub
parent 58b9785485
commit 61ff1b786b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 0 deletions

View file

@ -15,6 +15,7 @@ from collections.abc import (
Iterable,
Mapping,
)
from contextvars import ContextVar
import datetime
import enum
import functools
@ -138,6 +139,8 @@ MAX_EXPECTED_ENTITY_IDS = 16384
_LOGGER = logging.getLogger(__name__)
_cv_hass: ContextVar[HomeAssistant] = ContextVar("current_entry")
@functools.lru_cache(MAX_EXPECTED_ENTITY_IDS)
def split_entity_id(entity_id: str) -> tuple[str, str]:
@ -175,6 +178,18 @@ def is_callback(func: Callable[..., Any]) -> bool:
return getattr(func, "_hass_callback", False) is True
@callback
def async_get_hass() -> HomeAssistant:
"""Return the HomeAssistant instance.
Raises LookupError if no HomeAssistant instance is available.
This should be used where it's very cumbersome or downright impossible to pass
hass to the code which needs it.
"""
return _cv_hass.get()
@enum.unique
class HassJobType(enum.Enum):
"""Represent a job type."""
@ -242,6 +257,12 @@ class HomeAssistant:
http: HomeAssistantHTTP = None # type: ignore[assignment]
config_entries: ConfigEntries = None # type: ignore[assignment]
def __new__(cls) -> HomeAssistant:
"""Set the _cv_hass context variable."""
hass = super().__new__(cls)
_cv_hass.set(hass)
return hass
def __init__(self) -> None:
"""Initialize new Home Assistant object."""
self.loop = asyncio.get_running_loop()

View file

@ -501,6 +501,8 @@ async def test_setup_hass(
assert len(mock_ensure_config_exists.mock_calls) == 1
assert len(mock_process_ha_config_upgrade.mock_calls) == 1
assert hass == core.async_get_hass()
async def test_setup_hass_takes_longer_than_log_slow_startup(
mock_enable_logging,