Prepare entity component for config entries (#13730)

* Prepare entity component for config entries

* Return in time
This commit is contained in:
Paulus Schoutsen 2018-04-07 23:04:50 -04:00 committed by GitHub
parent 81b1d08d35
commit 40d7857f3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 40 deletions

View file

@ -40,16 +40,7 @@ class EntityComponent(object):
self.config = None self.config = None
self._platforms = { self._platforms = {
domain: EntityPlatform( domain: self._async_init_entity_platform(domain, None)
hass=hass,
logger=logger,
domain=domain,
platform_name=domain,
scan_interval=self.scan_interval,
parallel_updates=0,
entity_namespace=None,
async_entities_added_callback=self._async_update_group,
)
} }
self.async_add_entities = self._platforms[domain].async_add_entities self.async_add_entities = self._platforms[domain].async_add_entities
self.add_entities = self._platforms[domain].add_entities self.add_entities = self._platforms[domain].add_entities
@ -127,34 +118,19 @@ class EntityComponent(object):
if platform is None: if platform is None:
return return
# Config > Platform > Component # Use config scan interval, fallback to platform if none set
scan_interval = ( scan_interval = platform_config.get(
platform_config.get(CONF_SCAN_INTERVAL) or CONF_SCAN_INTERVAL, getattr(platform, 'SCAN_INTERVAL', None))
getattr(platform, 'SCAN_INTERVAL', None) or self.scan_interval)
parallel_updates = getattr(
platform, 'PARALLEL_UPDATES',
int(not hasattr(platform, 'async_setup_platform')))
entity_namespace = platform_config.get(CONF_ENTITY_NAMESPACE) entity_namespace = platform_config.get(CONF_ENTITY_NAMESPACE)
key = (platform_type, scan_interval, entity_namespace) key = (platform_type, scan_interval, entity_namespace)
if key not in self._platforms: if key not in self._platforms:
entity_platform = self._platforms[key] = EntityPlatform( self._platforms[key] = self._async_init_entity_platform(
hass=self.hass, platform_type, platform, scan_interval, entity_namespace
logger=self.logger,
domain=self.domain,
platform_name=platform_type,
scan_interval=scan_interval,
parallel_updates=parallel_updates,
entity_namespace=entity_namespace,
async_entities_added_callback=self._async_update_group,
) )
else:
entity_platform = self._platforms[key]
await entity_platform.async_setup( await self._platforms[key].async_setup(platform_config, discovery_info)
platform, platform_config, discovery_info)
@callback @callback
def _async_update_group(self): def _async_update_group(self):
@ -219,3 +195,20 @@ class EntityComponent(object):
await self._async_reset() await self._async_reset()
return conf return conf
def _async_init_entity_platform(self, platform_type, platform,
scan_interval=None, entity_namespace=None):
"""Helper to initialize an entity platform."""
if scan_interval is None:
scan_interval = self.scan_interval
return EntityPlatform(
hass=self.hass,
logger=self.logger,
domain=self.domain,
platform_name=platform_type,
platform=platform,
scan_interval=scan_interval,
entity_namespace=entity_namespace,
async_entities_added_callback=self._async_update_group,
)

View file

@ -20,8 +20,8 @@ PLATFORM_NOT_READY_RETRIES = 10
class EntityPlatform(object): class EntityPlatform(object):
"""Manage the entities for a single platform.""" """Manage the entities for a single platform."""
def __init__(self, *, hass, logger, domain, platform_name, scan_interval, def __init__(self, *, hass, logger, domain, platform_name, platform,
parallel_updates, entity_namespace, scan_interval, entity_namespace,
async_entities_added_callback): async_entities_added_callback):
"""Initialize the entity platform. """Initialize the entity platform.
@ -38,8 +38,8 @@ class EntityPlatform(object):
self.logger = logger self.logger = logger
self.domain = domain self.domain = domain
self.platform_name = platform_name self.platform_name = platform_name
self.platform = platform
self.scan_interval = scan_interval self.scan_interval = scan_interval
self.parallel_updates = None
self.entity_namespace = entity_namespace self.entity_namespace = entity_namespace
self.async_entities_added_callback = async_entities_added_callback self.async_entities_added_callback = async_entities_added_callback
self.entities = {} self.entities = {}
@ -47,13 +47,30 @@ class EntityPlatform(object):
self._async_unsub_polling = None self._async_unsub_polling = None
self._process_updates = asyncio.Lock(loop=hass.loop) self._process_updates = asyncio.Lock(loop=hass.loop)
# Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities
if platform is None:
self.parallel_updates = None
return
# Async platforms do all updates in parallel by default
if hasattr(platform, 'async_setup_platform'):
default_parallel_updates = 0
else:
default_parallel_updates = 1
parallel_updates = getattr(platform, 'PARALLEL_UPDATES',
default_parallel_updates)
if parallel_updates: if parallel_updates:
self.parallel_updates = asyncio.Semaphore( self.parallel_updates = asyncio.Semaphore(
parallel_updates, loop=hass.loop) parallel_updates, loop=hass.loop)
else:
self.parallel_updates = None
async def async_setup(self, platform, platform_config, discovery_info=None, async def async_setup(self, platform_config, discovery_info=None, tries=0):
tries=0):
"""Setup the platform.""" """Setup the platform."""
platform = self.platform
logger = self.logger logger = self.logger
hass = self.hass hass = self.hass
full_name = '{}.{}'.format(self.domain, self.platform_name) full_name = '{}.{}'.format(self.domain, self.platform_name)
@ -98,8 +115,7 @@ class EntityPlatform(object):
'Platform %s not ready yet. Retrying in %d seconds.', 'Platform %s not ready yet. Retrying in %d seconds.',
self.platform_name, wait_time) self.platform_name, wait_time)
async_track_point_in_time( async_track_point_in_time(
hass, self.async_setup( hass, self.async_setup(platform_config, discovery_info, tries),
platform, platform_config, discovery_info, tries),
dt_util.utcnow() + timedelta(seconds=wait_time)) dt_util.utcnow() + timedelta(seconds=wait_time))
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error( logger.error(

View file

@ -370,8 +370,8 @@ class MockEntityPlatform(entity_platform.EntityPlatform):
logger=None, logger=None,
domain='test_domain', domain='test_domain',
platform_name='test_platform', platform_name='test_platform',
platform=None,
scan_interval=timedelta(seconds=15), scan_interval=timedelta(seconds=15),
parallel_updates=0,
entity_namespace=None, entity_namespace=None,
async_entities_added_callback=lambda: None async_entities_added_callback=lambda: None
): ):
@ -381,8 +381,8 @@ class MockEntityPlatform(entity_platform.EntityPlatform):
logger=logger, logger=logger,
domain=domain, domain=domain,
platform_name=platform_name, platform_name=platform_name,
platform=platform,
scan_interval=scan_interval, scan_interval=scan_interval,
parallel_updates=parallel_updates,
entity_namespace=entity_namespace, entity_namespace=entity_namespace,
async_entities_added_callback=async_entities_added_callback, async_entities_added_callback=async_entities_added_callback,
) )