Detach aiohttp.ClientSession created by config entry setup on unload (#48908)

This commit is contained in:
J. Nick Koston 2021-04-09 07:14:33 -10:00 committed by GitHub
parent 8e2b5b36b5
commit 40450b9cfd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 192 additions and 17 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from contextvars import ContextVar
import functools
import logging
from types import MappingProxyType, MethodType
@ -133,6 +134,7 @@ class ConfigEntry:
"_setup_lock",
"update_listeners",
"_async_cancel_retry_setup",
"_on_unload",
)
def __init__(
@ -198,6 +200,9 @@ class ConfigEntry:
# Function to cancel a scheduled retry
self._async_cancel_retry_setup: Callable[[], Any] | None = None
# Hold list for functions to call on unload.
self._on_unload: list[CALLBACK_TYPE] | None = None
async def async_setup(
self,
hass: HomeAssistant,
@ -206,6 +211,7 @@ class ConfigEntry:
tries: int = 0,
) -> None:
"""Set up an entry."""
current_entry.set(self)
if self.source == SOURCE_IGNORE or self.disabled_by:
return
@ -290,6 +296,8 @@ class ConfigEntry:
self._async_cancel_retry_setup = hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STARTED, setup_again
)
self._async_process_on_unload()
return
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
@ -358,6 +366,8 @@ class ConfigEntry:
if result and integration.domain == self.domain:
self.state = ENTRY_STATE_NOT_LOADED
self._async_process_on_unload()
return result
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
@ -470,6 +480,25 @@ class ConfigEntry:
"disabled_by": self.disabled_by,
}
@callback
def async_on_unload(self, func: CALLBACK_TYPE) -> None:
"""Add a function to call when config entry is unloaded."""
if self._on_unload is None:
self._on_unload = []
self._on_unload.append(func)
@callback
def _async_process_on_unload(self) -> None:
"""Process the on_unload callbacks."""
if self._on_unload is not None:
while self._on_unload:
self._on_unload.pop()()
current_entry: ContextVar[ConfigEntry | None] = ContextVar(
"current_entry", default=None
)
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
"""Manage all the config entry flows that are in progress."""

View file

@ -5,7 +5,7 @@ import asyncio
from contextlib import suppress
from ssl import SSLContext
import sys
from typing import Any, Awaitable, cast
from typing import Any, Awaitable, Callable, cast
import aiohttp
from aiohttp import web
@ -13,6 +13,7 @@ from aiohttp.hdrs import CONTENT_TYPE, USER_AGENT
from aiohttp.web_exceptions import HTTPBadGateway, HTTPGatewayTimeout
import async_timeout
from homeassistant import config_entries
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.frame import warn_use
@ -27,6 +28,8 @@ SERVER_SOFTWARE = "HomeAssistant/{0} aiohttp/{1} Python/{2[0]}.{2[1]}".format(
__version__, aiohttp.__version__, sys.version_info
)
WARN_CLOSE_MSG = "closes the Home Assistant aiohttp session"
@callback
@bind_hass
@ -37,12 +40,14 @@ def async_get_clientsession(
This method must be run in the event loop.
"""
key = DATA_CLIENTSESSION_NOTVERIFY
if verify_ssl:
key = DATA_CLIENTSESSION
key = DATA_CLIENTSESSION if verify_ssl else DATA_CLIENTSESSION_NOTVERIFY
if key not in hass.data:
hass.data[key] = async_create_clientsession(hass, verify_ssl)
hass.data[key] = _async_create_clientsession(
hass,
verify_ssl,
auto_cleanup_method=_async_register_default_clientsession_shutdown,
)
return cast(aiohttp.ClientSession, hass.data[key])
@ -59,24 +64,44 @@ def async_create_clientsession(
If auto_cleanup is False, you need to call detach() after the session
returned is no longer used. Default is True, the session will be
automatically detached on homeassistant_stop.
automatically detached on homeassistant_stop or when being created
in config entry setup, the config entry is unloaded.
This method must be run in the event loop.
"""
connector = _async_get_connector(hass, verify_ssl)
auto_cleanup_method = None
if auto_cleanup:
auto_cleanup_method = _async_register_clientsession_shutdown
clientsession = _async_create_clientsession(
hass,
verify_ssl,
auto_cleanup_method=auto_cleanup_method,
**kwargs,
)
return clientsession
@callback
def _async_create_clientsession(
hass: HomeAssistant,
verify_ssl: bool = True,
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
| None = None,
**kwargs: Any,
) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession(
connector=connector,
connector=_async_get_connector(hass, verify_ssl),
headers={USER_AGENT: SERVER_SOFTWARE},
**kwargs,
)
clientsession.close = warn_use( # type: ignore
clientsession.close, "closes the Home Assistant aiohttp session"
)
clientsession.close = warn_use(clientsession.close, WARN_CLOSE_MSG) # type: ignore
if auto_cleanup:
_async_register_clientsession_shutdown(hass, clientsession)
if auto_cleanup_method:
auto_cleanup_method(hass, clientsession)
return clientsession
@ -146,7 +171,33 @@ async def async_aiohttp_proxy_stream(
def _async_register_clientsession_shutdown(
hass: HomeAssistant, clientsession: aiohttp.ClientSession
) -> None:
"""Register ClientSession close on Home Assistant shutdown.
"""Register ClientSession close on Home Assistant shutdown or config entry unload.
This method must be run in the event loop.
"""
@callback
def _async_close_websession(*_: Any) -> None:
"""Close websession."""
clientsession.detach()
unsub = hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, _async_close_websession
)
config_entry = config_entries.current_entry.get()
if not config_entry:
return
config_entry.async_on_unload(unsub)
config_entry.async_on_unload(_async_close_websession)
@callback
def _async_register_default_clientsession_shutdown(
hass: HomeAssistant, clientsession: aiohttp.ClientSession
) -> None:
"""Register default ClientSession close on Home Assistant shutdown.
This method must be run in the event loop.
"""

View file

@ -1,15 +1,16 @@
"""Test the config manager."""
import asyncio
from datetime import timedelta
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from homeassistant import config_entries, data_entry_flow, loader
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, callback
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_create_clientsession
from homeassistant.setup import async_setup_component
from homeassistant.util import dt
@ -2489,3 +2490,97 @@ async def test_updating_entry_with_and_without_changes(manager):
assert manager.async_update_entry(entry, title="newtitle") is True
assert manager.async_update_entry(entry, unique_id="abc123") is False
assert manager.async_update_entry(entry, unique_id="abc1234") is True
async def test_entry_reload_calls_on_unload_listeners(hass, manager):
"""Test reload calls the on unload listeners."""
entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED)
entry.add_to_hass(hass)
async_setup = AsyncMock(return_value=True)
mock_setup_entry = AsyncMock(return_value=True)
async_unload_entry = AsyncMock(return_value=True)
mock_integration(
hass,
MockModule(
"comp",
async_setup=async_setup,
async_setup_entry=mock_setup_entry,
async_unload_entry=async_unload_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
mock_unload_callback = Mock()
entry.async_on_unload(mock_unload_callback)
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
assert len(mock_unload_callback.mock_calls) == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 2
assert len(mock_setup_entry.mock_calls) == 2
# Since we did not register another async_on_unload it should
# have only been called once
assert len(mock_unload_callback.mock_calls) == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
async def test_entry_reload_cleans_up_aiohttp_session(hass, manager):
"""Test reload cleans up aiohttp sessions their close listener created by the config entry."""
entry = MockConfigEntry(domain="comp", state=config_entries.ENTRY_STATE_LOADED)
entry.add_to_hass(hass)
async_setup_calls = 0
async def async_setup_entry(hass, _):
"""Mock setup entry."""
nonlocal async_setup_calls
async_setup_calls += 1
async_create_clientsession(hass)
return True
async_setup = AsyncMock(return_value=True)
async_unload_entry = AsyncMock(return_value=True)
mock_integration(
hass,
MockModule(
"comp",
async_setup=async_setup,
async_setup_entry=async_setup_entry,
async_unload_entry=async_unload_entry,
),
)
mock_entity_platform(hass, "config_flow.comp", None)
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 1
assert async_setup_calls == 1
assert entry.state == config_entries.ENTRY_STATE_LOADED
original_close_listeners = hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 2
assert async_setup_calls == 2
assert entry.state == config_entries.ENTRY_STATE_LOADED
assert (
hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
== original_close_listeners
)
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 3
assert async_setup_calls == 3
assert entry.state == config_entries.ENTRY_STATE_LOADED
assert (
hass.bus.async_listeners()[EVENT_HOMEASSISTANT_CLOSE]
== original_close_listeners
)

View file

@ -288,7 +288,7 @@ def mock_aiohttp_client():
return session
with mock.patch(
"homeassistant.helpers.aiohttp_client.async_create_clientsession",
"homeassistant.helpers.aiohttp_client._async_create_clientsession",
side_effect=create_session,
):
yield mocker