mirror of
https://github.com/home-assistant/core
synced 2024-10-05 09:12:11 +00:00
Detach aiohttp.ClientSession created by config entry setup on unload (#48908)
This commit is contained in:
parent
8e2b5b36b5
commit
40450b9cfd
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextvars import ContextVar
|
||||
import functools
|
||||
import logging
|
||||
from types import MappingProxyType, MethodType
|
||||
|
@ -133,6 +134,7 @@ class ConfigEntry:
|
|||
"_setup_lock",
|
||||
"update_listeners",
|
||||
"_async_cancel_retry_setup",
|
||||
"_on_unload",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
@ -198,6 +200,9 @@ class ConfigEntry:
|
|||
# Function to cancel a scheduled retry
|
||||
self._async_cancel_retry_setup: Callable[[], Any] | None = None
|
||||
|
||||
# Hold list for functions to call on unload.
|
||||
self._on_unload: list[CALLBACK_TYPE] | None = None
|
||||
|
||||
async def async_setup(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
|
@ -206,6 +211,7 @@ class ConfigEntry:
|
|||
tries: int = 0,
|
||||
) -> None:
|
||||
"""Set up an entry."""
|
||||
current_entry.set(self)
|
||||
if self.source == SOURCE_IGNORE or self.disabled_by:
|
||||
return
|
||||
|
||||
|
@ -290,6 +296,8 @@ class ConfigEntry:
|
|||
self._async_cancel_retry_setup = hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_STARTED, setup_again
|
||||
)
|
||||
|
||||
self._async_process_on_unload()
|
||||
return
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
|
@ -358,6 +366,8 @@ class ConfigEntry:
|
|||
if result and integration.domain == self.domain:
|
||||
self.state = ENTRY_STATE_NOT_LOADED
|
||||
|
||||
self._async_process_on_unload()
|
||||
|
||||
return result
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
|
@ -470,6 +480,25 @@ class ConfigEntry:
|
|||
"disabled_by": self.disabled_by,
|
||||
}
|
||||
|
||||
@callback
|
||||
def async_on_unload(self, func: CALLBACK_TYPE) -> None:
|
||||
"""Add a function to call when config entry is unloaded."""
|
||||
if self._on_unload is None:
|
||||
self._on_unload = []
|
||||
self._on_unload.append(func)
|
||||
|
||||
@callback
|
||||
def _async_process_on_unload(self) -> None:
|
||||
"""Process the on_unload callbacks."""
|
||||
if self._on_unload is not None:
|
||||
while self._on_unload:
|
||||
self._on_unload.pop()()
|
||||
|
||||
|
||||
current_entry: ContextVar[ConfigEntry | None] = ContextVar(
|
||||
"current_entry", default=None
|
||||
)
|
||||
|
||||
|
||||
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
|
||||
"""Manage all the config entry flows that are in progress."""
|
||||
|
|
|
@ -5,7 +5,7 @@ import asyncio
|
|||
from contextlib import suppress
|
||||
from ssl import SSLContext
|
||||
import sys
|
||||
from typing import Any, Awaitable, cast
|
||||
from typing import Any, Awaitable, Callable, cast
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
@ -13,6 +13,7 @@ from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT
|
|||
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
|
||||
import async_timeout
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.helpers.frame import warn_use
|
||||
|
@ -27,6 +28,8 @@ SERVER_SOFTWARE = "HomeAssistant/{0} aiohttp/{1} Python/{2[0]}.{2[1]}".format(
|
|||
__version__, aiohttp.__version__, sys.version_info
|
||||
)
|
||||
|
||||
WARN_CLOSE_MSG = "closes the Home Assistant aiohttp session"
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
|
@ -37,12 +40,14 @@ def async_get_clientsession(
|
|||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
key = DATA_CLIENTSESSION_NOTVERIFY
|
||||
if verify_ssl:
|
||||
key = DATA_CLIENTSESSION
|
||||
key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY
|
||||
|
||||
if key not in hass.data:
|
||||
hass.data[key] = async_create_clientsession(hass, verify_ssl)
|
||||
hass.data[key] = _async_create_clientsession(
|
||||
hass,
|
||||
verify_ssl,
|
||||
auto_cleanup_method=_async_register_default_clientsession_shutdown,
|
||||
)
|
||||
|
||||
return cast(aiohttp.ClientSession, hass.data[key])
|
||||
|
||||
|
@ -59,24 +64,44 @@ def async_create_clientsession(
|
|||
|
||||
If auto_cleanup is False, you need to call detach() after the session
|
||||
returned is no longer used. Default is True, the session will be
|
||||
automatically detached on homeassistant_stop.
|
||||
automatically detached on homeassistant_stop or when being created
|
||||
in config entry setup, the config entry is unloaded.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
connector = _async_get_connector(hass, verify_ssl)
|
||||
auto_cleanup_method = None
|
||||
if auto_cleanup:
|
||||
auto_cleanup_method = _async_register_clientsession_shutdown
|
||||
|
||||
clientsession = _async_create_clientsession(
|
||||
hass,
|
||||
verify_ssl,
|
||||
auto_cleanup_method=auto_cleanup_method,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return clientsession
|
||||
|
||||
|
||||
@callback
|
||||
def _async_create_clientsession(
|
||||
hass: HomeAssistant,
|
||||
verify_ssl: bool = True,
|
||||
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
|
||||
| None = None,
|
||||
**kwargs: Any,
|
||||
) -> aiohttp.ClientSession:
|
||||
"""Create a new ClientSession with kwargs, i.e. for cookies."""
|
||||
clientsession = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
connector=_async_get_connector(hass, verify_ssl),
|
||||
headers={USER_AGENT: SERVER_SOFTWARE},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
clientsession.close = warn_use( # type: ignore
|
||||
clientsession.close, "closes the Home Assistant aiohttp session"
|
||||
)
|
||||
clientsession.close = warn_use(clientsession.close, WARN_CLOSE_MSG) # type: ignore
|
||||
|
||||
if auto_cleanup:
|
||||
_async_register_clientsession_shutdown(hass, clientsession)
|
||||
if auto_cleanup_method:
|
||||
auto_cleanup_method(hass, clientsession)
|
||||
|
||||
return clientsession
|
||||
|
||||
|
@ -146,7 +171,33 @@ async def async_aiohttp_proxy_stream(
|
|||
def _async_register_clientsession_shutdown(
|
||||
hass: HomeAssistant, clientsession: aiohttp.ClientSession
|
||||
) -> None:
|
||||
"""Register ClientSession close on Home Assistant shutdown.
|
||||
"""Register ClientSession close on Home Assistant shutdown or config entry unload.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
|
||||
@callback
|
||||
def _async_close_websession(*_: Any) -> None:
|
||||
"""Close websession."""
|
||||
clientsession.detach()
|
||||
|
||||
unsub = hass.bus.async_listen_once(
|
||||
EVENT_HOMEASSISTANT_CLOSE, _async_close_websession
|
||||
)
|
||||
|
||||
config_entry = config_entries.current_entry.get()
|
||||
if not config_entry:
|
||||
return
|
||||
|
||||
config_entry.async_on_unload(unsub)
|
||||
config_entry.async_on_unload(_async_close_websession)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_register_default_clientsession_shutdown(
|
||||
hass: HomeAssistant, clientsession: aiohttp.ClientSession
|
||||
) -> None:
|
||||
"""Register default ClientSession close on Home Assistant shutdown.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
"""Test the config manager."""
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow, loader
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_STARTED
|
||||
from homeassistant.core import CoreState, callback
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.aiohttp_client import async_create_clientsession
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt
|
||||
|
||||
|
@ -2489,3 +2490,97 @@ async def test_updating_entry_with_and_without_changes(manager):
|
|||
assert manager.async_update_entry(entry, title="newtitle") is True
|
||||
assert manager.async_update_entry(entry, unique_id="abc123") is False
|
||||
assert manager.async_update_entry(entry, unique_id="abc1234") is True
|
||||
|
||||
|
||||
async def test_entry_reload_calls_on_unload_listeners(hass, manager):
|
||||
"""Test reload calls the on unload listeners."""
|
||||
entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
async_setup = AsyncMock(return_value=True)
|
||||
mock_setup_entry = AsyncMock(return_value=True)
|
||||
async_unload_entry = AsyncMock(return_value=True)
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"comp",
|
||||
async_setup=async_setup,
|
||||
async_setup_entry=mock_setup_entry,
|
||||
async_unload_entry=async_unload_entry,
|
||||
),
|
||||
)
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
|
||||
mock_unload_callback = Mock()
|
||||
|
||||
entry.async_on_unload(mock_unload_callback)
|
||||
|
||||
assert await manager.async_reload(entry.entry_id)
|
||||
assert len(async_unload_entry.mock_calls) == 1
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
assert len(mock_unload_callback.mock_calls) == 1
|
||||
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||
|
||||
assert await manager.async_reload(entry.entry_id)
|
||||
assert len(async_unload_entry.mock_calls) == 2
|
||||
assert len(mock_setup_entry.mock_calls) == 2
|
||||
# Since we did not register another async_on_unload it should
|
||||
# have only been called once
|
||||
assert len(mock_unload_callback.mock_calls) == 1
|
||||
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||
|
||||
|
||||
async def test_entry_reload_cleans_up_aiohttp_session(hass, manager):
|
||||
"""Test reload cleans up aiohttp sessions their close listener created by the config entry."""
|
||||
entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED)
|
||||
entry.add_to_hass(hass)
|
||||
async_setup_calls = 0
|
||||
|
||||
async def async_setup_entry(hass, _):
|
||||
"""Mock setup entry."""
|
||||
nonlocal async_setup_calls
|
||||
async_setup_calls += 1
|
||||
async_create_clientsession(hass)
|
||||
return True
|
||||
|
||||
async_setup = AsyncMock(return_value=True)
|
||||
async_unload_entry = AsyncMock(return_value=True)
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
"comp",
|
||||
async_setup=async_setup,
|
||||
async_setup_entry=async_setup_entry,
|
||||
async_unload_entry=async_unload_entry,
|
||||
),
|
||||
)
|
||||
mock_entity_platform(hass, "config_flow.comp", None)
|
||||
|
||||
assert await manager.async_reload(entry.entry_id)
|
||||
assert len(async_unload_entry.mock_calls) == 1
|
||||
assert async_setup_calls == 1
|
||||
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||
|
||||
original_close_listeners = hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
|
||||
|
||||
assert await manager.async_reload(entry.entry_id)
|
||||
assert len(async_unload_entry.mock_calls) == 2
|
||||
assert async_setup_calls == 2
|
||||
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||
|
||||
assert (
|
||||
hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
|
||||
== original_close_listeners
|
||||
)
|
||||
|
||||
assert await manager.async_reload(entry.entry_id)
|
||||
assert len(async_unload_entry.mock_calls) == 3
|
||||
assert async_setup_calls == 3
|
||||
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||
|
||||
assert (
|
||||
hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
|
||||
== original_close_listeners
|
||||
)
|
||||
|
|
|
@ -288,7 +288,7 @@ def mock_aiohttp_client():
|
|||
return session
|
||||
|
||||
with mock.patch(
|
||||
"homeassistant.helpers.aiohttp_client.async_create_clientsession",
|
||||
"homeassistant.helpers.aiohttp_client._async_create_clientsession",
|
||||
side_effect=create_session,
|
||||
):
|
||||
yield mocker
|
||||
|
|
Loading…
Reference in a new issue