Hold a lock to prevent concurrent setup of config entries (#116482)

This commit is contained in:
J. Nick Koston 2024-04-30 18:47:12 -05:00 committed by GitHub
parent 3c7cbf5794
commit 6cf1c5c1f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 129 additions and 18 deletions

View file

@ -295,7 +295,7 @@ class ConfigEntry(Generic[_DataT]):
update_listeners: list[UpdateListenerType]
_async_cancel_retry_setup: Callable[[], Any] | None
_on_unload: list[Callable[[], Coroutine[Any, Any, None] | None]] | None
reload_lock: asyncio.Lock
setup_lock: asyncio.Lock
_reauth_lock: asyncio.Lock
_reconfigure_lock: asyncio.Lock
_tasks: set[asyncio.Future[Any]]
@ -403,7 +403,7 @@ class ConfigEntry(Generic[_DataT]):
_setter(self, "_on_unload", None)
# Reload lock to prevent conflicting reloads
_setter(self, "reload_lock", asyncio.Lock())
_setter(self, "setup_lock", asyncio.Lock())
# Reauth lock to prevent concurrent reauth flows
_setter(self, "_reauth_lock", asyncio.Lock())
# Reconfigure lock to prevent concurrent reconfigure flows
@ -702,19 +702,17 @@ class ConfigEntry(Generic[_DataT]):
# has started so we do not block shutdown
if not hass.is_stopping:
hass.async_create_background_task(
self._async_setup_retry(hass),
self.async_setup_locked(hass),
f"config entry retry {self.domain} {self.title}",
eager_start=True,
)
async def _async_setup_retry(self, hass: HomeAssistant) -> None:
"""Retry setup.
We hold the reload lock during setup retry to ensure
that nothing can reload the entry while we are retrying.
"""
async with self.reload_lock:
await self.async_setup(hass)
async def async_setup_locked(
self, hass: HomeAssistant, integration: loader.Integration | None = None
) -> None:
"""Set up while holding the setup lock."""
async with self.setup_lock:
await self.async_setup(hass, integration=integration)
@callback
def async_shutdown(self) -> None:
@ -1794,7 +1792,15 @@ class ConfigEntries:
# attempts.
entry.async_cancel_retry_setup()
async with entry.reload_lock:
if entry.domain not in self.hass.config.components:
# If the component is not loaded, just load it as
# the config entry will be loaded as well. We need
# to do this before holding the lock to avoid a
# deadlock.
await async_setup_component(self.hass, entry.domain, self._hass_config)
return entry.state is ConfigEntryState.LOADED
async with entry.setup_lock:
unload_result = await self.async_unload(entry_id)
if not unload_result or entry.disabled_by:

View file

@ -449,7 +449,7 @@ async def _async_setup_component(
await asyncio.gather(
*(
create_eager_task(
entry.async_setup(hass, integration=integration),
entry.async_setup_locked(hass, integration=integration),
name=f"config entry setup {entry.title} {entry.domain} {entry.entry_id}",
)
for entry in entries

View file

@ -324,6 +324,7 @@ async def test_user_flow_already_configured_host_changed_reloads_entry(
state=ConfigEntryState.LOADED,
)
mock_config_entry.add_to_hass(hass)
hass.config.components.add(DOMAIN)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
@ -640,6 +641,7 @@ async def test_zeroconf_flow_already_configured_host_changed_reloads_entry(
state=ConfigEntryState.LOADED,
)
mock_config_entry.add_to_hass(hass)
hass.config.components.add(DOMAIN)
result = await hass.config_entries.flow.async_init(
DOMAIN,
@ -769,6 +771,7 @@ async def test_reauth_flow_success(
state=ConfigEntryState.LOADED,
)
mock_config_entry.add_to_hass(hass)
hass.config.components.add(DOMAIN)
mock_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()

View file

@ -251,6 +251,7 @@ async def test_reload_entry(hass: HomeAssistant, client) -> None:
domain="kitchen_sink", state=core_ce.ConfigEntryState.LOADED
)
entry.add_to_hass(hass)
hass.config.components.add("kitchen_sink")
resp = await client.post(
f"/api/config/config_entries/entry/{entry.entry_id}/reload"
)
@ -298,6 +299,7 @@ async def test_reload_entry_in_failed_state(
"""Test reloading an entry via the API that has already failed to unload."""
entry = MockConfigEntry(domain="demo", state=core_ce.ConfigEntryState.FAILED_UNLOAD)
entry.add_to_hass(hass)
hass.config.components.add("demo")
resp = await client.post(
f"/api/config/config_entries/entry/{entry.entry_id}/reload"
)
@ -326,6 +328,7 @@ async def test_reload_entry_in_setup_retry(
entry = MockConfigEntry(domain="comp", state=core_ce.ConfigEntryState.SETUP_RETRY)
entry.supports_unload = True
entry.add_to_hass(hass)
hass.config.components.add("comp")
with patch.dict(HANDLERS, {"comp": ConfigFlow, "test": ConfigFlow}):
resp = await client.post(
@ -1109,6 +1112,7 @@ async def test_update_prefrences(
domain="kitchen_sink", state=core_ce.ConfigEntryState.LOADED
)
entry.add_to_hass(hass)
hass.config.components.add("kitchen_sink")
assert entry.pref_disable_new_entities is False
assert entry.pref_disable_polling is False
@ -1209,6 +1213,7 @@ async def test_disable_entry(
)
entry.add_to_hass(hass)
assert entry.disabled_by is None
hass.config.components.add("kitchen_sink")
# Disable
await ws_client.send_json(

View file

@ -1873,6 +1873,7 @@ async def test_reload_entry_with_restored_subscriptions(
# Setup the MQTT entry
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
entry.add_to_hass(hass)
hass.config.components.add(mqtt.DOMAIN)
mqtt_client_mock.connect.return_value = 0
with patch("homeassistant.config.load_yaml_config_file", return_value={}):
await entry.async_setup(hass)

View file

@ -279,6 +279,7 @@ async def test_form_valid_reauth(
) -> None:
"""Test that we can handle a valid reauth."""
mock_config_entry.mock_state(hass, ConfigEntryState.LOADED)
hass.config.components.add(DOMAIN)
mock_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
@ -328,6 +329,7 @@ async def test_form_valid_reauth_with_mfa(
},
)
mock_config_entry.mock_state(hass, ConfigEntryState.LOADED)
hass.config.components.add(DOMAIN)
mock_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()

View file

@ -825,7 +825,7 @@ async def test_as_dict(snapshot: SnapshotAssertion) -> None:
"error_reason_translation_placeholders",
"_async_cancel_retry_setup",
"_on_unload",
"reload_lock",
"setup_lock",
"_reauth_lock",
"_tasks",
"_background_tasks",
@ -1632,7 +1632,6 @@ async def test_entry_reload_succeed(
mock_platform(hass, "comp.config_flow", None)
assert await manager.async_reload(entry.entry_id)
assert len(async_unload_entry.mock_calls) == 1
assert len(async_setup.mock_calls) == 1
assert len(async_setup_entry.mock_calls) == 1
assert entry.state is config_entries.ConfigEntryState.LOADED
@ -1707,6 +1706,8 @@ async def test_entry_reload_error(
),
)
hass.config.components.add("comp")
with pytest.raises(config_entries.OperationNotAllowed, match=str(state)):
assert await manager.async_reload(entry.entry_id)
@ -1738,8 +1739,11 @@ async def test_entry_disable_succeed(
),
)
mock_platform(hass, "comp.config_flow", None)
hass.config.components.add("comp")
# Disable
assert len(async_setup.mock_calls) == 0
assert len(async_setup_entry.mock_calls) == 0
assert await manager.async_set_disabled_by(
entry.entry_id, config_entries.ConfigEntryDisabler.USER
)
@ -1751,7 +1755,7 @@ async def test_entry_disable_succeed(
# Enable
assert await manager.async_set_disabled_by(entry.entry_id, None)
assert len(async_unload_entry.mock_calls) == 1
assert len(async_setup.mock_calls) == 1
assert len(async_setup.mock_calls) == 0
assert len(async_setup_entry.mock_calls) == 1
assert entry.state is config_entries.ConfigEntryState.LOADED
@ -1775,6 +1779,7 @@ async def test_entry_disable_without_reload_support(
),
)
mock_platform(hass, "comp.config_flow", None)
hass.config.components.add("comp")
# Disable
assert not await manager.async_set_disabled_by(
@ -1951,7 +1956,7 @@ async def test_reload_entry_entity_registry_works(
)
await hass.async_block_till_done()
assert len(mock_unload_entry.mock_calls) == 2
assert len(mock_unload_entry.mock_calls) == 1
async def test_unique_id_persisted(
@ -3392,6 +3397,7 @@ async def test_entry_reload_calls_on_unload_listeners(
),
)
mock_platform(hass, "comp.config_flow", None)
hass.config.components.add("comp")
mock_unload_callback = Mock()
@ -3944,8 +3950,9 @@ async def test_deprecated_disabled_by_str_set(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test deprecated str set disabled_by enumizes and logs a warning."""
entry = MockConfigEntry()
entry = MockConfigEntry(domain="comp")
entry.add_to_manager(manager)
hass.config.components.add("comp")
assert await manager.async_set_disabled_by(
entry.entry_id, config_entries.ConfigEntryDisabler.USER.value
)
@ -3963,6 +3970,47 @@ async def test_entry_reload_concurrency(
async_setup = AsyncMock(return_value=True)
loaded = 1
async def _async_setup_entry(*args, **kwargs):
await asyncio.sleep(0)
nonlocal loaded
loaded += 1
return loaded == 1
async def _async_unload_entry(*args, **kwargs):
await asyncio.sleep(0)
nonlocal loaded
loaded -= 1
return loaded == 0
mock_integration(
hass,
MockModule(
"comp",
async_setup=async_setup,
async_setup_entry=_async_setup_entry,
async_unload_entry=_async_unload_entry,
),
)
mock_platform(hass, "comp.config_flow", None)
hass.config.components.add("comp")
tasks = [
asyncio.create_task(manager.async_reload(entry.entry_id)) for _ in range(15)
]
await asyncio.gather(*tasks)
assert entry.state is config_entries.ConfigEntryState.LOADED
assert loaded == 1
async def test_entry_reload_concurrency_not_setup_setup(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
"""Test multiple reload calls do not cause a reload race."""
entry = MockConfigEntry(domain="comp", state=config_entries.ConfigEntryState.LOADED)
entry.add_to_hass(hass)
async_setup = AsyncMock(return_value=True)
loaded = 0
async def _async_setup_entry(*args, **kwargs):
await asyncio.sleep(0)
nonlocal loaded
@ -4074,6 +4122,7 @@ async def test_disallow_entry_reload_with_setup_in_progress(
domain="comp", state=config_entries.ConfigEntryState.SETUP_IN_PROGRESS
)
entry.add_to_hass(hass)
hass.config.components.add("comp")
with pytest.raises(
config_entries.OperationNotAllowed,
@ -5016,3 +5065,48 @@ async def test_updating_non_added_entry_raises(hass: HomeAssistant) -> None:
with pytest.raises(config_entries.UnknownEntry, match=entry.entry_id):
hass.config_entries.async_update_entry(entry, unique_id="new_id")
async def test_reload_during_setup(hass: HomeAssistant) -> None:
"""Test reload during setup waits."""
entry = MockConfigEntry(domain="comp", data={"value": "initial"})
entry.add_to_hass(hass)
setup_start_future = hass.loop.create_future()
setup_finish_future = hass.loop.create_future()
in_setup = False
setup_calls = 0
async def mock_async_setup_entry(hass, entry):
"""Mock setting up an entry."""
nonlocal in_setup
nonlocal setup_calls
setup_calls += 1
assert not in_setup
in_setup = True
setup_start_future.set_result(None)
await setup_finish_future
in_setup = False
return True
mock_integration(
hass,
MockModule(
"comp",
async_setup_entry=mock_async_setup_entry,
async_unload_entry=AsyncMock(return_value=True),
),
)
mock_platform(hass, "comp.config_flow", None)
setup_task = hass.async_create_task(async_setup_component(hass, "comp", {}))
await setup_start_future # ensure we are in the setup
reload_task = hass.async_create_task(
hass.config_entries.async_reload(entry.entry_id)
)
await asyncio.sleep(0)
setup_finish_future.set_result(None)
await setup_task
await reload_task
assert setup_calls == 2