Entity layer cleanup (#12237)

* Simplify entity update

* Split entity platform from entity component

* Decouple entity platform from entity component

* Always include unit of measurement again

* Lint

* Fix test
This commit is contained in:
Paulus Schoutsen 2018-02-08 03:16:51 -08:00 committed by Pascal Vizeli
parent 8523933605
commit 5601fbdc7a
7 changed files with 905 additions and 857 deletions

View file

@ -152,7 +152,7 @@ class Entity(object):
@property
def assumed_state(self) -> bool:
"""Return True if unable to access real state of the entity."""
return None
return False
@property
def force_update(self) -> bool:
@ -221,21 +221,41 @@ class Entity(object):
if device_attr is not None:
attr.update(device_attr)
self._attr_setter('unit_of_measurement', str, ATTR_UNIT_OF_MEASUREMENT,
attr)
unit_of_measurement = self.unit_of_measurement
if unit_of_measurement is not None:
attr[ATTR_UNIT_OF_MEASUREMENT] = unit_of_measurement
self._attr_setter('name', str, ATTR_FRIENDLY_NAME, attr)
self._attr_setter('icon', str, ATTR_ICON, attr)
self._attr_setter('entity_picture', str, ATTR_ENTITY_PICTURE, attr)
self._attr_setter('hidden', bool, ATTR_HIDDEN, attr)
self._attr_setter('assumed_state', bool, ATTR_ASSUMED_STATE, attr)
self._attr_setter('supported_features', int, ATTR_SUPPORTED_FEATURES,
attr)
self._attr_setter('device_class', str, ATTR_DEVICE_CLASS, attr)
name = self.name
if name is not None:
attr[ATTR_FRIENDLY_NAME] = name
icon = self.icon
if icon is not None:
attr[ATTR_ICON] = icon
entity_picture = self.entity_picture
if entity_picture is not None:
attr[ATTR_ENTITY_PICTURE] = entity_picture
hidden = self.hidden
if hidden:
attr[ATTR_HIDDEN] = hidden
assumed_state = self.assumed_state
if assumed_state:
attr[ATTR_ASSUMED_STATE] = assumed_state
supported_features = self.supported_features
if supported_features is not None:
attr[ATTR_SUPPORTED_FEATURES] = supported_features
device_class = self.device_class
if device_class is not None:
attr[ATTR_DEVICE_CLASS] = str(device_class)
end = timer()
if not self._slow_reported and end - start > 0.4:
if end - start > 0.4 and not self._slow_reported:
self._slow_reported = True
_LOGGER.warning("Updating state for %s (%s) took %.3f seconds. "
"Please report platform to the developers at "
@ -246,10 +266,6 @@ class Entity(object):
if DATA_CUSTOMIZE in self.hass.data:
attr.update(self.hass.data[DATA_CUSTOMIZE].get(self.entity_id))
# Remove hidden property if false so it won't show up.
if not attr.get(ATTR_HIDDEN, True):
attr.pop(ATTR_HIDDEN)
# Convert temperature if we detect one
try:
unit_of_measure = attr.get(ATTR_UNIT_OF_MEASUREMENT)
@ -321,21 +337,6 @@ class Entity(object):
else:
self.hass.states.async_remove(self.entity_id)
def _attr_setter(self, name, typ, attr, attrs):
"""Populate attributes based on properties."""
if attr in attrs:
return
value = getattr(self, name)
if value is None:
return
try:
attrs[attr] = typ(value)
except (TypeError, ValueError):
pass
def __eq__(self, other):
"""Return the comparison."""
if not isinstance(other, self.__class__):

View file

@ -6,25 +6,15 @@ from itertools import chain
from homeassistant import config as conf_util
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.const import (
ATTR_ENTITY_ID, CONF_SCAN_INTERVAL, CONF_ENTITY_NAMESPACE,
DEVICE_DEFAULT_NAME)
from homeassistant.core import callback, valid_entity_id, split_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
ATTR_ENTITY_ID, CONF_SCAN_INTERVAL, CONF_ENTITY_NAMESPACE)
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.event import (
async_track_time_interval, async_track_point_in_time)
from homeassistant.helpers.service import extract_entity_ids
from homeassistant.util import slugify
from homeassistant.util.async import (
run_callback_threadsafe, run_coroutine_threadsafe)
import homeassistant.util.dt as dt_util
from .entity_registry import EntityRegistry
from .entity_platform import EntityPlatform
DEFAULT_SCAN_INTERVAL = timedelta(seconds=15)
SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60
PLATFORM_NOT_READY_RETRIES = 10
DATA_REGISTRY = 'entity_registry'
class EntityComponent(object):
@ -43,16 +33,23 @@ class EntityComponent(object):
"""Initialize an entity component."""
self.logger = logger
self.hass = hass
self.domain = domain
self.entity_id_format = domain + '.{}'
self.scan_interval = scan_interval
self.group_name = group_name
self.config = None
self._platforms = {
'core': EntityPlatform(self, domain, self.scan_interval, 0, None),
'core': EntityPlatform(
hass=hass,
logger=logger,
domain=domain,
platform_name='core',
scan_interval=self.scan_interval,
parallel_updates=0,
entity_namespace=None,
async_entities_added_callback=self._async_update_group,
)
}
self.async_add_entities = self._platforms['core'].async_add_entities
self.add_entities = self._platforms['core'].add_entities
@ -107,17 +104,6 @@ class EntityComponent(object):
discovery.async_listen_platform(
self.hass, self.domain, component_platform_discovered)
def extract_from_service(self, service, expand_group=True):
"""Extract all known entities from a service call.
Will return all entities if no entities specified in call.
Will return an empty list if entities specified but unknown.
"""
return run_callback_threadsafe(
self.hass.loop, self.async_extract_from_service, service,
expand_group
).result()
@callback
def async_extract_from_service(self, service, expand_group=True):
"""Extract all known and available entities from a service call.
@ -136,11 +122,8 @@ class EntityComponent(object):
@asyncio.coroutine
def _async_setup_platform(self, platform_type, platform_config,
discovery_info=None, tries=0):
"""Set up a platform for this component.
This method must be run in the event loop.
"""
discovery_info=None):
"""Set up a platform for this component."""
platform = yield from async_prepare_setup_platform(
self.hass, self.config, self.domain, platform_type)
@ -161,59 +144,23 @@ class EntityComponent(object):
if key not in self._platforms:
entity_platform = self._platforms[key] = EntityPlatform(
self, platform_type, scan_interval, parallel_updates,
entity_namespace)
hass=self.hass,
logger=self.logger,
domain=self.domain,
platform_name=platform_type,
scan_interval=scan_interval,
parallel_updates=parallel_updates,
entity_namespace=entity_namespace,
async_entities_added_callback=self._async_update_group,
)
else:
entity_platform = self._platforms[key]
self.logger.info("Setting up %s.%s", self.domain, platform_type)
warn_task = self.hass.loop.call_later(
SLOW_SETUP_WARNING, self.logger.warning,
"Setup of platform %s is taking over %s seconds.", platform_type,
SLOW_SETUP_WARNING)
try:
if getattr(platform, 'async_setup_platform', None):
task = platform.async_setup_platform(
self.hass, platform_config,
entity_platform.async_schedule_add_entities, discovery_info
)
else:
# This should not be replaced with hass.async_add_job because
# we don't want to track this task in case it blocks startup.
task = self.hass.loop.run_in_executor(
None, platform.setup_platform, self.hass, platform_config,
entity_platform.schedule_add_entities, discovery_info
)
yield from asyncio.wait_for(
asyncio.shield(task, loop=self.hass.loop),
SLOW_SETUP_MAX_WAIT, loop=self.hass.loop)
yield from entity_platform.async_block_entities_done()
self.hass.config.components.add(
'{}.{}'.format(self.domain, platform_type))
except PlatformNotReady:
tries += 1
wait_time = min(tries, 6) * 30
self.logger.warning(
'Platform %s not ready yet. Retrying in %d seconds.',
platform_type, wait_time)
async_track_point_in_time(
self.hass, self._async_setup_platform(
platform_type, platform_config, discovery_info, tries),
dt_util.utcnow() + timedelta(seconds=wait_time))
except asyncio.TimeoutError:
self.logger.error(
"Setup of platform %s is taking longer than %s seconds."
" Startup will proceed without waiting any longer.",
platform_type, SLOW_SETUP_MAX_WAIT)
except Exception: # pylint: disable=broad-except
self.logger.exception(
"Error while setting up platform %s", platform_type)
finally:
warn_task.cancel()
yield from entity_platform.async_setup(
platform, platform_config, discovery_info)
@callback
def async_update_group(self):
def _async_update_group(self):
"""Set up and/or update component group.
This method must be run in the event loop.
@ -230,12 +177,8 @@ class EntityComponent(object):
visible=False, entity_ids=ids
)
def reset(self):
"""Remove entities and reset the entity component to initial values."""
run_coroutine_threadsafe(self.async_reset(), self.hass.loop).result()
@asyncio.coroutine
def async_reset(self):
def _async_reset(self):
"""Remove entities and reset the entity component to initial values.
This method must be run in the event loop.
@ -261,11 +204,6 @@ class EntityComponent(object):
if entity_id in platform.entities:
yield from platform.async_remove_entity(entity_id)
def prepare_reload(self):
"""Prepare reloading this entity component."""
return run_coroutine_threadsafe(
self.async_prepare_reload(), loop=self.hass.loop).result()
@asyncio.coroutine
def async_prepare_reload(self):
"""Prepare reloading this entity component.
@ -285,239 +223,5 @@ class EntityComponent(object):
if conf is None:
return None
yield from self.async_reset()
yield from self._async_reset()
return conf
class EntityPlatform(object):
"""Manage the entities for a single platform."""
def __init__(self, component, platform, scan_interval, parallel_updates,
entity_namespace):
"""Initialize the entity platform."""
self.component = component
self.platform = platform
self.scan_interval = scan_interval
self.parallel_updates = None
self.entity_namespace = entity_namespace
self.entities = {}
self._tasks = []
self._async_unsub_polling = None
self._process_updates = asyncio.Lock(loop=component.hass.loop)
if parallel_updates:
self.parallel_updates = asyncio.Semaphore(
parallel_updates, loop=component.hass.loop)
@asyncio.coroutine
def async_block_entities_done(self):
"""Wait until all entities add to hass."""
if self._tasks:
pending = [task for task in self._tasks if not task.done()]
self._tasks.clear()
if pending:
yield from asyncio.wait(pending, loop=self.component.hass.loop)
def schedule_add_entities(self, new_entities, update_before_add=False):
"""Add entities for a single platform."""
run_callback_threadsafe(
self.component.hass.loop,
self.async_schedule_add_entities, list(new_entities),
update_before_add
).result()
@callback
def async_schedule_add_entities(self, new_entities,
update_before_add=False):
"""Add entities for a single platform async."""
self._tasks.append(self.component.hass.async_add_job(
self.async_add_entities(
new_entities, update_before_add=update_before_add)
))
def add_entities(self, new_entities, update_before_add=False):
"""Add entities for a single platform."""
# That avoid deadlocks
if update_before_add:
self.component.logger.warning(
"Call 'add_entities' with update_before_add=True "
"only inside tests or you can run into a deadlock!")
run_coroutine_threadsafe(
self.async_add_entities(list(new_entities), update_before_add),
self.component.hass.loop).result()
@asyncio.coroutine
def async_add_entities(self, new_entities, update_before_add=False):
"""Add entities for a single platform async.
This method must be run in the event loop.
"""
# handle empty list from component/platform
if not new_entities:
return
hass = self.component.hass
component_entities = set(entity.entity_id for entity
in self.component.entities)
registry = hass.data.get(DATA_REGISTRY)
if registry is None:
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
yield from registry.async_ensure_loaded()
tasks = [
self._async_add_entity(entity, update_before_add,
component_entities, registry)
for entity in new_entities]
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
self.component.async_update_group()
if self._async_unsub_polling is not None or \
not any(entity.should_poll for entity
in self.entities.values()):
return
self._async_unsub_polling = async_track_time_interval(
self.component.hass, self._update_entity_states, self.scan_interval
)
@asyncio.coroutine
def _async_add_entity(self, entity, update_before_add, component_entities,
registry):
"""Helper method to add an entity to the platform."""
if entity is None:
raise ValueError('Entity cannot be None')
entity.hass = self.component.hass
entity.platform = self
entity.parallel_updates = self.parallel_updates
# Update properties before we generate the entity_id
if update_before_add:
try:
yield from entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.component.logger.exception(
"%s: Error on device update!", self.platform)
return
suggested_object_id = None
# Get entity_id from unique ID registration
if entity.unique_id is not None:
if entity.entity_id is not None:
suggested_object_id = split_entity_id(entity.entity_id)[1]
else:
suggested_object_id = entity.name
entry = registry.async_get_or_create(
self.component.domain, self.platform, entity.unique_id,
suggested_object_id=suggested_object_id)
entity.entity_id = entry.entity_id
# We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID
elif (entity.entity_id is not None and
registry.async_is_registered(entity.entity_id)):
# If entity already registered, convert entity id to suggestion
suggested_object_id = split_entity_id(entity.entity_id)[1]
entity.entity_id = None
# Generate entity ID
if entity.entity_id is None:
suggested_object_id = \
suggested_object_id or entity.name or DEVICE_DEFAULT_NAME
if self.entity_namespace is not None:
suggested_object_id = '{} {}'.format(self.entity_namespace,
suggested_object_id)
entity.entity_id = registry.async_generate_entity_id(
self.component.domain, suggested_object_id)
# Make sure it is valid in case an entity set the value themselves
if not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
elif entity.entity_id in component_entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
component_entities.add(entity.entity_id)
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
@asyncio.coroutine
def async_reset(self):
"""Remove all entities and reset data.
This method must be run in the event loop.
"""
if not self.entities:
return
tasks = [self._async_remove_entity(entity_id)
for entity_id in self.entities]
yield from asyncio.wait(tasks, loop=self.component.hass.loop)
if self._async_unsub_polling is not None:
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
yield from self._async_remove_entity(entity_id)
# Clean up polling job if no longer needed
if (self._async_unsub_polling is not None and
not any(entity.should_poll for entity
in self.entities.values())):
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def _async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
entity = self.entities.pop(entity_id)
if hasattr(entity, 'async_will_remove_from_hass'):
yield from entity.async_will_remove_from_hass()
self.component.hass.states.async_remove(entity_id)
@asyncio.coroutine
def _update_entity_states(self, now):
"""Update the states of all the polling entities.
To protect from flooding the executor, we will update async entities
in parallel and other entities sequential.
This method must be run in the event loop.
"""
if self._process_updates.locked():
self.component.logger.warning(
"Updating %s %s took longer than the scheduled update "
"interval %s", self.platform, self.component.domain,
self.scan_interval)
return
with (yield from self._process_updates):
tasks = []
for entity in self.entities.values():
if not entity.should_poll:
continue
tasks.append(entity.async_update_ha_state(True))
if tasks:
yield from asyncio.wait(tasks, loop=self.component.hass.loop)

View file

@ -0,0 +1,317 @@
"""Class to manage the entities for a single platform."""
import asyncio
from datetime import timedelta
from homeassistant.const import DEVICE_DEFAULT_NAME
from homeassistant.core import callback, valid_entity_id, split_entity_id
from homeassistant.exceptions import HomeAssistantError, PlatformNotReady
from homeassistant.util.async import (
run_callback_threadsafe, run_coroutine_threadsafe)
import homeassistant.util.dt as dt_util
from .event import async_track_time_interval, async_track_point_in_time
from .entity_registry import EntityRegistry
SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60
PLATFORM_NOT_READY_RETRIES = 10
DATA_REGISTRY = 'entity_registry'
class EntityPlatform(object):
"""Manage the entities for a single platform."""
def __init__(self, *, hass, logger, domain, platform_name, scan_interval,
parallel_updates, entity_namespace,
async_entities_added_callback):
"""Initialize the entity platform.
hass: HomeAssistant
logger: Logger
domain: str
platform_name: str
scan_interval: timedelta
parallel_updates: int
entity_namespace: str
async_entities_added_callback: @callback method
"""
self.hass = hass
self.logger = logger
self.domain = domain
self.platform_name = platform_name
self.scan_interval = scan_interval
self.parallel_updates = None
self.entity_namespace = entity_namespace
self.async_entities_added_callback = async_entities_added_callback
self.entities = {}
self._tasks = []
self._async_unsub_polling = None
self._process_updates = asyncio.Lock(loop=hass.loop)
if parallel_updates:
self.parallel_updates = asyncio.Semaphore(
parallel_updates, loop=hass.loop)
@asyncio.coroutine
def async_setup(self, platform, platform_config, discovery_info=None,
tries=0):
"""Setup the platform."""
logger = self.logger
hass = self.hass
full_name = '{}.{}'.format(self.domain, self.platform_name)
logger.info("Setting up %s", full_name)
warn_task = hass.loop.call_later(
SLOW_SETUP_WARNING, logger.warning,
"Setup of platform %s is taking over %s seconds.",
self.platform_name, SLOW_SETUP_WARNING)
try:
if getattr(platform, 'async_setup_platform', None):
task = platform.async_setup_platform(
hass, platform_config,
self._async_schedule_add_entities, discovery_info
)
else:
# This should not be replaced with hass.async_add_job because
# we don't want to track this task in case it blocks startup.
task = hass.loop.run_in_executor(
None, platform.setup_platform, hass, platform_config,
self._schedule_add_entities, discovery_info
)
yield from asyncio.wait_for(
asyncio.shield(task, loop=hass.loop),
SLOW_SETUP_MAX_WAIT, loop=hass.loop)
# Block till all entities are done
if self._tasks:
pending = [task for task in self._tasks if not task.done()]
self._tasks.clear()
if pending:
yield from asyncio.wait(
pending, loop=self.hass.loop)
hass.config.components.add(full_name)
except PlatformNotReady:
tries += 1
wait_time = min(tries, 6) * 30
logger.warning(
'Platform %s not ready yet. Retrying in %d seconds.',
self.platform_name, wait_time)
async_track_point_in_time(
hass, self.async_setup(
platform, platform_config, discovery_info, tries),
dt_util.utcnow() + timedelta(seconds=wait_time))
except asyncio.TimeoutError:
logger.error(
"Setup of platform %s is taking longer than %s seconds."
" Startup will proceed without waiting any longer.",
self.platform_name, SLOW_SETUP_MAX_WAIT)
except Exception: # pylint: disable=broad-except
logger.exception(
"Error while setting up platform %s", self.platform_name)
finally:
warn_task.cancel()
def _schedule_add_entities(self, new_entities, update_before_add=False):
"""Synchronously schedule adding entities for a single platform."""
run_callback_threadsafe(
self.hass.loop,
self._async_schedule_add_entities, list(new_entities),
update_before_add
).result()
@callback
def _async_schedule_add_entities(self, new_entities,
update_before_add=False):
"""Schedule adding entities for a single platform async."""
self._tasks.append(self.hass.async_add_job(
self.async_add_entities(
new_entities, update_before_add=update_before_add)
))
def add_entities(self, new_entities, update_before_add=False):
"""Add entities for a single platform."""
# That avoid deadlocks
if update_before_add:
self.logger.warning(
"Call 'add_entities' with update_before_add=True "
"only inside tests or you can run into a deadlock!")
run_coroutine_threadsafe(
self.async_add_entities(list(new_entities), update_before_add),
self.hass.loop).result()
@asyncio.coroutine
def async_add_entities(self, new_entities, update_before_add=False):
"""Add entities for a single platform async.
This method must be run in the event loop.
"""
# handle empty list from component/platform
if not new_entities:
return
hass = self.hass
component_entities = set(hass.states.async_entity_ids(self.domain))
registry = hass.data.get(DATA_REGISTRY)
if registry is None:
registry = hass.data[DATA_REGISTRY] = EntityRegistry(hass)
yield from registry.async_ensure_loaded()
tasks = [
self._async_add_entity(entity, update_before_add,
component_entities, registry)
for entity in new_entities]
yield from asyncio.wait(tasks, loop=self.hass.loop)
self.async_entities_added_callback()
if self._async_unsub_polling is not None or \
not any(entity.should_poll for entity
in self.entities.values()):
return
self._async_unsub_polling = async_track_time_interval(
self.hass, self._update_entity_states, self.scan_interval
)
@asyncio.coroutine
def _async_add_entity(self, entity, update_before_add, component_entities,
registry):
"""Helper method to add an entity to the platform."""
if entity is None:
raise ValueError('Entity cannot be None')
entity.hass = self.hass
entity.platform = self
entity.parallel_updates = self.parallel_updates
# Update properties before we generate the entity_id
if update_before_add:
try:
yield from entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.logger.exception(
"%s: Error on device update!", self.platform_name)
return
suggested_object_id = None
# Get entity_id from unique ID registration
if entity.unique_id is not None:
if entity.entity_id is not None:
suggested_object_id = split_entity_id(entity.entity_id)[1]
else:
suggested_object_id = entity.name
entry = registry.async_get_or_create(
self.domain, self.platform_name, entity.unique_id,
suggested_object_id=suggested_object_id)
entity.entity_id = entry.entity_id
# We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID
elif (entity.entity_id is not None and
registry.async_is_registered(entity.entity_id)):
# If entity already registered, convert entity id to suggestion
suggested_object_id = split_entity_id(entity.entity_id)[1]
entity.entity_id = None
# Generate entity ID
if entity.entity_id is None:
suggested_object_id = \
suggested_object_id or entity.name or DEVICE_DEFAULT_NAME
if self.entity_namespace is not None:
suggested_object_id = '{} {}'.format(self.entity_namespace,
suggested_object_id)
entity.entity_id = registry.async_generate_entity_id(
self.domain, suggested_object_id)
# Make sure it is valid in case an entity set the value themselves
if not valid_entity_id(entity.entity_id):
raise HomeAssistantError(
'Invalid entity id: {}'.format(entity.entity_id))
elif entity.entity_id in component_entities:
raise HomeAssistantError(
'Entity id already exists: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
component_entities.add(entity.entity_id)
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
@asyncio.coroutine
def async_reset(self):
"""Remove all entities and reset data.
This method must be run in the event loop.
"""
if not self.entities:
return
tasks = [self._async_remove_entity(entity_id)
for entity_id in self.entities]
yield from asyncio.wait(tasks, loop=self.hass.loop)
if self._async_unsub_polling is not None:
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
yield from self._async_remove_entity(entity_id)
# Clean up polling job if no longer needed
if (self._async_unsub_polling is not None and
not any(entity.should_poll for entity
in self.entities.values())):
self._async_unsub_polling()
self._async_unsub_polling = None
@asyncio.coroutine
def _async_remove_entity(self, entity_id):
"""Remove entity id from platform."""
entity = self.entities.pop(entity_id)
if hasattr(entity, 'async_will_remove_from_hass'):
yield from entity.async_will_remove_from_hass()
self.hass.states.async_remove(entity_id)
@asyncio.coroutine
def _update_entity_states(self, now):
"""Update the states of all the polling entities.
To protect from flooding the executor, we will update async entities
in parallel and other entities sequential.
This method must be run in the event loop.
"""
if self._process_updates.locked():
self.logger.warning(
"Updating %s %s took longer than the scheduled update "
"interval %s", self.platform_name, self.domain,
self.scan_interval)
return
with (yield from self._process_updates):
tasks = []
for entity in self.entities.values():
if not entity.should_poll:
continue
tasks.append(entity.async_update_ha_state(True))
if tasks:
yield from asyncio.wait(tasks, loop=self.hass.loop)

View file

@ -14,7 +14,9 @@ from aiohttp import web
from homeassistant import core as ha, loader
from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config
from homeassistant.helpers import intent, dispatcher, entity, restore_state
from homeassistant.helpers import (
intent, dispatcher, entity, restore_state, entity_registry,
entity_platform)
from homeassistant.util.unit_system import METRIC_SYSTEM
import homeassistant.util.dt as date_util
import homeassistant.util.yaml as yaml
@ -22,7 +24,6 @@ from homeassistant.const import (
STATE_ON, STATE_OFF, DEVICE_DEFAULT_NAME, EVENT_TIME_CHANGED,
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE)
from homeassistant.helpers import entity_component, entity_registry
from homeassistant.components import mqtt, recorder
from homeassistant.components.http.auth import auth_middleware
from homeassistant.components.http.const import (
@ -320,7 +321,7 @@ def mock_registry(hass):
"""Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass)
registry.entities = {}
hass.data[entity_component.DATA_REGISTRY] = registry
hass.data[entity_platform.DATA_REGISTRY] = registry
return registry
@ -585,3 +586,40 @@ class MockDependency:
func(*args, **kwargs)
return run_mocked
class MockEntity(entity.Entity):
"""Mock Entity class."""
def __init__(self, **values):
"""Initialize an entity."""
self._values = values
if 'entity_id' in values:
self.entity_id = values['entity_id']
@property
def name(self):
"""Return the name of the entity."""
return self._handle('name')
@property
def should_poll(self):
"""Return the ste of the polling."""
return self._handle('should_poll')
@property
def unique_id(self):
"""Return the unique ID of the entity."""
return self._handle('unique_id')
@property
def available(self):
"""Return True if entity is available."""
return self._handle('available')
def _handle(self, attr):
"""Helper for the attributes."""
if attr in self._values:
return self._values[attr]
return getattr(super(), attr)

View file

@ -4,67 +4,27 @@ import asyncio
from collections import OrderedDict
import logging
import unittest
from unittest.mock import patch, Mock, MagicMock
from unittest.mock import patch, Mock
from datetime import timedelta
import homeassistant.core as ha
import homeassistant.loader as loader
from homeassistant.exceptions import PlatformNotReady
from homeassistant.components import group
from homeassistant.helpers.entity import Entity, generate_entity_id
from homeassistant.helpers.entity_component import (
EntityComponent, DEFAULT_SCAN_INTERVAL, SLOW_SETUP_WARNING)
from homeassistant.helpers import entity_component
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.setup import setup_component
from homeassistant.helpers import discovery
import homeassistant.util.dt as dt_util
from tests.common import (
get_test_home_assistant, MockPlatform, MockModule, fire_time_changed,
mock_coro, async_fire_time_changed, mock_registry)
get_test_home_assistant, MockPlatform, MockModule, mock_coro,
async_fire_time_changed, MockEntity)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
class EntityTest(Entity):
"""Test for the Entity component."""
def __init__(self, **values):
"""Initialize an entity."""
self._values = values
if 'entity_id' in values:
self.entity_id = values['entity_id']
@property
def name(self):
"""Return the name of the entity."""
return self._handle('name')
@property
def should_poll(self):
"""Return the ste of the polling."""
return self._handle('should_poll')
@property
def unique_id(self):
"""Return the unique ID of the entity."""
return self._handle('unique_id')
@property
def available(self):
"""Return True if entity is available."""
return self._handle('available')
def _handle(self, attr):
"""Helper for the attributes."""
if attr in self._values:
return self._values[attr]
return getattr(super(), attr)
class TestHelpersEntityComponent(unittest.TestCase):
"""Test homeassistant.helpers.entity_component module."""
@ -85,7 +45,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
# No group after setup
assert len(self.hass.states.entity_ids()) == 0
component.add_entities([EntityTest()])
component.add_entities([MockEntity()])
self.hass.block_till_done()
# group exists
@ -98,7 +58,7 @@ class TestHelpersEntityComponent(unittest.TestCase):
('test_domain.unnamed_device',)
# group extended
component.add_entities([EntityTest(name='goodbye')])
component.add_entities([MockEntity(name='goodbye')])
self.hass.block_till_done()
assert len(self.hass.states.entity_ids()) == 3
@ -108,151 +68,6 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert group.attributes.get('entity_id') == \
('test_domain.goodbye', 'test_domain.unnamed_device')
def test_polling_only_updates_entities_it_should_poll(self):
"""Test the polling of only updated entities."""
component = EntityComponent(
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20))
no_poll_ent = EntityTest(should_poll=False)
no_poll_ent.async_update = Mock()
poll_ent = EntityTest(should_poll=True)
poll_ent.async_update = Mock()
component.add_entities([no_poll_ent, poll_ent])
no_poll_ent.async_update.reset_mock()
poll_ent.async_update.reset_mock()
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20))
self.hass.block_till_done()
assert not no_poll_ent.async_update.called
assert poll_ent.async_update.called
def test_polling_updates_entities_with_exception(self):
"""Test the updated entities that not break with an exception."""
component = EntityComponent(
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20))
update_ok = []
update_err = []
def update_mock():
"""Mock normal update."""
update_ok.append(None)
def update_mock_err():
"""Mock error update."""
update_err.append(None)
raise AssertionError("Fake error update")
ent1 = EntityTest(should_poll=True)
ent1.update = update_mock_err
ent2 = EntityTest(should_poll=True)
ent2.update = update_mock
ent3 = EntityTest(should_poll=True)
ent3.update = update_mock
ent4 = EntityTest(should_poll=True)
ent4.update = update_mock
component.add_entities([ent1, ent2, ent3, ent4])
update_ok.clear()
update_err.clear()
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20))
self.hass.block_till_done()
assert len(update_ok) == 3
assert len(update_err) == 1
def test_update_state_adds_entities(self):
"""Test if updating poll entities cause an entity to be added works."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent1 = EntityTest()
ent2 = EntityTest(should_poll=True)
component.add_entities([ent2])
assert 1 == len(self.hass.states.entity_ids())
ent2.update = lambda *_: component.add_entities([ent1])
fire_time_changed(
self.hass, dt_util.utcnow() + DEFAULT_SCAN_INTERVAL
)
self.hass.block_till_done()
assert 2 == len(self.hass.states.entity_ids())
def test_update_state_adds_entities_with_update_before_add_true(self):
"""Test if call update before add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent = EntityTest()
ent.update = Mock(spec_set=True)
component.add_entities([ent], True)
self.hass.block_till_done()
assert 1 == len(self.hass.states.entity_ids())
assert ent.update.called
def test_update_state_adds_entities_with_update_before_add_false(self):
"""Test if not call update before add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent = EntityTest()
ent.update = Mock(spec_set=True)
component.add_entities([ent], False)
self.hass.block_till_done()
assert 1 == len(self.hass.states.entity_ids())
assert not ent.update.called
def test_extract_from_service_returns_all_if_no_entity_id(self):
"""Test the extraction of everything from service."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2'),
])
call = ha.ServiceCall('test', 'service')
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.extract_from_service(call))
def test_extract_from_service_filter_out_non_existing_entities(self):
"""Test the extraction of non existing entities from service."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2'),
])
call = ha.ServiceCall('test', 'service', {
'entity_id': ['test_domain.test_2', 'test_domain.non_exist']
})
assert ['test_domain.test_2'] == \
[ent.entity_id for ent in component.extract_from_service(call)]
def test_extract_from_service_no_group_expand(self):
"""Test not expanding a group."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
test_group = group.Group.create_group(
self.hass, 'test_group', ['light.Ceiling', 'light.Kitchen'])
component.add_entities([test_group])
call = ha.ServiceCall('test', 'service', {
'entity_id': ['group.test_group']
})
extracted = component.extract_from_service(call, expand_group=False)
self.assertEqual([test_group], extracted)
def test_setup_loads_platforms(self):
"""Test the loading of the platforms."""
component_setup = Mock(return_value=True)
@ -320,13 +135,13 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert ('platform_test', {}, {'msg': 'discovery_info'}) == \
mock_setup.call_args[0]
@patch('homeassistant.helpers.entity_component.'
@patch('homeassistant.helpers.entity_platform.'
'async_track_time_interval')
def test_set_scan_interval_via_config(self, mock_track):
"""Test the setting of the scan interval via configuration."""
def platform_setup(hass, config, add_devices, discovery_info=None):
"""Test the platform setup."""
add_devices([EntityTest(should_poll=True)])
add_devices([MockEntity(should_poll=True)])
loader.set_component('test_domain.platform',
MockPlatform(platform_setup))
@ -344,38 +159,13 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2]
@patch('homeassistant.helpers.entity_component.'
'async_track_time_interval')
def test_set_scan_interval_via_platform(self, mock_track):
"""Test the setting of the scan interval via platform."""
def platform_setup(hass, config, add_devices, discovery_info=None):
"""Test the platform setup."""
add_devices([EntityTest(should_poll=True)])
platform = MockPlatform(platform_setup)
platform.SCAN_INTERVAL = timedelta(seconds=30)
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.setup({
DOMAIN: {
'platform': 'platform',
}
})
self.hass.block_till_done()
assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2]
def test_set_entity_namespace_via_config(self):
"""Test setting an entity namespace."""
def platform_setup(hass, config, add_devices, discovery_info=None):
"""Test the platform setup."""
add_devices([
EntityTest(name='beer'),
EntityTest(name=None),
MockEntity(name='beer'),
MockEntity(name=None),
])
platform = MockPlatform(platform_setup)
@ -396,83 +186,16 @@ class TestHelpersEntityComponent(unittest.TestCase):
assert sorted(self.hass.states.entity_ids()) == \
['test_domain.yummy_beer', 'test_domain.yummy_unnamed_device']
def test_adding_entities_with_generator_and_thread_callback(self):
"""Test generator in add_entities that calls thread method.
We should make sure we resolve the generator to a list before passing
it into an async context.
"""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
def create_entity(number):
"""Create entity helper."""
entity = EntityTest()
entity.entity_id = generate_entity_id(component.entity_id_format,
'Number', hass=self.hass)
return entity
component.add_entities(create_entity(i) for i in range(2))
@asyncio.coroutine
def test_platform_warn_slow_setup(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
with patch.object(hass.loop, 'call_later', MagicMock()) \
as mock_call:
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
assert mock_call.called
timeout, logger_method = mock_call.mock_calls[0][1][:2]
assert timeout == SLOW_SETUP_WARNING
assert logger_method == _LOGGER.warning
assert mock_call().cancel.called
@asyncio.coroutine
def test_platform_error_slow_setup(hass, caplog):
"""Don't block startup more than SLOW_SETUP_MAX_WAIT."""
with patch.object(entity_component, 'SLOW_SETUP_MAX_WAIT', 0):
called = []
@asyncio.coroutine
def setup_platform(*args):
called.append(1)
yield from asyncio.sleep(1, loop=hass.loop)
platform = MockPlatform(async_setup_platform=setup_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
loader.set_component('test_domain.test_platform', platform)
yield from component.async_setup({
DOMAIN: {
'platform': 'test_platform',
}
})
assert len(called) == 1
assert 'test_domain.test_platform' not in hass.config.components
assert 'test_platform is taking longer than 0 seconds' in caplog.text
@asyncio.coroutine
def test_extract_from_service_available_device(hass):
"""Test the extraction of entity from service and device is available."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(name='test_1'),
EntityTest(name='test_2', available=False),
EntityTest(name='test_3'),
EntityTest(name='test_4', available=False),
MockEntity(name='test_1'),
MockEntity(name='test_2', available=False),
MockEntity(name='test_3'),
MockEntity(name='test_4', available=False),
])
call_1 = ha.ServiceCall('test', 'service')
@ -490,26 +213,6 @@ def test_extract_from_service_available_device(hass):
component.async_extract_from_service(call_2))
@asyncio.coroutine
def test_updated_state_used_for_entity_id(hass):
"""Test that first update results used for entity ID generation."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
class EntityTestNameFetcher(EntityTest):
"""Mock entity that fetches a friendly name."""
@asyncio.coroutine
def async_update(self):
"""Mock update that assigns a name."""
self._values['name'] = "Living Room"
yield from component.async_add_entities([EntityTestNameFetcher()], True)
entity_ids = hass.states.async_entity_ids()
assert 1 == len(entity_ids)
assert entity_ids[0] == "test_domain.living_room"
@asyncio.coroutine
def test_platform_not_ready(hass):
"""Test that we retry when platform not ready."""
@ -555,188 +258,50 @@ def test_platform_not_ready(hass):
@asyncio.coroutine
def test_parallel_updates_async_platform(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
@asyncio.coroutine
def mock_update(*args, **kwargs):
pass
platform.async_setup_platform = mock_update
loader.set_component('test_domain.platform', platform)
def test_extract_from_service_returns_all_if_no_entity_id(hass):
"""Test the extraction of everything from service."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_add_entities([
MockEntity(name='test_1'),
MockEntity(name='test_2'),
])
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
call = ha.ServiceCall('test', 'service')
assert ['test_domain.test_1', 'test_domain.test_2'] == \
sorted(ent.entity_id for ent in
component.async_extract_from_service(call))
@asyncio.coroutine
def test_extract_from_service_filter_out_non_existing_entities(hass):
"""Test the extraction of non existing entities from service."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
MockEntity(name='test_1'),
MockEntity(name='test_2'),
])
call = ha.ServiceCall('test', 'service', {
'entity_id': ['test_domain.test_2', 'test_domain.non_exist']
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is None
assert ['test_domain.test_2'] == \
[ent.entity_id for ent
in component.async_extract_from_service(call)]
@asyncio.coroutine
def test_parallel_updates_async_platform_with_constant(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
@asyncio.coroutine
def mock_update(*args, **kwargs):
pass
platform.async_setup_platform = mock_update
platform.PARALLEL_UPDATES = 1
loader.set_component('test_domain.platform', platform)
def test_extract_from_service_no_group_expand(hass):
"""Test not expanding a group."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
test_group = yield from group.Group.async_create_group(
hass, 'test_group', ['light.Ceiling', 'light.Kitchen'])
yield from component.async_add_entities([test_group])
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
call = ha.ServiceCall('test', 'service', {
'entity_id': ['group.test_group']
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is not None
@asyncio.coroutine
def test_parallel_updates_sync_platform(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is not None
@asyncio.coroutine
def test_raise_error_on_update(hass):
"""Test the add entity if they raise an error on update."""
updates = []
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = EntityTest(name='test_1')
entity2 = EntityTest(name='test_2')
def _raise():
"""Helper to raise an exception."""
raise AssertionError
entity1.update = _raise
entity2.update = lambda: updates.append(1)
yield from component.async_add_entities([entity1, entity2], True)
assert len(updates) == 1
assert 1 in updates
@asyncio.coroutine
def test_async_remove_with_platform(hass):
"""Remove an entity from a platform."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = EntityTest(name='test_1')
yield from component.async_add_entities([entity1])
assert len(hass.states.async_entity_ids()) == 1
yield from entity1.async_remove()
assert len(hass.states.async_entity_ids()) == 0
@asyncio.coroutine
def test_not_adding_duplicate_entities_with_unique_id(hass):
"""Test for not adding duplicate entities."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(name='test1', unique_id='not_very_unique')])
assert len(hass.states.async_entity_ids()) == 1
yield from component.async_add_entities([
EntityTest(name='test2', unique_id='not_very_unique')])
assert len(hass.states.async_entity_ids()) == 1
@asyncio.coroutine
def test_using_prescribed_entity_id(hass):
"""Test for using predefined entity ID."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(name='bla', entity_id='hello.world')])
assert 'hello.world' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_using_prescribed_entity_id_with_unique_id(hass):
"""Test for ammending predefined entity ID because currently exists."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(entity_id='test_domain.world')])
yield from component.async_add_entities([
EntityTest(entity_id='test_domain.world', unique_id='bla')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_using_prescribed_entity_id_which_is_registered(hass):
"""Test not allowing predefined entity ID that already registered."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
registry = mock_registry(hass)
# Register test_domain.world
registry.async_get_or_create(
DOMAIN, 'test', '1234', suggested_object_id='world')
# This entity_id will be rewritten
yield from component.async_add_entities([
EntityTest(entity_id='test_domain.world')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_name_which_conflict_with_registered(hass):
"""Test not generating conflicting entity ID based on name."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
registry = mock_registry(hass)
# Register test_domain.world
registry.async_get_or_create(
DOMAIN, 'test', '1234', suggested_object_id='world')
yield from component.async_add_entities([
EntityTest(name='world')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_entity_with_name_and_entity_id_getting_registered(hass):
"""Ensure that entity ID is used for registration."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
EntityTest(unique_id='1234', name='bla',
entity_id='test_domain.world')])
assert 'test_domain.world' in hass.states.async_entity_ids()
extracted = component.async_extract_from_service(call, expand_group=False)
assert extracted == [test_group]

View file

@ -0,0 +1,435 @@
"""Tests for the EntityPlatform helper."""
import asyncio
import logging
import unittest
from unittest.mock import patch, Mock, MagicMock
from datetime import timedelta
import homeassistant.loader as loader
from homeassistant.helpers.entity import generate_entity_id
from homeassistant.helpers.entity_component import (
EntityComponent, DEFAULT_SCAN_INTERVAL)
from homeassistant.helpers import entity_platform
import homeassistant.util.dt as dt_util
from tests.common import (
get_test_home_assistant, MockPlatform, fire_time_changed, mock_registry,
MockEntity)
_LOGGER = logging.getLogger(__name__)
DOMAIN = "test_domain"
class TestHelpersEntityPlatform(unittest.TestCase):
"""Test homeassistant.helpers.entity_component module."""
def setUp(self): # pylint: disable=invalid-name
"""Initialize a test Home Assistant instance."""
self.hass = get_test_home_assistant()
def tearDown(self): # pylint: disable=invalid-name
"""Clean up the test Home Assistant instance."""
self.hass.stop()
def test_polling_only_updates_entities_it_should_poll(self):
"""Test the polling of only updated entities."""
component = EntityComponent(
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20))
no_poll_ent = MockEntity(should_poll=False)
no_poll_ent.async_update = Mock()
poll_ent = MockEntity(should_poll=True)
poll_ent.async_update = Mock()
component.add_entities([no_poll_ent, poll_ent])
no_poll_ent.async_update.reset_mock()
poll_ent.async_update.reset_mock()
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20))
self.hass.block_till_done()
assert not no_poll_ent.async_update.called
assert poll_ent.async_update.called
def test_polling_updates_entities_with_exception(self):
"""Test the updated entities that not break with an exception."""
component = EntityComponent(
_LOGGER, DOMAIN, self.hass, timedelta(seconds=20))
update_ok = []
update_err = []
def update_mock():
"""Mock normal update."""
update_ok.append(None)
def update_mock_err():
"""Mock error update."""
update_err.append(None)
raise AssertionError("Fake error update")
ent1 = MockEntity(should_poll=True)
ent1.update = update_mock_err
ent2 = MockEntity(should_poll=True)
ent2.update = update_mock
ent3 = MockEntity(should_poll=True)
ent3.update = update_mock
ent4 = MockEntity(should_poll=True)
ent4.update = update_mock
component.add_entities([ent1, ent2, ent3, ent4])
update_ok.clear()
update_err.clear()
fire_time_changed(self.hass, dt_util.utcnow() + timedelta(seconds=20))
self.hass.block_till_done()
assert len(update_ok) == 3
assert len(update_err) == 1
def test_update_state_adds_entities(self):
"""Test if updating poll entities cause an entity to be added works."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent1 = MockEntity()
ent2 = MockEntity(should_poll=True)
component.add_entities([ent2])
assert 1 == len(self.hass.states.entity_ids())
ent2.update = lambda *_: component.add_entities([ent1])
fire_time_changed(
self.hass, dt_util.utcnow() + DEFAULT_SCAN_INTERVAL
)
self.hass.block_till_done()
assert 2 == len(self.hass.states.entity_ids())
def test_update_state_adds_entities_with_update_before_add_true(self):
"""Test if call update before add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent = MockEntity()
ent.update = Mock(spec_set=True)
component.add_entities([ent], True)
self.hass.block_till_done()
assert 1 == len(self.hass.states.entity_ids())
assert ent.update.called
def test_update_state_adds_entities_with_update_before_add_false(self):
"""Test if not call update before add to state machine."""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
ent = MockEntity()
ent.update = Mock(spec_set=True)
component.add_entities([ent], False)
self.hass.block_till_done()
assert 1 == len(self.hass.states.entity_ids())
assert not ent.update.called
@patch('homeassistant.helpers.entity_platform.'
'async_track_time_interval')
def test_set_scan_interval_via_platform(self, mock_track):
"""Test the setting of the scan interval via platform."""
def platform_setup(hass, config, add_devices, discovery_info=None):
"""Test the platform setup."""
add_devices([MockEntity(should_poll=True)])
platform = MockPlatform(platform_setup)
platform.SCAN_INTERVAL = timedelta(seconds=30)
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
component.setup({
DOMAIN: {
'platform': 'platform',
}
})
self.hass.block_till_done()
assert mock_track.called
assert timedelta(seconds=30) == mock_track.call_args[0][2]
def test_adding_entities_with_generator_and_thread_callback(self):
"""Test generator in add_entities that calls thread method.
We should make sure we resolve the generator to a list before passing
it into an async context.
"""
component = EntityComponent(_LOGGER, DOMAIN, self.hass)
def create_entity(number):
"""Create entity helper."""
entity = MockEntity()
entity.entity_id = generate_entity_id(DOMAIN + '.{}',
'Number', hass=self.hass)
return entity
component.add_entities(create_entity(i) for i in range(2))
@asyncio.coroutine
def test_platform_warn_slow_setup(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
with patch.object(hass.loop, 'call_later', MagicMock()) \
as mock_call:
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
assert mock_call.called
timeout, logger_method = mock_call.mock_calls[0][1][:2]
assert timeout == entity_platform.SLOW_SETUP_WARNING
assert logger_method == _LOGGER.warning
assert mock_call().cancel.called
@asyncio.coroutine
def test_platform_error_slow_setup(hass, caplog):
"""Don't block startup more than SLOW_SETUP_MAX_WAIT."""
with patch.object(entity_platform, 'SLOW_SETUP_MAX_WAIT', 0):
called = []
@asyncio.coroutine
def setup_platform(*args):
called.append(1)
yield from asyncio.sleep(1, loop=hass.loop)
platform = MockPlatform(async_setup_platform=setup_platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
loader.set_component('test_domain.test_platform', platform)
yield from component.async_setup({
DOMAIN: {
'platform': 'test_platform',
}
})
assert len(called) == 1
assert 'test_domain.test_platform' not in hass.config.components
assert 'test_platform is taking longer than 0 seconds' in caplog.text
@asyncio.coroutine
def test_updated_state_used_for_entity_id(hass):
"""Test that first update results used for entity ID generation."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
class MockEntityNameFetcher(MockEntity):
"""Mock entity that fetches a friendly name."""
@asyncio.coroutine
def async_update(self):
"""Mock update that assigns a name."""
self._values['name'] = "Living Room"
yield from component.async_add_entities([MockEntityNameFetcher()], True)
entity_ids = hass.states.async_entity_ids()
assert 1 == len(entity_ids)
assert entity_ids[0] == "test_domain.living_room"
@asyncio.coroutine
def test_parallel_updates_async_platform(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
@asyncio.coroutine
def mock_update(*args, **kwargs):
pass
platform.async_setup_platform = mock_update
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is None
@asyncio.coroutine
def test_parallel_updates_async_platform_with_constant(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
@asyncio.coroutine
def mock_update(*args, **kwargs):
pass
platform.async_setup_platform = mock_update
platform.PARALLEL_UPDATES = 1
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is not None
@asyncio.coroutine
def test_parallel_updates_sync_platform(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform()
loader.set_component('test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is not None
@asyncio.coroutine
def test_raise_error_on_update(hass):
"""Test the add entity if they raise an error on update."""
updates = []
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = MockEntity(name='test_1')
entity2 = MockEntity(name='test_2')
def _raise():
"""Helper to raise an exception."""
raise AssertionError
entity1.update = _raise
entity2.update = lambda: updates.append(1)
yield from component.async_add_entities([entity1, entity2], True)
assert len(updates) == 1
assert 1 in updates
@asyncio.coroutine
def test_async_remove_with_platform(hass):
"""Remove an entity from a platform."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
entity1 = MockEntity(name='test_1')
yield from component.async_add_entities([entity1])
assert len(hass.states.async_entity_ids()) == 1
yield from entity1.async_remove()
assert len(hass.states.async_entity_ids()) == 0
@asyncio.coroutine
def test_not_adding_duplicate_entities_with_unique_id(hass):
"""Test for not adding duplicate entities."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
MockEntity(name='test1', unique_id='not_very_unique')])
assert len(hass.states.async_entity_ids()) == 1
yield from component.async_add_entities([
MockEntity(name='test2', unique_id='not_very_unique')])
assert len(hass.states.async_entity_ids()) == 1
@asyncio.coroutine
def test_using_prescribed_entity_id(hass):
"""Test for using predefined entity ID."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
MockEntity(name='bla', entity_id='hello.world')])
assert 'hello.world' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_using_prescribed_entity_id_with_unique_id(hass):
"""Test for ammending predefined entity ID because currently exists."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
MockEntity(entity_id='test_domain.world')])
yield from component.async_add_entities([
MockEntity(entity_id='test_domain.world', unique_id='bla')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_using_prescribed_entity_id_which_is_registered(hass):
"""Test not allowing predefined entity ID that already registered."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
registry = mock_registry(hass)
# Register test_domain.world
registry.async_get_or_create(
DOMAIN, 'test', '1234', suggested_object_id='world')
# This entity_id will be rewritten
yield from component.async_add_entities([
MockEntity(entity_id='test_domain.world')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_name_which_conflict_with_registered(hass):
"""Test not generating conflicting entity ID based on name."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
registry = mock_registry(hass)
# Register test_domain.world
registry.async_get_or_create(
DOMAIN, 'test', '1234', suggested_object_id='world')
yield from component.async_add_entities([
MockEntity(name='world')])
assert 'test_domain.world_2' in hass.states.async_entity_ids()
@asyncio.coroutine
def test_entity_with_name_and_entity_id_getting_registered(hass):
"""Ensure that entity ID is used for registration."""
component = EntityComponent(_LOGGER, DOMAIN, hass)
yield from component.async_add_entities([
MockEntity(unique_id='1234', name='bla',
entity_id='test_domain.world')])
assert 'test_domain.world' in hass.states.async_entity_ids()

View file

@ -255,18 +255,6 @@ class TestConfig(unittest.TestCase):
return self.hass.states.get('test.test')
def test_entity_customization_false(self):
"""Test entity customization through configuration."""
config = {CONF_LATITUDE: 50,
CONF_LONGITUDE: 50,
CONF_NAME: 'Test',
CONF_CUSTOMIZE: {
'test.test': {'hidden': False}}}
state = self._compute_state(config)
assert 'hidden' not in state.attributes
def test_entity_customization(self):
"""Test entity customization through configuration."""
config = {CONF_LATITUDE: 50,