From 40d7857f3b7a8b2b1522e4d1a7f59f7ac3617b06 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 7 Apr 2018 23:04:50 -0400 Subject: [PATCH] Prepare entity component for config entries (#13730) * Prepare entity component for config entries * Return in time --- homeassistant/helpers/entity_component.py | 55 ++++++++++------------- homeassistant/helpers/entity_platform.py | 30 ++++++++++--- tests/common.py | 4 +- 3 files changed, 49 insertions(+), 40 deletions(-) diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index f086437c10dc..6ff9b6f65713 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -40,16 +40,7 @@ class EntityComponent(object): self.config = None self._platforms = { - domain: EntityPlatform( - hass=hass, - logger=logger, - domain=domain, - platform_name=domain, - scan_interval=self.scan_interval, - parallel_updates=0, - entity_namespace=None, - async_entities_added_callback=self._async_update_group, - ) + domain: self._async_init_entity_platform(domain, None) } self.async_add_entities = self._platforms[domain].async_add_entities self.add_entities = self._platforms[domain].add_entities @@ -127,34 +118,19 @@ class EntityComponent(object): if platform is None: return - # Config > Platform > Component - scan_interval = ( - platform_config.get(CONF_SCAN_INTERVAL) or - getattr(platform, 'SCAN_INTERVAL', None) or self.scan_interval) - parallel_updates = getattr( - platform, 'PARALLEL_UPDATES', - int(not hasattr(platform, 'async_setup_platform'))) - + # Use config scan interval, fallback to platform if none set + scan_interval = platform_config.get( + CONF_SCAN_INTERVAL, getattr(platform, 'SCAN_INTERVAL', None)) entity_namespace = platform_config.get(CONF_ENTITY_NAMESPACE) key = (platform_type, scan_interval, entity_namespace) if key not in self._platforms: - entity_platform = self._platforms[key] = EntityPlatform( - hass=self.hass, - logger=self.logger, - domain=self.domain, - platform_name=platform_type, - scan_interval=scan_interval, - parallel_updates=parallel_updates, - entity_namespace=entity_namespace, - async_entities_added_callback=self._async_update_group, + self._platforms[key] = self._async_init_entity_platform( + platform_type, platform, scan_interval, entity_namespace ) - else: - entity_platform = self._platforms[key] - await entity_platform.async_setup( - platform, platform_config, discovery_info) + await self._platforms[key].async_setup(platform_config, discovery_info) @callback def _async_update_group(self): @@ -219,3 +195,20 @@ class EntityComponent(object): await self._async_reset() return conf + + def _async_init_entity_platform(self, platform_type, platform, + scan_interval=None, entity_namespace=None): + """Helper to initialize an entity platform.""" + if scan_interval is None: + scan_interval = self.scan_interval + + return EntityPlatform( + hass=self.hass, + logger=self.logger, + domain=self.domain, + platform_name=platform_type, + platform=platform, + scan_interval=scan_interval, + entity_namespace=entity_namespace, + async_entities_added_callback=self._async_update_group, + ) diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 501ab5057a36..3c6deaba94af 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -20,8 +20,8 @@ PLATFORM_NOT_READY_RETRIES = 10 class EntityPlatform(object): """Manage the entities for a single platform.""" - def __init__(self, *, hass, logger, domain, platform_name, scan_interval, - parallel_updates, entity_namespace, + def __init__(self, *, hass, logger, domain, platform_name, platform, + scan_interval, entity_namespace, async_entities_added_callback): """Initialize the entity platform. @@ -38,8 +38,8 @@ class EntityPlatform(object): self.logger = logger self.domain = domain self.platform_name = platform_name + self.platform = platform self.scan_interval = scan_interval - self.parallel_updates = None self.entity_namespace = entity_namespace self.async_entities_added_callback = async_entities_added_callback self.entities = {} @@ -47,13 +47,30 @@ class EntityPlatform(object): self._async_unsub_polling = None self._process_updates = asyncio.Lock(loop=hass.loop) + # Platform is None for the EntityComponent "catch-all" EntityPlatform + # which powers entity_component.add_entities + if platform is None: + self.parallel_updates = None + return + + # Async platforms do all updates in parallel by default + if hasattr(platform, 'async_setup_platform'): + default_parallel_updates = 0 + else: + default_parallel_updates = 1 + + parallel_updates = getattr(platform, 'PARALLEL_UPDATES', + default_parallel_updates) + if parallel_updates: self.parallel_updates = asyncio.Semaphore( parallel_updates, loop=hass.loop) + else: + self.parallel_updates = None - async def async_setup(self, platform, platform_config, discovery_info=None, - tries=0): + async def async_setup(self, platform_config, discovery_info=None, tries=0): """Setup the platform.""" + platform = self.platform logger = self.logger hass = self.hass full_name = '{}.{}'.format(self.domain, self.platform_name) @@ -98,8 +115,7 @@ class EntityPlatform(object): 'Platform %s not ready yet. Retrying in %d seconds.', self.platform_name, wait_time) async_track_point_in_time( - hass, self.async_setup( - platform, platform_config, discovery_info, tries), + hass, self.async_setup(platform_config, discovery_info, tries), dt_util.utcnow() + timedelta(seconds=wait_time)) except asyncio.TimeoutError: logger.error( diff --git a/tests/common.py b/tests/common.py index bc84b3493a8e..388898e70243 100644 --- a/tests/common.py +++ b/tests/common.py @@ -370,8 +370,8 @@ class MockEntityPlatform(entity_platform.EntityPlatform): logger=None, domain='test_domain', platform_name='test_platform', + platform=None, scan_interval=timedelta(seconds=15), - parallel_updates=0, entity_namespace=None, async_entities_added_callback=lambda: None ): @@ -381,8 +381,8 @@ class MockEntityPlatform(entity_platform.EntityPlatform): logger=logger, domain=domain, platform_name=platform_name, + platform=platform, scan_interval=scan_interval, - parallel_updates=parallel_updates, entity_namespace=entity_namespace, async_entities_added_callback=async_entities_added_callback, )