Reorg auth (#15443)

This commit is contained in:
Paulus Schoutsen 2018-07-13 11:43:08 +02:00 committed by GitHub
parent 23f1b49e55
commit b6ca03ce47
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 698 additions and 652 deletions

View file

@ -1,613 +0,0 @@
"""Provide an authentication layer for Home Assistant."""
import asyncio
import binascii
import importlib
import logging
import os
import uuid
from collections import OrderedDict
from datetime import datetime, timedelta
import attr
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.core import callback
from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry
_LOGGER = logging.getLogger(__name__)
STORAGE_VERSION = 1
STORAGE_KEY = 'auth'
AUTH_PROVIDERS = Registry()
AUTH_PROVIDER_SCHEMA = vol.Schema({
vol.Required(CONF_TYPE): str,
vol.Optional(CONF_NAME): str,
# Specify ID if you have two auth providers for same type.
vol.Optional(CONF_ID): str,
}, extra=vol.ALLOW_EXTRA)
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
DATA_REQS = 'auth_reqs_processed'
def generate_secret(entropy: int = 32) -> str:
"""Generate a secret.
Backport of secrets.token_hex from Python 3.6
Event loop friendly.
"""
return binascii.hexlify(os.urandom(entropy)).decode('ascii')
class AuthProvider:
"""Provider of user authentication."""
DEFAULT_TITLE = 'Unnamed auth provider'
initialized = False
def __init__(self, hass, store, config):
"""Initialize an auth provider."""
self.hass = hass
self.store = store
self.config = config
@property
def id(self): # pylint: disable=invalid-name
"""Return id of the auth provider.
Optional, can be None.
"""
return self.config.get(CONF_ID)
@property
def type(self):
"""Return type of the provider."""
return self.config[CONF_TYPE]
@property
def name(self):
"""Return the name of the auth provider."""
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
async def async_credentials(self):
"""Return all credentials of this provider."""
users = await self.store.async_get_users()
return [
credentials
for user in users
for credentials in user.credentials
if (credentials.auth_provider_type == self.type and
credentials.auth_provider_id == self.id)
]
@callback
def async_create_credentials(self, data):
"""Create credentials."""
return Credentials(
auth_provider_type=self.type,
auth_provider_id=self.id,
data=data,
)
# Implement by extending class
async def async_initialize(self):
"""Initialize the auth provider.
Optional.
"""
async def async_credential_flow(self):
"""Return the data flow for logging in with auth provider."""
raise NotImplementedError
async def async_get_or_create_credentials(self, flow_result):
"""Get credentials based on the flow result."""
raise NotImplementedError
async def async_user_meta_for_credentials(self, credentials):
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
"""
return {}
@attr.s(slots=True)
class User:
"""A user."""
name = attr.ib(type=str)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False)
system_generated = attr.ib(type=bool, default=False)
# List of credentials of a user.
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
# Tokens associated with a user.
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
@attr.s(slots=True)
class RefreshToken:
"""RefreshToken for a user to grant new access tokens."""
user = attr.ib(type=User)
client_id = attr.ib(type=str)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
access_token_expiration = attr.ib(type=timedelta,
default=ACCESS_TOKEN_EXPIRATION)
token = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
@attr.s(slots=True)
class AccessToken:
"""Access token to access the API.
These will only ever be stored in memory and not be persisted.
"""
refresh_token = attr.ib(type=RefreshToken)
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
token = attr.ib(type=str,
default=attr.Factory(generate_secret))
@property
def expired(self):
"""Return if this token has expired."""
expires = self.created_at + self.refresh_token.access_token_expiration
return dt_util.utcnow() > expires
@attr.s(slots=True)
class Credentials:
"""Credentials for a user on an auth provider."""
auth_provider_type = attr.ib(type=str)
auth_provider_id = attr.ib(type=str)
# Allow the auth provider to store data to represent their auth.
data = attr.ib(type=dict)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_new = attr.ib(type=bool, default=True)
async def load_auth_provider_module(hass, provider):
"""Load an auth provider."""
try:
module = importlib.import_module(
'homeassistant.auth_providers.{}'.format(provider))
except ImportError:
_LOGGER.warning('Unable to find auth provider %s', provider)
return None
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'):
return module
processed = hass.data.get(DATA_REQS)
if processed is None:
processed = hass.data[DATA_REQS] = set()
elif provider in processed:
return module
req_success = await requirements.async_process_requirements(
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
if not req_success:
return None
return module
async def auth_manager_from_config(hass, provider_configs):
"""Initialize an auth manager from config."""
store = AuthStore(hass)
if provider_configs:
providers = await asyncio.gather(
*[_auth_provider_from_config(hass, store, config)
for config in provider_configs])
else:
providers = []
# So returned auth providers are in same order as config
provider_hash = OrderedDict()
for provider in providers:
if provider is None:
continue
key = (provider.type, provider.id)
if key in provider_hash:
_LOGGER.error(
'Found duplicate provider: %s. Please add unique IDs if you '
'want to have the same provider twice.', key)
continue
provider_hash[key] = provider
manager = AuthManager(hass, store, provider_hash)
return manager
async def _auth_provider_from_config(hass, store, config):
"""Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE]
module = await load_auth_provider_module(hass, provider_name)
if module is None:
return None
try:
config = module.CONFIG_SCHEMA(config)
except vol.Invalid as err:
_LOGGER.error('Invalid configuration for auth provider %s: %s',
provider_name, humanize_error(config, err))
return None
return AUTH_PROVIDERS[provider_name](hass, store, config)
class AuthManager:
"""Manage the authentication for Home Assistant."""
def __init__(self, hass, store, providers):
"""Initialize the auth manager."""
self._store = store
self._providers = providers
self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow,
self._async_finish_login_flow)
self._access_tokens = {}
@property
def active(self):
"""Return if any auth providers are registered."""
return bool(self._providers)
@property
def support_legacy(self):
"""
Return if legacy_api_password auth providers are registered.
Should be removed when we removed legacy_api_password auth providers.
"""
for provider_type, _ in self._providers:
if provider_type == 'legacy_api_password':
return True
return False
@property
def async_auth_providers(self):
"""Return a list of available auth providers."""
return self._providers.values()
async def async_get_user(self, user_id):
"""Retrieve a user."""
return await self._store.async_get_user(user_id)
async def async_create_system_user(self, name):
"""Create a system user."""
return await self._store.async_create_user(
name=name,
system_generated=True,
is_active=True,
)
async def async_get_or_create_user(self, credentials):
"""Get or create a user."""
if not credentials.is_new:
for user in await self._store.async_get_users():
for creds in user.credentials:
if creds.id == credentials.id:
return user
raise ValueError('Unable to find the user.')
auth_provider = self._async_get_auth_provider(credentials)
info = await auth_provider.async_user_meta_for_credentials(
credentials)
kwargs = {
'credentials': credentials,
'name': info.get('name')
}
# Make owner and activate user if it's the first user.
if await self._store.async_get_users():
kwargs['is_owner'] = False
kwargs['is_active'] = False
else:
kwargs['is_owner'] = True
kwargs['is_active'] = True
return await self._store.async_create_user(**kwargs)
async def async_link_user(self, user, credentials):
"""Link credentials to an existing user."""
await self._store.async_link_user(user, credentials)
async def async_remove_user(self, user):
"""Remove a user."""
await self._store.async_remove_user(user)
async def async_create_refresh_token(self, user, client_id=None):
"""Create a new refresh token for a user."""
if not user.is_active:
raise ValueError('User is not active')
if user.system_generated and client_id is not None:
raise ValueError(
'System generated users cannot have refresh tokens connected '
'to a client.')
if not user.system_generated and client_id is None:
raise ValueError('Client is required to generate a refresh token.')
return await self._store.async_create_refresh_token(user, client_id)
async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
return await self._store.async_get_refresh_token(token)
@callback
def async_create_access_token(self, refresh_token):
"""Create a new access token."""
access_token = AccessToken(refresh_token=refresh_token)
self._access_tokens[access_token.token] = access_token
return access_token
@callback
def async_get_access_token(self, token):
"""Get an access token."""
tkn = self._access_tokens.get(token)
if tkn is None:
return None
if tkn.expired:
self._access_tokens.pop(token)
return None
return tkn
async def _async_create_login_flow(self, handler, *, source, data):
"""Create a login flow."""
auth_provider = self._providers[handler]
if not auth_provider.initialized:
auth_provider.initialized = True
await auth_provider.async_initialize()
return await auth_provider.async_credential_flow()
async def _async_finish_login_flow(self, result):
"""Result of a credential login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
auth_provider = self._providers[result['handler']]
return await auth_provider.async_get_or_create_credentials(
result['data'])
@callback
def _async_get_auth_provider(self, credentials):
"""Helper to get auth provider from a set of credentials."""
auth_provider_key = (credentials.auth_provider_type,
credentials.auth_provider_id)
return self._providers[auth_provider_key]
class AuthStore:
"""Stores authentication info.
Any mutation to an object should happen inside the auth store.
The auth store is lazy. It won't load the data from disk until a method is
called that needs it.
"""
def __init__(self, hass):
"""Initialize the auth store."""
self.hass = hass
self._users = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
async def async_get_users(self):
"""Retrieve all users."""
if self._users is None:
await self.async_load()
return list(self._users.values())
async def async_get_user(self, user_id):
"""Retrieve a user by id."""
if self._users is None:
await self.async_load()
return self._users.get(user_id)
async def async_create_user(self, name, is_owner=None, is_active=None,
system_generated=None, credentials=None):
"""Create a new user."""
if self._users is None:
await self.async_load()
kwargs = {
'name': name
}
if is_owner is not None:
kwargs['is_owner'] = is_owner
if is_active is not None:
kwargs['is_active'] = is_active
if system_generated is not None:
kwargs['system_generated'] = system_generated
new_user = User(**kwargs)
self._users[new_user.id] = new_user
if credentials is None:
await self.async_save()
return new_user
# Saving is done inside the link.
await self.async_link_user(new_user, credentials)
return new_user
async def async_link_user(self, user, credentials):
"""Add credentials to an existing user."""
user.credentials.append(credentials)
await self.async_save()
credentials.is_new = False
async def async_remove_user(self, user):
"""Remove a user."""
self._users.pop(user.id)
await self.async_save()
async def async_create_refresh_token(self, user, client_id=None):
"""Create a new token for a user."""
refresh_token = RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.token] = refresh_token
await self.async_save()
return refresh_token
async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
if self._users is None:
await self.async_load()
for user in self._users.values():
refresh_token = user.refresh_tokens.get(token)
if refresh_token is not None:
return refresh_token
return None
async def async_load(self):
"""Load the users."""
data = await self._store.async_load()
# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self._users is not None:
return
if data is None:
self._users = {}
return
users = {
user_dict['id']: User(**user_dict) for user_dict in data['users']
}
for cred_dict in data['credentials']:
users[cred_dict['user_id']].credentials.append(Credentials(
id=cred_dict['id'],
is_new=False,
auth_provider_type=cred_dict['auth_provider_type'],
auth_provider_id=cred_dict['auth_provider_id'],
data=cred_dict['data'],
))
refresh_tokens = {}
for rt_dict in data['refresh_tokens']:
token = RefreshToken(
id=rt_dict['id'],
user=users[rt_dict['user_id']],
client_id=rt_dict['client_id'],
created_at=dt_util.parse_datetime(rt_dict['created_at']),
access_token_expiration=timedelta(
seconds=rt_dict['access_token_expiration']),
token=rt_dict['token'],
)
refresh_tokens[token.id] = token
users[rt_dict['user_id']].refresh_tokens[token.token] = token
for ac_dict in data['access_tokens']:
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
token = AccessToken(
refresh_token=refresh_token,
created_at=dt_util.parse_datetime(ac_dict['created_at']),
token=ac_dict['token'],
)
refresh_token.access_tokens.append(token)
self._users = users
async def async_save(self):
"""Save users."""
users = [
{
'id': user.id,
'is_owner': user.is_owner,
'is_active': user.is_active,
'name': user.name,
'system_generated': user.system_generated,
}
for user in self._users.values()
]
credentials = [
{
'id': credential.id,
'user_id': user.id,
'auth_provider_type': credential.auth_provider_type,
'auth_provider_id': credential.auth_provider_id,
'data': credential.data,
}
for user in self._users.values()
for credential in user.credentials
]
refresh_tokens = [
{
'id': refresh_token.id,
'user_id': user.id,
'client_id': refresh_token.client_id,
'created_at': refresh_token.created_at.isoformat(),
'access_token_expiration':
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
]
access_tokens = [
{
'id': user.id,
'refresh_token_id': refresh_token.id,
'created_at': access_token.created_at.isoformat(),
'token': access_token.token,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens
]
data = {
'users': users,
'credentials': credentials,
'access_tokens': access_tokens,
'refresh_tokens': refresh_tokens,
}
await self._store.async_save(data, delay=1)

View file

@ -0,0 +1,191 @@
"""Provide an authentication layer for Home Assistant."""
import asyncio
import logging
from collections import OrderedDict
from homeassistant import data_entry_flow
from homeassistant.core import callback
from . import models
from . import auth_store
from .providers import auth_provider_from_config
_LOGGER = logging.getLogger(__name__)
async def auth_manager_from_config(hass, provider_configs):
"""Initialize an auth manager from config."""
store = auth_store.AuthStore(hass)
if provider_configs:
providers = await asyncio.gather(
*[auth_provider_from_config(hass, store, config)
for config in provider_configs])
else:
providers = []
# So returned auth providers are in same order as config
provider_hash = OrderedDict()
for provider in providers:
if provider is None:
continue
key = (provider.type, provider.id)
if key in provider_hash:
_LOGGER.error(
'Found duplicate provider: %s. Please add unique IDs if you '
'want to have the same provider twice.', key)
continue
provider_hash[key] = provider
manager = AuthManager(hass, store, provider_hash)
return manager
class AuthManager:
"""Manage the authentication for Home Assistant."""
def __init__(self, hass, store, providers):
"""Initialize the auth manager."""
self._store = store
self._providers = providers
self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow,
self._async_finish_login_flow)
self._access_tokens = {}
@property
def active(self):
"""Return if any auth providers are registered."""
return bool(self._providers)
@property
def support_legacy(self):
"""
Return if legacy_api_password auth providers are registered.
Should be removed when we removed legacy_api_password auth providers.
"""
for provider_type, _ in self._providers:
if provider_type == 'legacy_api_password':
return True
return False
@property
def async_auth_providers(self):
"""Return a list of available auth providers."""
return self._providers.values()
async def async_get_user(self, user_id):
"""Retrieve a user."""
return await self._store.async_get_user(user_id)
async def async_create_system_user(self, name):
"""Create a system user."""
return await self._store.async_create_user(
name=name,
system_generated=True,
is_active=True,
)
async def async_get_or_create_user(self, credentials):
"""Get or create a user."""
if not credentials.is_new:
for user in await self._store.async_get_users():
for creds in user.credentials:
if creds.id == credentials.id:
return user
raise ValueError('Unable to find the user.')
auth_provider = self._async_get_auth_provider(credentials)
info = await auth_provider.async_user_meta_for_credentials(
credentials)
kwargs = {
'credentials': credentials,
'name': info.get('name')
}
# Make owner and activate user if it's the first user.
if await self._store.async_get_users():
kwargs['is_owner'] = False
kwargs['is_active'] = False
else:
kwargs['is_owner'] = True
kwargs['is_active'] = True
return await self._store.async_create_user(**kwargs)
async def async_link_user(self, user, credentials):
"""Link credentials to an existing user."""
await self._store.async_link_user(user, credentials)
async def async_remove_user(self, user):
"""Remove a user."""
await self._store.async_remove_user(user)
async def async_create_refresh_token(self, user, client_id=None):
"""Create a new refresh token for a user."""
if not user.is_active:
raise ValueError('User is not active')
if user.system_generated and client_id is not None:
raise ValueError(
'System generated users cannot have refresh tokens connected '
'to a client.')
if not user.system_generated and client_id is None:
raise ValueError('Client is required to generate a refresh token.')
return await self._store.async_create_refresh_token(user, client_id)
async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
return await self._store.async_get_refresh_token(token)
@callback
def async_create_access_token(self, refresh_token):
"""Create a new access token."""
access_token = models.AccessToken(refresh_token=refresh_token)
self._access_tokens[access_token.token] = access_token
return access_token
@callback
def async_get_access_token(self, token):
"""Get an access token."""
tkn = self._access_tokens.get(token)
if tkn is None:
return None
if tkn.expired:
self._access_tokens.pop(token)
return None
return tkn
async def _async_create_login_flow(self, handler, *, source, data):
"""Create a login flow."""
auth_provider = self._providers[handler]
if not auth_provider.initialized:
auth_provider.initialized = True
await auth_provider.async_initialize()
return await auth_provider.async_credential_flow()
async def _async_finish_login_flow(self, result):
"""Result of a credential login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return None
auth_provider = self._providers[result['handler']]
return await auth_provider.async_get_or_create_credentials(
result['data'])
@callback
def _async_get_auth_provider(self, credentials):
"""Helper to get auth provider from a set of credentials."""
auth_provider_key = (credentials.auth_provider_type,
credentials.auth_provider_id)
return self._providers[auth_provider_key]

View file

@ -0,0 +1,213 @@
"""Storage for auth models."""
from datetime import timedelta
from homeassistant.util import dt as dt_util
from . import models
STORAGE_VERSION = 1
STORAGE_KEY = 'auth'
class AuthStore:
"""Stores authentication info.
Any mutation to an object should happen inside the auth store.
The auth store is lazy. It won't load the data from disk until a method is
called that needs it.
"""
def __init__(self, hass):
"""Initialize the auth store."""
self.hass = hass
self._users = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
async def async_get_users(self):
"""Retrieve all users."""
if self._users is None:
await self.async_load()
return list(self._users.values())
async def async_get_user(self, user_id):
"""Retrieve a user by id."""
if self._users is None:
await self.async_load()
return self._users.get(user_id)
async def async_create_user(self, name, is_owner=None, is_active=None,
system_generated=None, credentials=None):
"""Create a new user."""
if self._users is None:
await self.async_load()
kwargs = {
'name': name
}
if is_owner is not None:
kwargs['is_owner'] = is_owner
if is_active is not None:
kwargs['is_active'] = is_active
if system_generated is not None:
kwargs['system_generated'] = system_generated
new_user = models.User(**kwargs)
self._users[new_user.id] = new_user
if credentials is None:
await self.async_save()
return new_user
# Saving is done inside the link.
await self.async_link_user(new_user, credentials)
return new_user
async def async_link_user(self, user, credentials):
"""Add credentials to an existing user."""
user.credentials.append(credentials)
await self.async_save()
credentials.is_new = False
async def async_remove_user(self, user):
"""Remove a user."""
self._users.pop(user.id)
await self.async_save()
async def async_create_refresh_token(self, user, client_id=None):
"""Create a new token for a user."""
refresh_token = models.RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.token] = refresh_token
await self.async_save()
return refresh_token
async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
if self._users is None:
await self.async_load()
for user in self._users.values():
refresh_token = user.refresh_tokens.get(token)
if refresh_token is not None:
return refresh_token
return None
async def async_load(self):
"""Load the users."""
data = await self._store.async_load()
# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self._users is not None:
return
if data is None:
self._users = {}
return
users = {
user_dict['id']: models.User(**user_dict)
for user_dict in data['users']
}
for cred_dict in data['credentials']:
users[cred_dict['user_id']].credentials.append(models.Credentials(
id=cred_dict['id'],
is_new=False,
auth_provider_type=cred_dict['auth_provider_type'],
auth_provider_id=cred_dict['auth_provider_id'],
data=cred_dict['data'],
))
refresh_tokens = {}
for rt_dict in data['refresh_tokens']:
token = models.RefreshToken(
id=rt_dict['id'],
user=users[rt_dict['user_id']],
client_id=rt_dict['client_id'],
created_at=dt_util.parse_datetime(rt_dict['created_at']),
access_token_expiration=timedelta(
seconds=rt_dict['access_token_expiration']),
token=rt_dict['token'],
)
refresh_tokens[token.id] = token
users[rt_dict['user_id']].refresh_tokens[token.token] = token
for ac_dict in data['access_tokens']:
refresh_token = refresh_tokens[ac_dict['refresh_token_id']]
token = models.AccessToken(
refresh_token=refresh_token,
created_at=dt_util.parse_datetime(ac_dict['created_at']),
token=ac_dict['token'],
)
refresh_token.access_tokens.append(token)
self._users = users
async def async_save(self):
"""Save users."""
users = [
{
'id': user.id,
'is_owner': user.is_owner,
'is_active': user.is_active,
'name': user.name,
'system_generated': user.system_generated,
}
for user in self._users.values()
]
credentials = [
{
'id': credential.id,
'user_id': user.id,
'auth_provider_type': credential.auth_provider_type,
'auth_provider_id': credential.auth_provider_id,
'data': credential.data,
}
for user in self._users.values()
for credential in user.credentials
]
refresh_tokens = [
{
'id': refresh_token.id,
'user_id': user.id,
'client_id': refresh_token.client_id,
'created_at': refresh_token.created_at.isoformat(),
'access_token_expiration':
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
]
access_tokens = [
{
'id': user.id,
'refresh_token_id': refresh_token.id,
'created_at': access_token.created_at.isoformat(),
'token': access_token.token,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens
]
data = {
'users': users,
'credentials': credentials,
'access_tokens': access_tokens,
'refresh_tokens': refresh_tokens,
}
await self._store.async_save(data, delay=1)

View file

@ -0,0 +1,4 @@
"""Constants for the auth module."""
from datetime import timedelta
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)

View file

@ -0,0 +1,75 @@
"""Auth models."""
from datetime import datetime, timedelta
import uuid
import attr
from homeassistant.util import dt as dt_util
from .const import ACCESS_TOKEN_EXPIRATION
from .util import generate_secret
@attr.s(slots=True)
class User:
"""A user."""
name = attr.ib(type=str)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False)
system_generated = attr.ib(type=bool, default=False)
# List of credentials of a user.
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
# Tokens associated with a user.
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict), cmp=False)
@attr.s(slots=True)
class RefreshToken:
"""RefreshToken for a user to grant new access tokens."""
user = attr.ib(type=User)
client_id = attr.ib(type=str)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
access_token_expiration = attr.ib(type=timedelta,
default=ACCESS_TOKEN_EXPIRATION)
token = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))
access_tokens = attr.ib(type=list, default=attr.Factory(list), cmp=False)
@attr.s(slots=True)
class AccessToken:
"""Access token to access the API.
These will only ever be stored in memory and not be persisted.
"""
refresh_token = attr.ib(type=RefreshToken)
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
token = attr.ib(type=str,
default=attr.Factory(generate_secret))
@property
def expired(self):
"""Return if this token has expired."""
expires = self.created_at + self.refresh_token.access_token_expiration
return dt_util.utcnow() > expires
@attr.s(slots=True)
class Credentials:
"""Credentials for a user on an auth provider."""
auth_provider_type = attr.ib(type=str)
auth_provider_id = attr.ib(type=str)
# Allow the auth provider to store data to represent their auth.
data = attr.ib(type=dict)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_new = attr.ib(type=bool, default=True)

View file

@ -0,0 +1,147 @@
"""Auth providers for Home Assistant."""
import importlib
import logging
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import requirements
from homeassistant.core import callback
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.util.decorator import Registry
from homeassistant.auth.models import Credentials
_LOGGER = logging.getLogger(__name__)
DATA_REQS = 'auth_prov_reqs_processed'
AUTH_PROVIDERS = Registry()
AUTH_PROVIDER_SCHEMA = vol.Schema({
vol.Required(CONF_TYPE): str,
vol.Optional(CONF_NAME): str,
# Specify ID if you have two auth providers for same type.
vol.Optional(CONF_ID): str,
}, extra=vol.ALLOW_EXTRA)
async def auth_provider_from_config(hass, store, config):
"""Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE]
module = await load_auth_provider_module(hass, provider_name)
if module is None:
return None
try:
config = module.CONFIG_SCHEMA(config)
except vol.Invalid as err:
_LOGGER.error('Invalid configuration for auth provider %s: %s',
provider_name, humanize_error(config, err))
return None
return AUTH_PROVIDERS[provider_name](hass, store, config)
async def load_auth_provider_module(hass, provider):
"""Load an auth provider."""
try:
module = importlib.import_module(
'homeassistant.auth.providers.{}'.format(provider))
except ImportError:
_LOGGER.warning('Unable to find auth provider %s', provider)
return None
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'):
return module
processed = hass.data.get(DATA_REQS)
if processed is None:
processed = hass.data[DATA_REQS] = set()
elif provider in processed:
return module
req_success = await requirements.async_process_requirements(
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
if not req_success:
return None
processed.add(provider)
return module
class AuthProvider:
"""Provider of user authentication."""
DEFAULT_TITLE = 'Unnamed auth provider'
initialized = False
def __init__(self, hass, store, config):
"""Initialize an auth provider."""
self.hass = hass
self.store = store
self.config = config
@property
def id(self): # pylint: disable=invalid-name
"""Return id of the auth provider.
Optional, can be None.
"""
return self.config.get(CONF_ID)
@property
def type(self):
"""Return type of the provider."""
return self.config[CONF_TYPE]
@property
def name(self):
"""Return the name of the auth provider."""
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
async def async_credentials(self):
"""Return all credentials of this provider."""
users = await self.store.async_get_users()
return [
credentials
for user in users
for credentials in user.credentials
if (credentials.auth_provider_type == self.type and
credentials.auth_provider_id == self.id)
]
@callback
def async_create_credentials(self, data):
"""Create credentials."""
return Credentials(
auth_provider_type=self.type,
auth_provider_id=self.id,
data=data,
)
# Implement by extending class
async def async_initialize(self):
"""Initialize the auth provider.
Optional.
"""
async def async_credential_flow(self):
"""Return the data flow for logging in with auth provider."""
raise NotImplementedError
async def async_get_or_create_credentials(self, flow_result):
"""Get credentials based on the flow result."""
raise NotImplementedError
async def async_user_meta_for_credentials(self, credentials):
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
"""
return {}

View file

@ -6,14 +6,17 @@ import hmac
import voluptuous as vol
from homeassistant import auth, data_entry_flow
from homeassistant import data_entry_flow
from homeassistant.exceptions import HomeAssistantError
from homeassistant.auth.util import generate_secret
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
STORAGE_VERSION = 1
STORAGE_KEY = 'auth_provider.homeassistant'
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
}, extra=vol.PREVENT_EXTRA)
@ -43,7 +46,7 @@ class Data:
if data is None:
data = {
'salt': auth.generate_secret(),
'salt': generate_secret(),
'users': []
}
@ -112,8 +115,8 @@ class Data:
await self._store.async_save(self._data)
@auth.AUTH_PROVIDERS.register('homeassistant')
class HassAuthProvider(auth.AuthProvider):
@AUTH_PROVIDERS.register('homeassistant')
class HassAuthProvider(AuthProvider):
"""Auth provider based on a local storage of users in HASS config dir."""
DEFAULT_TITLE = 'Home Assistant Local'

View file

@ -5,9 +5,11 @@ import hmac
import voluptuous as vol
from homeassistant.exceptions import HomeAssistantError
from homeassistant import auth, data_entry_flow
from homeassistant import data_entry_flow
from homeassistant.core import callback
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
USER_SCHEMA = vol.Schema({
vol.Required('username'): str,
@ -16,7 +18,7 @@ USER_SCHEMA = vol.Schema({
})
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
vol.Required('users'): [USER_SCHEMA]
}, extra=vol.PREVENT_EXTRA)
@ -25,8 +27,8 @@ class InvalidAuthError(HomeAssistantError):
"""Raised when submitting invalid authentication."""
@auth.AUTH_PROVIDERS.register('insecure_example')
class ExampleAuthProvider(auth.AuthProvider):
@AUTH_PROVIDERS.register('insecure_example')
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
async def async_credential_flow(self):

View file

@ -9,15 +9,18 @@ import hmac
import voluptuous as vol
from homeassistant.exceptions import HomeAssistantError
from homeassistant import auth, data_entry_flow
from homeassistant import data_entry_flow
from homeassistant.core import callback
from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS
USER_SCHEMA = vol.Schema({
vol.Required('username'): str,
})
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({
}, extra=vol.PREVENT_EXTRA)
LEGACY_USER = 'homeassistant'
@ -27,8 +30,8 @@ class InvalidAuthError(HomeAssistantError):
"""Raised when submitting invalid authentication."""
@auth.AUTH_PROVIDERS.register('legacy_api_password')
class LegacyApiPasswordAuthProvider(auth.AuthProvider):
@AUTH_PROVIDERS.register('legacy_api_password')
class LegacyApiPasswordAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
DEFAULT_TITLE = 'Legacy API Password'

View file

@ -0,0 +1,13 @@
"""Auth utils."""
import binascii
import os
def generate_secret(entropy: int = 32) -> str:
"""Generate a secret.
Backport of secrets.token_hex from Python 3.6
Event loop friendly.
"""
return binascii.hexlify(os.urandom(entropy)).decode('ascii')

View file

@ -1 +0,0 @@
"""Auth providers for Home Assistant."""

View file

@ -10,7 +10,7 @@ import logging
from aiohttp import web
import voluptuous as vol
from homeassistant.auth import generate_secret
from homeassistant.auth.util import generate_secret
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import CONF_API_KEY, EVENT_HOMEASSISTANT_STOP, URL_API
import homeassistant.helpers.config_validation as cv

View file

@ -13,6 +13,7 @@ import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import auth
from homeassistant.auth import providers as auth_providers
from homeassistant.const import (
ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ASSUMED_STATE,
CONF_LATITUDE, CONF_LONGITUDE, CONF_NAME, CONF_PACKAGES, CONF_UNIT_SYSTEM,
@ -159,7 +160,7 @@ CORE_CONFIG_SCHEMA = CUSTOMIZE_CONFIG_SCHEMA.extend({
vol.All(cv.ensure_list, [vol.IsDir()]),
vol.Optional(CONF_PACKAGES, default={}): PACKAGES_CONFIG_SCHEMA,
vol.Optional(CONF_AUTH_PROVIDERS):
vol.All(cv.ensure_list, [auth.AUTH_PROVIDER_SCHEMA])
vol.All(cv.ensure_list, [auth_providers.AUTH_PROVIDER_SCHEMA])
})

View file

@ -5,7 +5,7 @@ import os
from homeassistant.core import HomeAssistant
from homeassistant.config import get_default_config_dir
from homeassistant.auth_providers import homeassistant as hass_auth
from homeassistant.auth.providers import homeassistant as hass_auth
def run(args):

1
tests/auth/__init__.py Normal file
View file

@ -0,0 +1 @@
"""Tests for the Home Assistant auth module."""

View file

@ -2,7 +2,7 @@
import pytest
from homeassistant import data_entry_flow
from homeassistant.auth_providers import homeassistant as hass_auth
from homeassistant.auth.providers import homeassistant as hass_auth
@pytest.fixture

View file

@ -4,8 +4,8 @@ import uuid
import pytest
from homeassistant import auth
from homeassistant.auth_providers import insecure_example
from homeassistant.auth import auth_store, models as auth_models
from homeassistant.auth.providers import insecure_example
from tests.common import mock_coro
@ -13,7 +13,7 @@ from tests.common import mock_coro
@pytest.fixture
def store(hass):
"""Mock store."""
return auth.AuthStore(hass)
return auth_store.AuthStore(hass)
@pytest.fixture
@ -45,7 +45,7 @@ async def test_create_new_credential(provider):
async def test_match_existing_credentials(store, provider):
"""See if we match existing users."""
existing = auth.Credentials(
existing = auth_models.Credentials(
id=uuid.uuid4(),
auth_provider_type='insecure_example',
auth_provider_id=None,

View file

@ -4,13 +4,14 @@ from unittest.mock import Mock
import pytest
from homeassistant import auth
from homeassistant.auth_providers import legacy_api_password
from homeassistant.auth import auth_store
from homeassistant.auth.providers import legacy_api_password
@pytest.fixture
def store(hass):
"""Mock store."""
return auth.AuthStore(hass)
return auth_store.AuthStore(hass)
@pytest.fixture

View file

@ -5,6 +5,8 @@ from unittest.mock import Mock, patch
import pytest
from homeassistant import auth, data_entry_flow
from homeassistant.auth import (
models as auth_models, auth_store, const as auth_const)
from homeassistant.util import dt as dt_util
from tests.common import (
MockUser, ensure_auth_manager_loaded, flush_store, CLIENT_ID)
@ -101,7 +103,7 @@ async def test_login_as_existing_user(mock_hass):
is_active=False,
name='Not user',
).add_to_auth_manager(manager)
user.credentials.append(auth.Credentials(
user.credentials.append(auth_models.Credentials(
id='mock-id2',
auth_provider_type='insecure_example',
auth_provider_id=None,
@ -116,7 +118,7 @@ async def test_login_as_existing_user(mock_hass):
is_active=False,
name='Paulus',
).add_to_auth_manager(manager)
user.credentials.append(auth.Credentials(
user.credentials.append(auth_models.Credentials(
id='mock-id',
auth_provider_type='insecure_example',
auth_provider_id=None,
@ -203,7 +205,7 @@ async def test_saving_loading(hass, hass_storage):
await flush_store(manager._store._store)
store2 = auth.AuthStore(hass)
store2 = auth_store.AuthStore(hass)
users = await store2.async_get_users()
assert len(users) == 1
assert users[0] == user
@ -211,23 +213,25 @@ async def test_saving_loading(hass, hass_storage):
def test_access_token_expired():
"""Test that the expired property on access tokens work."""
refresh_token = auth.RefreshToken(
refresh_token = auth_models.RefreshToken(
user=None,
client_id='bla'
)
access_token = auth.AccessToken(
access_token = auth_models.AccessToken(
refresh_token=refresh_token
)
assert access_token.expired is False
with patch('homeassistant.auth.dt_util.utcnow',
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
with patch('homeassistant.util.dt.utcnow',
return_value=dt_util.utcnow() +
auth_const.ACCESS_TOKEN_EXPIRATION):
assert access_token.expired is True
almost_exp = dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION - timedelta(1)
with patch('homeassistant.auth.dt_util.utcnow', return_value=almost_exp):
almost_exp = \
dt_util.utcnow() + auth_const.ACCESS_TOKEN_EXPIRATION - timedelta(1)
with patch('homeassistant.util.dt.utcnow', return_value=almost_exp):
assert access_token.expired is False
@ -242,8 +246,9 @@ async def test_cannot_retrieve_expired_access_token(hass):
access_token = manager.async_create_access_token(refresh_token)
assert manager.async_get_access_token(access_token.token) is access_token
with patch('homeassistant.auth.dt_util.utcnow',
return_value=dt_util.utcnow() + auth.ACCESS_TOKEN_EXPIRATION):
with patch('homeassistant.util.dt.utcnow',
return_value=dt_util.utcnow() +
auth_const.ACCESS_TOKEN_EXPIRATION):
assert manager.async_get_access_token(access_token.token) is None
# Even with unpatched time, it should have been removed from manager

View file

@ -12,6 +12,7 @@ import threading
from contextlib import contextmanager
from homeassistant import auth, core as ha, data_entry_flow, config_entries
from homeassistant.auth import models as auth_models, auth_store
from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config
from homeassistant.helpers import (
@ -114,7 +115,7 @@ def async_test_home_assistant(loop):
"""Return a Home Assistant object pointing at test config dir."""
hass = ha.HomeAssistant(loop)
hass.config.async_load = Mock()
store = auth.AuthStore(hass)
store = auth_store.AuthStore(hass)
hass.auth = auth.AuthManager(hass, store, {})
ensure_auth_manager_loaded(hass.auth)
INSTANCES.append(hass)
@ -308,7 +309,7 @@ def mock_registry(hass, mock_entries=None):
return registry
class MockUser(auth.User):
class MockUser(auth_models.User):
"""Mock a user in Home Assistant."""
def __init__(self, id='mock-id', is_owner=True, is_active=True,

View file

@ -7,7 +7,7 @@ import pytest
from aiohttp import BasicAuth, web
from aiohttp.web_exceptions import HTTPUnauthorized
from homeassistant.auth import AccessToken, RefreshToken
from homeassistant.auth.models import AccessToken, RefreshToken
from homeassistant.components.http.auth import setup_auth
from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.real_ip import setup_real_ip

View file

@ -4,7 +4,7 @@ from unittest.mock import Mock, patch
import pytest
from homeassistant.scripts import auth as script_auth
from homeassistant.auth_providers import homeassistant as hass_auth
from homeassistant.auth.providers import homeassistant as hass_auth
@pytest.fixture