Fix Kulersky and Zerproc config unloading (#47572)

This commit is contained in:
Emily Mills 2021-03-22 00:08:09 -05:00 committed by GitHub
parent f35641ae8e
commit 8557b856a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 59 additions and 30 deletions

View file

@ -4,7 +4,7 @@ import asyncio
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from .const import DOMAIN
from .const import DATA_ADDRESSES, DATA_DISCOVERY_SUBSCRIPTION, DOMAIN
PLATFORMS = ["light"]
@ -16,6 +16,11 @@ async def async_setup(hass: HomeAssistant, config: dict):
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Set up Kuler Sky from a config entry."""
if DOMAIN not in hass.data:
hass.data[DOMAIN] = {}
if DATA_ADDRESSES not in hass.data[DOMAIN]:
hass.data[DOMAIN][DATA_ADDRESSES] = set()
for platform in PLATFORMS:
hass.async_create_task(
hass.config_entries.async_forward_entry_setup(entry, platform)
@ -26,7 +31,14 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Unload a config entry."""
unload_ok = all(
# Stop discovery
unregister_discovery = hass.data[DOMAIN].pop(DATA_DISCOVERY_SUBSCRIPTION, None)
if unregister_discovery:
unregister_discovery()
hass.data.pop(DOMAIN, None)
return all(
await asyncio.gather(
*[
hass.config_entries.async_forward_entry_unload(entry, platform)
@ -34,7 +46,3 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
]
)
)
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok

View file

@ -1,2 +1,5 @@
"""Constants for the Kuler Sky integration."""
DOMAIN = "kulersky"
DATA_ADDRESSES = "addresses"
DATA_DISCOVERY_SUBSCRIPTION = "discovery_subscription"

View file

@ -18,13 +18,12 @@ from homeassistant.components.light import (
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import HomeAssistantType
import homeassistant.util.color as color_util
from .const import DOMAIN
from .const import DATA_ADDRESSES, DATA_DISCOVERY_SUBSCRIPTION, DOMAIN
_LOGGER = logging.getLogger(__name__)
@ -39,10 +38,6 @@ async def async_setup_entry(
async_add_entities: Callable[[list[Entity], bool], None],
) -> None:
"""Set up Kuler sky light devices."""
if DOMAIN not in hass.data:
hass.data[DOMAIN] = {}
if "addresses" not in hass.data[DOMAIN]:
hass.data[DOMAIN]["addresses"] = set()
async def discover(*args):
"""Attempt to discover new lights."""
@ -52,12 +47,12 @@ async def async_setup_entry(
new_lights = [
light
for light in lights
if light.address not in hass.data[DOMAIN]["addresses"]
if light.address not in hass.data[DOMAIN][DATA_ADDRESSES]
]
new_entities = []
for light in new_lights:
hass.data[DOMAIN]["addresses"].add(light.address)
hass.data[DOMAIN][DATA_ADDRESSES].add(light.address)
new_entities.append(KulerskyLight(light))
async_add_entities(new_entities, update_before_add=True)
@ -66,12 +61,9 @@ async def async_setup_entry(
hass.async_create_task(discover())
# Perform recurring discovery of new devices
async_track_time_interval(hass, discover, DISCOVERY_INTERVAL)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Cleanup the Kuler sky integration."""
hass.data.pop(DOMAIN, None)
hass.data[DOMAIN][DATA_DISCOVERY_SUBSCRIPTION] = async_track_time_interval(
hass, discover, DISCOVERY_INTERVAL
)
class KulerskyLight(LightEntity):

View file

@ -4,7 +4,7 @@ import asyncio
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.core import HomeAssistant
from .const import DOMAIN
from .const import DATA_ADDRESSES, DATA_DISCOVERY_SUBSCRIPTION, DOMAIN
PLATFORMS = ["light"]
@ -22,8 +22,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Set up Zerproc from a config entry."""
if DOMAIN not in hass.data:
hass.data[DOMAIN] = {}
if "addresses" not in hass.data[DOMAIN]:
hass.data[DOMAIN]["addresses"] = set()
if DATA_ADDRESSES not in hass.data[DOMAIN]:
hass.data[DOMAIN][DATA_ADDRESSES] = set()
for platform in PLATFORMS:
hass.async_create_task(
@ -35,7 +35,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Unload a config entry."""
# Stop discovery
unregister_discovery = hass.data[DOMAIN].pop(DATA_DISCOVERY_SUBSCRIPTION, None)
if unregister_discovery:
unregister_discovery()
hass.data.pop(DOMAIN, None)
return all(
await asyncio.gather(
*[

View file

@ -1,2 +1,5 @@
"""Constants for the Zerproc integration."""
DOMAIN = "zerproc"
DATA_ADDRESSES = "addresses"
DATA_DISCOVERY_SUBSCRIPTION = "discovery_subscription"

View file

@ -22,7 +22,7 @@ from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import HomeAssistantType
import homeassistant.util.color as color_util
from .const import DOMAIN
from .const import DATA_ADDRESSES, DATA_DISCOVERY_SUBSCRIPTION, DOMAIN
_LOGGER = logging.getLogger(__name__)
@ -37,12 +37,14 @@ async def discover_entities(hass: HomeAssistant) -> list[Entity]:
# Filter out already discovered lights
new_lights = [
light for light in lights if light.address not in hass.data[DOMAIN]["addresses"]
light
for light in lights
if light.address not in hass.data[DOMAIN][DATA_ADDRESSES]
]
entities = []
for light in new_lights:
hass.data[DOMAIN]["addresses"].add(light.address)
hass.data[DOMAIN][DATA_ADDRESSES].add(light.address)
entities.append(ZerprocLight(light))
return entities
@ -72,7 +74,9 @@ async def async_setup_entry(
hass.async_create_task(discover())
# Perform recurring discovery of new devices
async_track_time_interval(hass, discover, DISCOVERY_INTERVAL)
hass.data[DOMAIN][DATA_DISCOVERY_SUBSCRIPTION] = async_track_time_interval(
hass, discover, DISCOVERY_INTERVAL
)
class ZerprocLight(LightEntity):

View file

@ -5,7 +5,11 @@ import pykulersky
import pytest
from homeassistant import setup
from homeassistant.components.kulersky.light import DOMAIN
from homeassistant.components.kulersky.const import (
DATA_ADDRESSES,
DATA_DISCOVERY_SUBSCRIPTION,
DOMAIN,
)
from homeassistant.components.light import (
ATTR_BRIGHTNESS,
ATTR_COLOR_MODE,
@ -85,9 +89,13 @@ async def test_init(hass, mock_light):
async def test_remove_entry(hass, mock_light, mock_entry):
"""Test platform setup."""
assert hass.data[DOMAIN][DATA_ADDRESSES] == {"AA:BB:CC:11:22:33"}
assert DATA_DISCOVERY_SUBSCRIPTION in hass.data[DOMAIN]
await hass.config_entries.async_remove(mock_entry.entry_id)
assert mock_light.disconnect.called
assert DOMAIN not in hass.data
async def test_remove_entry_exceptions_caught(hass, mock_light, mock_entry):

View file

@ -17,7 +17,11 @@ from homeassistant.components.light import (
SUPPORT_BRIGHTNESS,
SUPPORT_COLOR,
)
from homeassistant.components.zerproc.light import DOMAIN
from homeassistant.components.zerproc.const import (
DATA_ADDRESSES,
DATA_DISCOVERY_SUBSCRIPTION,
DOMAIN,
)
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_FRIENDLY_NAME,
@ -146,7 +150,8 @@ async def test_discovery_exception(hass, mock_entry):
async def test_remove_entry(hass, mock_light, mock_entry):
"""Test platform setup."""
assert hass.data[DOMAIN]["addresses"] == {"AA:BB:CC:DD:EE:FF"}
assert hass.data[DOMAIN][DATA_ADDRESSES] == {"AA:BB:CC:DD:EE:FF"}
assert DATA_DISCOVERY_SUBSCRIPTION in hass.data[DOMAIN]
with patch.object(mock_light, "disconnect") as mock_disconnect:
await hass.config_entries.async_remove(mock_entry.entry_id)