Foundation for users (#13968)

* Add initial user foundation to Home Assistant

* Address comments

* Address comments

* Allow non-ascii passwords

* One more utf-8 hmac compare digest

* Add new line
This commit is contained in:
Paulus Schoutsen 2018-05-01 12:20:41 -04:00 committed by Pascal Vizeli
parent b994c10d7f
commit cdd45e7878
22 changed files with 1774 additions and 59 deletions

505
homeassistant/auth.py Normal file
View file

@ -0,0 +1,505 @@
"""Provide an authentication layer for Home Assistant."""
import asyncio
import binascii
from collections import OrderedDict
from datetime import datetime, timedelta
import os
import importlib
import logging
import uuid
import attr
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements
from homeassistant.core import callback
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.decorator import Registry
from homeassistant.util import dt as dt_util
_LOGGER = logging.getLogger(__name__)
AUTH_PROVIDERS = Registry()
AUTH_PROVIDER_SCHEMA = vol.Schema({
vol.Required(CONF_TYPE): str,
vol.Optional(CONF_NAME): str,
# Specify ID if you have two auth providers for same type.
vol.Optional(CONF_ID): str,
}, extra=vol.ALLOW_EXTRA)
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
DATA_REQS = 'auth_reqs_processed'
class AuthError(HomeAssistantError):
"""Generic authentication error."""
class InvalidUser(AuthError):
"""Raised when an invalid user has been specified."""
class InvalidPassword(AuthError):
"""Raised when an invalid password has been supplied."""
class UnknownError(AuthError):
"""When an unknown error occurs."""
def generate_secret(entropy=32):
"""Generate a secret.
Backport of secrets.token_hex from Python 3.6
Event loop friendly.
"""
return binascii.hexlify(os.urandom(entropy)).decode('ascii')
class AuthProvider:
"""Provider of user authentication."""
DEFAULT_TITLE = 'Unnamed auth provider'
initialized = False
def __init__(self, store, config):
"""Initialize an auth provider."""
self.store = store
self.config = config
@property
def id(self): # pylint: disable=invalid-name
"""Return id of the auth provider.
Optional, can be None.
"""
return self.config.get(CONF_ID)
@property
def type(self):
"""Return type of the provider."""
return self.config[CONF_TYPE]
@property
def name(self):
"""Return the name of the auth provider."""
return self.config.get(CONF_NAME, self.DEFAULT_TITLE)
async def async_credentials(self):
"""Return all credentials of this provider."""
return await self.store.credentials_for_provider(self.type, self.id)
@callback
def async_create_credentials(self, data):
"""Create credentials."""
return Credentials(
auth_provider_type=self.type,
auth_provider_id=self.id,
data=data,
)
# Implement by extending class
async def async_initialize(self):
"""Initialize the auth provider.
Optional.
"""
async def async_credential_flow(self):
"""Return the data flow for logging in with auth provider."""
raise NotImplementedError
async def async_get_or_create_credentials(self, flow_result):
"""Get credentials based on the flow result."""
raise NotImplementedError
async def async_user_meta_for_credentials(self, credentials):
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
"""
return {}
@attr.s(slots=True)
class User:
"""A user."""
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False)
name = attr.ib(type=str, default=None)
# For persisting and see if saved?
# store = attr.ib(type=AuthStore, default=None)
# List of credentials of a user.
credentials = attr.ib(type=list, default=attr.Factory(list))
# Tokens associated with a user.
refresh_tokens = attr.ib(type=dict, default=attr.Factory(dict))
def as_dict(self):
"""Convert user object to a dictionary."""
return {
'id': self.id,
'is_owner': self.is_owner,
'is_active': self.is_active,
'name': self.name,
}
@attr.s(slots=True)
class RefreshToken:
"""RefreshToken for a user to grant new access tokens."""
user = attr.ib(type=User)
client_id = attr.ib(type=str)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
access_token_expiration = attr.ib(type=timedelta,
default=ACCESS_TOKEN_EXPIRATION)
token = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))
access_tokens = attr.ib(type=list, default=attr.Factory(list))
@attr.s(slots=True)
class AccessToken:
"""Access token to access the API.
These will only ever be stored in memory and not be persisted.
"""
refresh_token = attr.ib(type=RefreshToken)
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow))
token = attr.ib(type=str,
default=attr.Factory(generate_secret))
@property
def expires(self):
"""Return datetime when this token expires."""
return self.created_at + self.refresh_token.access_token_expiration
@attr.s(slots=True)
class Credentials:
"""Credentials for a user on an auth provider."""
auth_provider_type = attr.ib(type=str)
auth_provider_id = attr.ib(type=str)
# Allow the auth provider to store data to represent their auth.
data = attr.ib(type=dict)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
is_new = attr.ib(type=bool, default=True)
@attr.s(slots=True)
class Client:
"""Client that interacts with Home Assistant on behalf of a user."""
name = attr.ib(type=str)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
secret = attr.ib(type=str, default=attr.Factory(generate_secret))
async def load_auth_provider_module(hass, provider):
"""Load an auth provider."""
try:
module = importlib.import_module(
'homeassistant.auth_providers.{}'.format(provider))
except ImportError:
_LOGGER.warning('Unable to find auth provider %s', provider)
return None
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'):
return module
processed = hass.data.get(DATA_REQS)
if processed is None:
processed = hass.data[DATA_REQS] = set()
elif provider in processed:
return module
req_success = await requirements.async_process_requirements(
hass, 'auth provider {}'.format(provider), module.REQUIREMENTS)
if not req_success:
return None
return module
async def auth_manager_from_config(hass, provider_configs):
"""Initialize an auth manager from config."""
store = AuthStore(hass)
if provider_configs:
providers = await asyncio.gather(
*[_auth_provider_from_config(hass, store, config)
for config in provider_configs])
else:
providers = []
# So returned auth providers are in same order as config
provider_hash = OrderedDict()
for provider in providers:
if provider is None:
continue
key = (provider.type, provider.id)
if key in provider_hash:
_LOGGER.error(
'Found duplicate provider: %s. Please add unique IDs if you '
'want to have the same provider twice.', key)
continue
provider_hash[key] = provider
manager = AuthManager(hass, store, provider_hash)
return manager
async def _auth_provider_from_config(hass, store, config):
"""Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE]
module = await load_auth_provider_module(hass, provider_name)
if module is None:
return None
try:
config = module.CONFIG_SCHEMA(config)
except vol.Invalid as err:
_LOGGER.error('Invalid configuration for auth provider %s: %s',
provider_name, humanize_error(config, err))
return None
return AUTH_PROVIDERS[provider_name](store, config)
class AuthManager:
"""Manage the authentication for Home Assistant."""
def __init__(self, hass, store, providers):
"""Initialize the auth manager."""
self._store = store
self._providers = providers
self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow,
self._async_finish_login_flow)
self.access_tokens = {}
@property
def async_auth_providers(self):
"""Return a list of available auth providers."""
return self._providers.values()
async def async_get_user(self, user_id):
"""Retrieve a user."""
return await self._store.async_get_user(user_id)
async def async_get_or_create_user(self, credentials):
"""Get or create a user."""
return await self._store.async_get_or_create_user(
credentials, self._async_get_auth_provider(credentials))
async def async_link_user(self, user, credentials):
"""Link credentials to an existing user."""
await self._store.async_link_user(user, credentials)
async def async_remove_user(self, user):
"""Remove a user."""
await self._store.async_remove_user(user)
async def async_create_refresh_token(self, user, client_id):
"""Create a new refresh token for a user."""
return await self._store.async_create_refresh_token(user, client_id)
async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
return await self._store.async_get_refresh_token(token)
@callback
def async_create_access_token(self, refresh_token):
"""Create a new access token."""
access_token = AccessToken(refresh_token)
self.access_tokens[access_token.token] = access_token
return access_token
@callback
def async_get_access_token(self, token):
"""Get an access token."""
return self.access_tokens.get(token)
async def async_create_client(self, name):
"""Create a new client."""
return await self._store.async_create_client(name)
async def async_get_client(self, client_id):
"""Get a client."""
return await self._store.async_get_client(client_id)
async def _async_create_login_flow(self, handler, *, source, data):
"""Create a login flow."""
auth_provider = self._providers[handler]
if not auth_provider.initialized:
auth_provider.initialized = True
await auth_provider.async_initialize()
return await auth_provider.async_credential_flow()
async def _async_finish_login_flow(self, result):
"""Result of a credential login flow."""
auth_provider = self._providers[result['handler']]
return await auth_provider.async_get_or_create_credentials(
result['data'])
@callback
def _async_get_auth_provider(self, credentials):
"""Helper to get auth provider from a set of credentials."""
auth_provider_key = (credentials.auth_provider_type,
credentials.auth_provider_id)
return self._providers[auth_provider_key]
class AuthStore:
"""Stores authentication info.
Any mutation to an object should happen inside the auth store.
The auth store is lazy. It won't load the data from disk until a method is
called that needs it.
"""
def __init__(self, hass):
"""Initialize the auth store."""
self.hass = hass
self.users = None
self.clients = None
self._load_lock = asyncio.Lock(loop=hass.loop)
async def credentials_for_provider(self, provider_type, provider_id):
"""Return credentials for specific auth provider type and id."""
if self.users is None:
await self.async_load()
return [
credentials
for user in self.users.values()
for credentials in user.credentials
if (credentials.auth_provider_type == provider_type and
credentials.auth_provider_id == provider_id)
]
async def async_get_user(self, user_id):
"""Retrieve a user."""
if self.users is None:
await self.async_load()
return self.users.get(user_id)
async def async_get_or_create_user(self, credentials, auth_provider):
"""Get or create a new user for given credentials.
If link_user is passed in, the credentials will be linked to the passed
in user if the credentials are new.
"""
if self.users is None:
await self.async_load()
# New credentials, store in user
if credentials.is_new:
info = await auth_provider.async_user_meta_for_credentials(
credentials)
# Make owner and activate user if it's the first user.
if self.users:
is_owner = False
is_active = False
else:
is_owner = True
is_active = True
new_user = User(
is_owner=is_owner,
is_active=is_active,
name=info.get('name'),
)
self.users[new_user.id] = new_user
await self.async_link_user(new_user, credentials)
return new_user
for user in self.users.values():
for creds in user.credentials:
if (creds.auth_provider_type == credentials.auth_provider_type
and creds.auth_provider_id ==
credentials.auth_provider_id):
return user
raise ValueError('We got credentials with ID but found no user')
async def async_link_user(self, user, credentials):
"""Add credentials to an existing user."""
user.credentials.append(credentials)
await self.async_save()
credentials.is_new = False
async def async_remove_user(self, user):
"""Remove a user."""
self.users.pop(user.id)
await self.async_save()
async def async_create_refresh_token(self, user, client_id):
"""Create a new token for a user."""
refresh_token = RefreshToken(user, client_id)
user.refresh_tokens[refresh_token.token] = refresh_token
await self.async_save()
return refresh_token
async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
if self.users is None:
await self.async_load()
for user in self.users.values():
refresh_token = user.refresh_tokens.get(token)
if refresh_token is not None:
return refresh_token
return None
async def async_create_client(self, name):
"""Create a new client."""
if self.clients is None:
await self.async_load()
client = Client(name)
self.clients[client.id] = client
await self.async_save()
return client
async def async_get_client(self, client_id):
"""Get a client."""
if self.clients is None:
await self.async_load()
return self.clients.get(client_id)
async def async_load(self):
"""Load the users."""
async with self._load_lock:
self.users = {}
self.clients = {}
async def async_save(self):
"""Save users."""
pass

View file

@ -0,0 +1 @@
"""Auth providers for Home Assistant."""

View file

@ -0,0 +1,116 @@
"""Example auth provider."""
from collections import OrderedDict
import hmac
import voluptuous as vol
from homeassistant import auth, data_entry_flow
from homeassistant.core import callback
USER_SCHEMA = vol.Schema({
vol.Required('username'): str,
vol.Required('password'): str,
vol.Optional('name'): str,
})
CONFIG_SCHEMA = auth.AUTH_PROVIDER_SCHEMA.extend({
vol.Required('users'): [USER_SCHEMA]
}, extra=vol.PREVENT_EXTRA)
@auth.AUTH_PROVIDERS.register('insecure_example')
class ExampleAuthProvider(auth.AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
async def async_credential_flow(self):
"""Return a flow to login."""
return LoginFlow(self)
@callback
def async_validate_login(self, username, password):
"""Helper to validate a username and password."""
user = None
# Compare all users to avoid timing attacks.
for usr in self.config['users']:
if hmac.compare_digest(username.encode('utf-8'),
usr['username'].encode('utf-8')):
user = usr
if user is None:
# Do one more compare to make timing the same as if user was found.
hmac.compare_digest(password.encode('utf-8'),
password.encode('utf-8'))
raise auth.InvalidUser
if not hmac.compare_digest(user['password'].encode('utf-8'),
password.encode('utf-8')):
raise auth.InvalidPassword
async def async_get_or_create_credentials(self, flow_result):
"""Get credentials based on the flow result."""
username = flow_result['username']
password = flow_result['password']
self.async_validate_login(username, password)
for credential in await self.async_credentials():
if credential.data['username'] == username:
return credential
# Create new credentials.
return self.async_create_credentials({
'username': username
})
async def async_user_meta_for_credentials(self, credentials):
"""Return extra user metadata for credentials.
Will be used to populate info when creating a new user.
"""
username = credentials.data['username']
for user in self.config['users']:
if user['username'] == username:
return {
'name': user.get('name')
}
return {}
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
def __init__(self, auth_provider):
"""Initialize the login flow."""
self._auth_provider = auth_provider
async def async_step_init(self, user_input=None):
"""Handle the step of the form."""
errors = {}
if user_input is not None:
try:
self._auth_provider.async_validate_login(
user_input['username'], user_input['password'])
except (auth.InvalidUser, auth.InvalidPassword):
errors['base'] = 'invalid_auth'
if not errors:
return self.async_create_entry(
title=self._auth_provider.name,
data=user_input
)
schema = OrderedDict()
schema['username'] = str
schema['password'] = str
return self.async_show_form(
step_id='init',
data_schema=vol.Schema(schema),
errors=errors,
)

View file

@ -0,0 +1,344 @@
"""Component to allow users to login and get tokens.
All requests will require passing in a valid client ID and secret via HTTP
Basic Auth.
# GET /auth/providers
Return a list of auth providers. Example:
[
{
"name": "Local",
"id": null,
"type": "local_provider",
}
]
# POST /auth/login_flow
Create a login flow. Will return the first step of the flow.
Pass in parameter 'handler' to specify the auth provider to use. Auth providers
are identified by type and id.
{
"handler": ["local_provider", null]
}
Return value will be a step in a data entry flow. See the docs for data entry
flow for details.
{
"data_schema": [
{"name": "username", "type": "string"},
{"name": "password", "type": "string"}
],
"errors": {},
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
"handler": ["insecure_example", null],
"step_id": "init",
"type": "form"
}
# POST /auth/login_flow/{flow_id}
Progress the flow. Most flows will be 1 page, but could optionally add extra
login challenges, like TFA. Once the flow has finished, the returned step will
have type "create_entry" and "result" key will contain an authorization code.
{
"flow_id": "8f7e42faab604bcab7ac43c44ca34d58",
"handler": ["insecure_example", null],
"result": "411ee2f916e648d691e937ae9344681e",
"source": "user",
"title": "Example",
"type": "create_entry",
"version": 1
}
# POST /auth/token
This is an OAuth2 endpoint for granting tokens. We currently support the grant
types "authorization_code" and "refresh_token". Because we follow the OAuth2
spec, data should be send in formatted as x-www-form-urlencoded. Examples will
be in JSON as it's more readable.
## Grant type authorization_code
Exchange the authorization code retrieved from the login flow for tokens.
{
"grant_type": "authorization_code",
"code": "411ee2f916e648d691e937ae9344681e"
}
Return value will be the access and refresh tokens. The access token will have
a limited expiration. New access tokens can be requested using the refresh
token.
{
"access_token": "ABCDEFGH",
"expires_in": 1800,
"refresh_token": "IJKLMNOPQRST",
"token_type": "Bearer"
}
## Grant type refresh_token
Request a new access token using a refresh token.
{
"grant_type": "refresh_token",
"refresh_token": "IJKLMNOPQRST"
}
Return value will be a new access token. The access token will have
a limited expiration.
{
"access_token": "ABCDEFGH",
"expires_in": 1800,
"token_type": "Bearer"
}
"""
import logging
import uuid
import aiohttp.web
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.core import callback
from homeassistant.helpers.data_entry_flow import (
FlowManagerIndexView, FlowManagerResourceView)
from homeassistant.components.http.view import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
from .client import verify_client
DOMAIN = 'auth'
DEPENDENCIES = ['http']
_LOGGER = logging.getLogger(__name__)
async def async_setup(hass, config):
"""Component to allow users to login."""
store_credentials, retrieve_credentials = _create_cred_store()
hass.http.register_view(AuthProvidersView)
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow))
hass.http.register_view(
LoginFlowResourceView(hass.auth.login_flow, store_credentials))
hass.http.register_view(GrantTokenView(retrieve_credentials))
hass.http.register_view(LinkUserView(retrieve_credentials))
return True
class AuthProvidersView(HomeAssistantView):
"""View to get available auth providers."""
url = '/auth/providers'
name = 'api:auth:providers'
requires_auth = False
@verify_client
async def get(self, request, client_id):
"""Get available auth providers."""
return self.json([{
'name': provider.name,
'id': provider.id,
'type': provider.type,
} for provider in request.app['hass'].auth.async_auth_providers])
class LoginFlowIndexView(FlowManagerIndexView):
"""View to create a config flow."""
url = '/auth/login_flow'
name = 'api:auth:login_flow'
requires_auth = False
async def get(self, request):
"""Do not allow index of flows in progress."""
return aiohttp.web.Response(status=405)
# pylint: disable=arguments-differ
@verify_client
async def post(self, request, client_id):
"""Create a new login flow."""
# pylint: disable=no-value-for-parameter
return await super().post(request)
class LoginFlowResourceView(FlowManagerResourceView):
"""View to interact with the flow manager."""
url = '/auth/login_flow/{flow_id}'
name = 'api:auth:login_flow:resource'
requires_auth = False
def __init__(self, flow_mgr, store_credentials):
"""Initialize the login flow resource view."""
super().__init__(flow_mgr)
self._store_credentials = store_credentials
# pylint: disable=arguments-differ
async def get(self, request):
"""Do not allow getting status of a flow in progress."""
return self.json_message('Invalid flow specified', 404)
# pylint: disable=arguments-differ
@verify_client
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(self, request, client_id, flow_id, data):
"""Handle progressing a login flow request."""
try:
result = await self._flow_mgr.async_configure(flow_id, data)
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
except vol.Invalid:
return self.json_message('User input malformed', 400)
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return self.json(self._prepare_result_json(result))
result.pop('data')
result['result'] = self._store_credentials(client_id, result['result'])
return self.json(result)
class GrantTokenView(HomeAssistantView):
"""View to grant tokens."""
url = '/auth/token'
name = 'api:auth:token'
requires_auth = False
def __init__(self, retrieve_credentials):
"""Initialize the grant token view."""
self._retrieve_credentials = retrieve_credentials
@verify_client
async def post(self, request, client_id):
"""Grant a token."""
hass = request.app['hass']
data = await request.post()
grant_type = data.get('grant_type')
if grant_type == 'authorization_code':
return await self._async_handle_auth_code(
hass, client_id, data)
elif grant_type == 'refresh_token':
return await self._async_handle_refresh_token(
hass, client_id, data)
return self.json({
'error': 'unsupported_grant_type',
}, status_code=400)
async def _async_handle_auth_code(self, hass, client_id, data):
"""Handle authorization code request."""
code = data.get('code')
if code is None:
return self.json({
'error': 'invalid_request',
}, status_code=400)
credentials = self._retrieve_credentials(client_id, code)
if credentials is None:
return self.json({
'error': 'invalid_request',
}, status_code=400)
user = await hass.auth.async_get_or_create_user(credentials)
refresh_token = await hass.auth.async_create_refresh_token(user,
client_id)
access_token = hass.auth.async_create_access_token(refresh_token)
return self.json({
'access_token': access_token.token,
'token_type': 'Bearer',
'refresh_token': refresh_token.token,
'expires_in':
int(refresh_token.access_token_expiration.total_seconds()),
})
async def _async_handle_refresh_token(self, hass, client_id, data):
"""Handle authorization code request."""
token = data.get('refresh_token')
if token is None:
return self.json({
'error': 'invalid_request',
}, status_code=400)
refresh_token = await hass.auth.async_get_refresh_token(token)
if refresh_token is None or refresh_token.client_id != client_id:
return self.json({
'error': 'invalid_grant',
}, status_code=400)
access_token = hass.auth.async_create_access_token(refresh_token)
return self.json({
'access_token': access_token.token,
'token_type': 'Bearer',
'expires_in':
int(refresh_token.access_token_expiration.total_seconds()),
})
class LinkUserView(HomeAssistantView):
"""View to link existing users to new credentials."""
url = '/auth/link_user'
name = 'api:auth:link_user'
def __init__(self, retrieve_credentials):
"""Initialize the link user view."""
self._retrieve_credentials = retrieve_credentials
@RequestDataValidator(vol.Schema({
'code': str,
'client_id': str,
}))
async def post(self, request, data):
"""Link a user."""
hass = request.app['hass']
user = request['hass_user']
credentials = self._retrieve_credentials(
data['client_id'], data['code'])
if credentials is None:
return self.json_message('Invalid code', status_code=400)
await hass.auth.async_link_user(user, credentials)
return self.json_message('User linked')
@callback
def _create_cred_store():
"""Create a credential store."""
temp_credentials = {}
@callback
def store_credentials(client_id, credentials):
"""Store credentials and return a code to retrieve it."""
code = uuid.uuid4().hex
temp_credentials[(client_id, code)] = credentials
return code
@callback
def retrieve_credentials(client_id, code):
"""Retrieve credentials."""
return temp_credentials.pop((client_id, code), None)
return store_credentials, retrieve_credentials

View file

@ -0,0 +1,63 @@
"""Helpers to resolve client ID/secret."""
import base64
from functools import wraps
import hmac
import aiohttp.hdrs
def verify_client(method):
"""Decorator to verify client id/secret on requests."""
@wraps(method)
async def wrapper(view, request, *args, **kwargs):
"""Verify client id/secret before doing request."""
client_id = await _verify_client(request)
if client_id is None:
return view.json({
'error': 'invalid_client',
}, status_code=401)
return await method(
view, request, *args, client_id=client_id, **kwargs)
return wrapper
async def _verify_client(request):
"""Method to verify the client id/secret in consistent time.
By using a consistent time for looking up client id and comparing the
secret, we prevent attacks by malicious actors trying different client ids
and are able to derive from the time it takes to process the request if
they guessed the client id correctly.
"""
if aiohttp.hdrs.AUTHORIZATION not in request.headers:
return None
auth_type, auth_value = \
request.headers.get(aiohttp.hdrs.AUTHORIZATION).split(' ', 1)
if auth_type != 'Basic':
return None
decoded = base64.b64decode(auth_value).decode('utf-8')
try:
client_id, client_secret = decoded.split(':', 1)
except ValueError:
# If no ':' in decoded
return None
client = await request.app['hass'].auth.async_get_client(client_id)
if client is None:
# Still do a compare so we run same time as if a client was found.
hmac.compare_digest(client_secret.encode('utf-8'),
client_secret.encode('utf-8'))
return None
if hmac.compare_digest(client_secret.encode('utf-8'),
client.secret.encode('utf-8')):
return client_id
return None

View file

@ -32,17 +32,19 @@ def setup_auth(app, trusted_networks, api_password):
if (HTTP_HEADER_HA_AUTH in request.headers and
hmac.compare_digest(
api_password, request.headers[HTTP_HEADER_HA_AUTH])):
api_password.encode('utf-8'),
request.headers[HTTP_HEADER_HA_AUTH].encode('utf-8'))):
# A valid auth header has been set
authenticated = True
elif (DATA_API_PASSWORD in request.query and
hmac.compare_digest(api_password,
request.query[DATA_API_PASSWORD])):
hmac.compare_digest(
api_password.encode('utf-8'),
request.query[DATA_API_PASSWORD].encode('utf-8'))):
authenticated = True
elif (hdrs.AUTHORIZATION in request.headers and
validate_authorization_header(api_password, request)):
await async_validate_auth_header(api_password, request)):
authenticated = True
elif _is_trusted_ip(request, trusted_networks):
@ -70,23 +72,38 @@ def _is_trusted_ip(request, trusted_networks):
def validate_password(request, api_password):
"""Test if password is valid."""
return hmac.compare_digest(
api_password, request.app['hass'].http.api_password)
api_password.encode('utf-8'),
request.app['hass'].http.api_password.encode('utf-8'))
def validate_authorization_header(api_password, request):
async def async_validate_auth_header(api_password, request):
"""Test an authorization header if valid password."""
if hdrs.AUTHORIZATION not in request.headers:
return False
auth_type, auth = request.headers.get(hdrs.AUTHORIZATION).split(' ', 1)
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION).split(' ', 1)
if auth_type != 'Basic':
if auth_type == 'Basic':
decoded = base64.b64decode(auth_val).decode('utf-8')
try:
username, password = decoded.split(':', 1)
except ValueError:
# If no ':' in decoded
return False
if username != 'homeassistant':
return False
return hmac.compare_digest(api_password.encode('utf-8'),
password.encode('utf-8'))
if auth_type != 'Bearer':
return False
decoded = base64.b64decode(auth).decode('utf-8')
username, password = decoded.split(':', 1)
if username != 'homeassistant':
hass = request.app['hass']
access_token = hass.auth.async_get_access_token(auth_val)
if access_token is None:
return False
return hmac.compare_digest(api_password, password)
request['hass_user'] = access_token.refresh_token.user
return True

View file

@ -12,13 +12,14 @@ from typing import Any, List, Tuple # NOQA
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import auth
from homeassistant.const import (
ATTR_FRIENDLY_NAME, ATTR_HIDDEN, ATTR_ASSUMED_STATE,
CONF_LATITUDE, CONF_LONGITUDE, CONF_NAME, CONF_PACKAGES, CONF_UNIT_SYSTEM,
CONF_TIME_ZONE, CONF_ELEVATION, CONF_UNIT_SYSTEM_METRIC,
CONF_UNIT_SYSTEM_IMPERIAL, CONF_TEMPERATURE_UNIT, TEMP_CELSIUS,
__version__, CONF_CUSTOMIZE, CONF_CUSTOMIZE_DOMAIN, CONF_CUSTOMIZE_GLOB,
CONF_WHITELIST_EXTERNAL_DIRS)
CONF_WHITELIST_EXTERNAL_DIRS, CONF_AUTH_PROVIDERS)
from homeassistant.core import callback, DOMAIN as CONF_CORE
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import get_component, get_platform
@ -157,6 +158,8 @@ CORE_CONFIG_SCHEMA = CUSTOMIZE_CONFIG_SCHEMA.extend({
# pylint: disable=no-value-for-parameter
vol.All(cv.ensure_list, [vol.IsDir()]),
vol.Optional(CONF_PACKAGES, default={}): PACKAGES_CONFIG_SCHEMA,
vol.Optional(CONF_AUTH_PROVIDERS):
vol.All(cv.ensure_list, [auth.AUTH_PROVIDER_SCHEMA])
})
@ -394,6 +397,12 @@ async def async_process_ha_core_config(hass, config):
This method is a coroutine.
"""
config = CORE_CONFIG_SCHEMA(config)
# Only load auth during startup.
if not hasattr(hass, 'auth'):
hass.auth = await auth.auth_manager_from_config(
hass, config.get(CONF_AUTH_PROVIDERS, []))
hac = hass.config
def set_time_zone(time_zone_str):

View file

@ -260,7 +260,7 @@ class ConfigEntries:
"""Initialize the entry manager."""
self.hass = hass
self.flow = data_entry_flow.FlowManager(
hass, self._async_create_flow, self._async_save_entry)
hass, self._async_create_flow, self._async_finish_flow)
self._hass_config = hass_config
self._entries = None
self._sched_save = None
@ -345,8 +345,8 @@ class ConfigEntries:
return await entry.async_unload(
self.hass, component=getattr(self.hass.components, component))
async def _async_save_entry(self, result):
"""Add an entry."""
async def _async_finish_flow(self, result):
"""Finish a config flow and add an entry."""
entry = ConfigEntry(
version=result['version'],
domain=result['handler'],

View file

@ -30,6 +30,7 @@ CONF_API_KEY = 'api_key'
CONF_API_VERSION = 'api_version'
CONF_AT = 'at'
CONF_AUTHENTICATION = 'authentication'
CONF_AUTH_PROVIDERS = 'auth_providers'
CONF_BASE = 'base'
CONF_BEFORE = 'before'
CONF_BELOW = 'below'

View file

@ -67,7 +67,7 @@ class FlowManager:
return await self._async_handle_step(flow, step, data)
async def async_configure(self, flow_id, user_input=None):
"""Start or continue a configuration flow."""
"""Continue a configuration flow."""
flow = self._progress.get(flow_id)
if flow is None:

View file

@ -7,40 +7,40 @@ from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
def _prepare_json(result):
"""Convert result for JSON."""
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
data = result.copy()
data.pop('result')
data.pop('data')
return data
elif result['type'] != data_entry_flow.RESULT_TYPE_FORM:
return result
import voluptuous_serialize
data = result.copy()
schema = data['data_schema']
if schema is None:
data['data_schema'] = []
else:
data['data_schema'] = voluptuous_serialize.convert(schema)
return data
class FlowManagerIndexView(HomeAssistantView):
"""View to create config flows."""
class _BaseFlowManagerView(HomeAssistantView):
"""Foundation for flow manager views."""
def __init__(self, flow_mgr):
"""Initialize the flow manager index view."""
self._flow_mgr = flow_mgr
async def get(self, request):
"""List flows that are in progress."""
return self.json(self._flow_mgr.async_progress())
# pylint: disable=no-self-use
def _prepare_result_json(self, result):
"""Convert result to JSON."""
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
data = result.copy()
data.pop('result')
data.pop('data')
return data
elif result['type'] != data_entry_flow.RESULT_TYPE_FORM:
return result
import voluptuous_serialize
data = result.copy()
schema = data['data_schema']
if schema is None:
data['data_schema'] = []
else:
data['data_schema'] = voluptuous_serialize.convert(schema)
return data
class FlowManagerIndexView(_BaseFlowManagerView):
"""View to create config flows."""
@RequestDataValidator(vol.Schema({
vol.Required('handler'): vol.Any(str, list),
@ -59,18 +59,14 @@ class FlowManagerIndexView(HomeAssistantView):
except data_entry_flow.UnknownStep:
return self.json_message('Handler does not support init', 400)
result = _prepare_json(result)
result = self._prepare_result_json(result)
return self.json(result)
class FlowManagerResourceView(HomeAssistantView):
class FlowManagerResourceView(_BaseFlowManagerView):
"""View to interact with the flow manager."""
def __init__(self, flow_mgr):
"""Initialize the flow manager resource view."""
self._flow_mgr = flow_mgr
async def get(self, request, flow_id):
"""Get the current state of a data_entry_flow."""
try:
@ -78,7 +74,7 @@ class FlowManagerResourceView(HomeAssistantView):
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
result = _prepare_json(result)
result = self._prepare_result_json(result)
return self.json(result)
@ -92,7 +88,7 @@ class FlowManagerResourceView(HomeAssistantView):
except vol.Invalid:
return self.json_message('User input malformed', 400)
result = _prepare_json(result)
result = self._prepare_result_json(result)
return self.json(result)

View file

@ -41,3 +41,7 @@ disable=
[EXCEPTIONS]
overgeneral-exceptions=Exception,HomeAssistantError
# For attrs
[typecheck]
ignored-classes=_CountingAttr

View file

@ -0,0 +1 @@
"""Tests for the auth providers."""

View file

@ -0,0 +1,89 @@
"""Tests for the insecure example auth provider."""
from unittest.mock import Mock
import uuid
import pytest
from homeassistant import auth
from homeassistant.auth_providers import insecure_example
from tests.common import mock_coro
@pytest.fixture
def store():
"""Mock store."""
return auth.AuthStore(Mock())
@pytest.fixture
def provider(store):
"""Mock provider."""
return insecure_example.ExampleAuthProvider(store, {
'type': 'insecure_example',
'users': [
{
'username': 'user-test',
'password': 'password-test',
},
{
'username': '🎉',
'password': '😎',
}
]
})
async def test_create_new_credential(provider):
"""Test that we create a new credential."""
credentials = await provider.async_get_or_create_credentials({
'username': 'user-test',
'password': 'password-test',
})
assert credentials.is_new is True
async def test_match_existing_credentials(store, provider):
"""See if we match existing users."""
existing = auth.Credentials(
id=uuid.uuid4(),
auth_provider_type='insecure_example',
auth_provider_id=None,
data={
'username': 'user-test'
},
is_new=False,
)
store.credentials_for_provider = Mock(return_value=mock_coro([existing]))
credentials = await provider.async_get_or_create_credentials({
'username': 'user-test',
'password': 'password-test',
})
assert credentials is existing
async def test_verify_username(provider):
"""Test we raise if incorrect user specified."""
with pytest.raises(auth.InvalidUser):
await provider.async_get_or_create_credentials({
'username': 'non-existing-user',
'password': 'password-test',
})
async def test_verify_password(provider):
"""Test we raise if incorrect user specified."""
with pytest.raises(auth.InvalidPassword):
await provider.async_get_or_create_credentials({
'username': 'user-test',
'password': 'incorrect-password',
})
async def test_utf_8_username_password(provider):
"""Test that we create a new credential."""
credentials = await provider.async_get_or_create_credentials({
'username': '🎉',
'password': '😎',
})
assert credentials.is_new is True

View file

@ -10,7 +10,8 @@ import logging
import threading
from contextlib import contextmanager
from homeassistant import core as ha, loader, data_entry_flow, config_entries
from homeassistant import (
auth, core as ha, loader, data_entry_flow, config_entries)
from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config
from homeassistant.helpers import (
@ -113,6 +114,9 @@ def async_test_home_assistant(loop):
hass.config_entries = config_entries.ConfigEntries(hass, {})
hass.config_entries._entries = []
hass.config.async_load = Mock()
store = auth.AuthStore(hass)
hass.auth = auth.AuthManager(hass, store, {})
ensure_auth_manager_loaded(hass.auth)
INSTANCES.append(hass)
orig_async_add_job = hass.async_add_job
@ -303,6 +307,34 @@ def mock_registry(hass, mock_entries=None):
return registry
class MockUser(auth.User):
"""Mock a user in Home Assistant."""
def __init__(self, id='mock-id', is_owner=True, is_active=True,
name='Mock User'):
"""Initialize mock user."""
super().__init__(id, is_owner, is_active, name)
def add_to_hass(self, hass):
"""Test helper to add entry to hass."""
return self.add_to_auth_manager(hass.auth)
def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass."""
auth_mgr._store.users[self.id] = self
return self
@ha.callback
def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded."""
store = auth_mgr._store
if store.clients is None:
store.clients = {}
if store.users is None:
store.users = {}
class MockModule(object):
"""Representation of a fake module."""

View file

@ -0,0 +1,38 @@
"""Tests for the auth component."""
from aiohttp.helpers import BasicAuth
from homeassistant import auth
from homeassistant.setup import async_setup_component
from tests.common import ensure_auth_manager_loaded
BASE_CONFIG = [{
'name': 'Example',
'type': 'insecure_example',
'users': [{
'username': 'test-user',
'password': 'test-pass',
'name': 'Test Name'
}]
}]
CLIENT_ID = 'test-id'
CLIENT_SECRET = 'test-secret'
CLIENT_AUTH = BasicAuth(CLIENT_ID, CLIENT_SECRET)
async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
setup_api=False):
"""Helper to setup authentication and create a HTTP client."""
hass.auth = await auth.auth_manager_from_config(hass, provider_configs)
ensure_auth_manager_loaded(hass.auth)
await async_setup_component(hass, 'auth', {
'http': {
'api_password': 'bla'
}
})
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET)
hass.auth._store.clients[client.id] = client
if setup_api:
await async_setup_component(hass, 'api', {})
return await aiohttp_client(hass.http.app)

View file

@ -0,0 +1,70 @@
"""Tests for the client validator."""
from aiohttp.helpers import BasicAuth
import pytest
from homeassistant.setup import async_setup_component
from homeassistant.components.auth.client import verify_client
from homeassistant.components.http.view import HomeAssistantView
from . import async_setup_auth
@pytest.fixture
def mock_view(hass):
"""Register a view that verifies client id/secret."""
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
clients = []
class ClientView(HomeAssistantView):
url = '/'
name = 'bla'
@verify_client
async def get(self, request, client_id):
"""Handle GET request."""
clients.append(client_id)
hass.http.register_view(ClientView)
return clients
async def test_verify_client(hass, aiohttp_client, mock_view):
"""Test that verify client can extract client auth from a request."""
http_client = await async_setup_auth(hass, aiohttp_client)
client = await hass.auth.async_create_client('Hello')
resp = await http_client.get('/', auth=BasicAuth(client.id, client.secret))
assert resp.status == 200
assert mock_view == [client.id]
async def test_verify_client_no_auth_header(hass, aiohttp_client, mock_view):
"""Test that verify client will decline unknown client id."""
http_client = await async_setup_auth(hass, aiohttp_client)
resp = await http_client.get('/')
assert resp.status == 401
assert mock_view == []
async def test_verify_client_invalid_client_id(hass, aiohttp_client,
mock_view):
"""Test that verify client will decline unknown client id."""
http_client = await async_setup_auth(hass, aiohttp_client)
client = await hass.auth.async_create_client('Hello')
resp = await http_client.get('/', auth=BasicAuth('invalid', client.secret))
assert resp.status == 401
assert mock_view == []
async def test_verify_client_invalid_client_secret(hass, aiohttp_client,
mock_view):
"""Test that verify client will decline incorrect client secret."""
http_client = await async_setup_auth(hass, aiohttp_client)
client = await hass.auth.async_create_client('Hello')
resp = await http_client.get('/', auth=BasicAuth(client.id, 'invalid'))
assert resp.status == 401
assert mock_view == []

View file

@ -0,0 +1,53 @@
"""Integration tests for the auth component."""
from . import async_setup_auth, CLIENT_AUTH
async def test_login_new_user_and_refresh_token(hass, aiohttp_client):
"""Test logging in with new user and refreshing tokens."""
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
resp = await client.post('/auth/login_flow', json={
'handler': ['insecure_example', None]
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'username': 'test-user',
'password': 'test-pass',
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
code = step['result']
# Exchange code for tokens
resp = await client.post('/auth/token', data={
'grant_type': 'authorization_code',
'code': code
}, auth=CLIENT_AUTH)
assert resp.status == 200
tokens = await resp.json()
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
# Use refresh token to get more tokens.
resp = await client.post('/auth/token', data={
'grant_type': 'refresh_token',
'refresh_token': tokens['refresh_token']
}, auth=CLIENT_AUTH)
assert resp.status == 200
tokens = await resp.json()
assert 'refresh_token' not in tokens
assert hass.auth.async_get_access_token(tokens['access_token']) is not None
# Test using access token to hit API.
resp = await client.get('/api/')
assert resp.status == 401
resp = await client.get('/api/', headers={
'authorization': 'Bearer {}'.format(tokens['access_token'])
})
assert resp.status == 200

View file

@ -0,0 +1,150 @@
"""Tests for the link user flow."""
from . import async_setup_auth, CLIENT_AUTH, CLIENT_ID
async def async_get_code(hass, aiohttp_client):
"""Helper for link user tests that returns authorization code."""
config = [{
'name': 'Example',
'type': 'insecure_example',
'users': [{
'username': 'test-user',
'password': 'test-pass',
'name': 'Test Name'
}]
}, {
'name': 'Example',
'id': '2nd auth',
'type': 'insecure_example',
'users': [{
'username': '2nd-user',
'password': '2nd-pass',
'name': '2nd Name'
}]
}]
client = await async_setup_auth(hass, aiohttp_client, config)
resp = await client.post('/auth/login_flow', json={
'handler': ['insecure_example', None]
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'username': 'test-user',
'password': 'test-pass',
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
code = step['result']
# Exchange code for tokens
resp = await client.post('/auth/token', data={
'grant_type': 'authorization_code',
'code': code
}, auth=CLIENT_AUTH)
assert resp.status == 200
tokens = await resp.json()
access_token = hass.auth.async_get_access_token(tokens['access_token'])
assert access_token is not None
user = access_token.refresh_token.user
assert len(user.credentials) == 1
# Now authenticate with the 2nd flow
resp = await client.post('/auth/login_flow', json={
'handler': ['insecure_example', '2nd auth']
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'username': '2nd-user',
'password': '2nd-pass',
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
return {
'user': user,
'code': step['result'],
'client': client,
'tokens': tokens,
}
async def test_link_user(hass, aiohttp_client):
"""Test linking a user to new credentials."""
info = await async_get_code(hass, aiohttp_client)
client = info['client']
code = info['code']
tokens = info['tokens']
# Link user
resp = await client.post('/auth/link_user', json={
'client_id': CLIENT_ID,
'code': code
}, headers={
'authorization': 'Bearer {}'.format(tokens['access_token'])
})
assert resp.status == 200
assert len(info['user'].credentials) == 2
async def test_link_user_invalid_client_id(hass, aiohttp_client):
"""Test linking a user to new credentials."""
info = await async_get_code(hass, aiohttp_client)
client = info['client']
code = info['code']
tokens = info['tokens']
# Link user
resp = await client.post('/auth/link_user', json={
'client_id': 'invalid',
'code': code
}, headers={
'authorization': 'Bearer {}'.format(tokens['access_token'])
})
assert resp.status == 400
assert len(info['user'].credentials) == 1
async def test_link_user_invalid_code(hass, aiohttp_client):
"""Test linking a user to new credentials."""
info = await async_get_code(hass, aiohttp_client)
client = info['client']
tokens = info['tokens']
# Link user
resp = await client.post('/auth/link_user', json={
'client_id': CLIENT_ID,
'code': 'invalid'
}, headers={
'authorization': 'Bearer {}'.format(tokens['access_token'])
})
assert resp.status == 400
assert len(info['user'].credentials) == 1
async def test_link_user_invalid_auth(hass, aiohttp_client):
"""Test linking a user to new credentials."""
info = await async_get_code(hass, aiohttp_client)
client = info['client']
code = info['code']
# Link user
resp = await client.post('/auth/link_user', json={
'client_id': CLIENT_ID,
'code': code,
}, headers={'authorization': 'Bearer invalid'})
assert resp.status == 401
assert len(info['user'].credentials) == 1

View file

@ -0,0 +1,66 @@
"""Tests for the login flow."""
from aiohttp.helpers import BasicAuth
from . import async_setup_auth, CLIENT_AUTH
async def test_fetch_auth_providers(hass, aiohttp_client):
"""Test fetching auth providers."""
client = await async_setup_auth(hass, aiohttp_client)
resp = await client.get('/auth/providers', auth=CLIENT_AUTH)
assert await resp.json() == [{
'name': 'Example',
'type': 'insecure_example',
'id': None
}]
async def test_fetch_auth_providers_require_valid_client(hass, aiohttp_client):
"""Test fetching auth providers."""
client = await async_setup_auth(hass, aiohttp_client)
resp = await client.get('/auth/providers',
auth=BasicAuth('invalid', 'bla'))
assert resp.status == 401
async def test_cannot_get_flows_in_progress(hass, aiohttp_client):
"""Test we cannot get flows in progress."""
client = await async_setup_auth(hass, aiohttp_client, [])
resp = await client.get('/auth/login_flow')
assert resp.status == 405
async def test_invalid_username_password(hass, aiohttp_client):
"""Test we cannot get flows in progress."""
client = await async_setup_auth(hass, aiohttp_client)
resp = await client.post('/auth/login_flow', json={
'handler': ['insecure_example', None]
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
# Incorrect username
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'username': 'wrong-user',
'password': 'test-pass',
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
assert step['step_id'] == 'init'
assert step['errors']['base'] == 'invalid_auth'
# Incorrect password
resp = await client.post(
'/auth/login_flow/{}'.format(step['flow_id']), json={
'username': 'test-user',
'password': 'wrong-pass',
}, auth=CLIENT_AUTH)
assert resp.status == 200
step = await resp.json()
assert step['step_id'] == 'init'
assert step['errors']['base'] == 'invalid_auth'

View file

@ -1,4 +1,6 @@
"""The tests for the Home Assistant HTTP component."""
import logging
from homeassistant.setup import async_setup_component
import homeassistant.components.http as http
@ -76,14 +78,13 @@ async def test_api_no_base_url(hass):
async def test_not_log_password(hass, aiohttp_client, caplog):
"""Test access with password doesn't get logged."""
result = await async_setup_component(hass, 'api', {
assert await async_setup_component(hass, 'api', {
'http': {
http.CONF_API_PASSWORD: 'some-pass'
}
})
assert result
client = await aiohttp_client(hass.http.app)
logging.getLogger('aiohttp.access').setLevel(logging.INFO)
resp = await client.get('/api/', params={
'api_password': 'some-pass'

159
tests/test_auth.py Normal file
View file

@ -0,0 +1,159 @@
"""Tests for the Home Assistant auth module."""
from unittest.mock import Mock
import pytest
from homeassistant import auth, data_entry_flow
from tests.common import MockUser, ensure_auth_manager_loaded
@pytest.fixture
def mock_hass():
"""Hass mock with minimum amount of data set to make it work with auth."""
hass = Mock()
hass.config.skip_pip = True
return hass
async def test_auth_manager_from_config_validates_config_and_id(mock_hass):
"""Test get auth providers."""
manager = await auth.auth_manager_from_config(mock_hass, [{
'name': 'Test Name',
'type': 'insecure_example',
'users': [],
}, {
'name': 'Invalid config because no users',
'type': 'insecure_example',
'id': 'invalid_config',
}, {
'name': 'Test Name 2',
'type': 'insecure_example',
'id': 'another',
'users': [],
}, {
'name': 'Wrong because duplicate ID',
'type': 'insecure_example',
'id': 'another',
'users': [],
}])
providers = [{
'name': provider.name,
'id': provider.id,
'type': provider.type,
} for provider in manager.async_auth_providers]
assert providers == [{
'name': 'Test Name',
'type': 'insecure_example',
'id': None,
}, {
'name': 'Test Name 2',
'type': 'insecure_example',
'id': 'another',
}]
async def test_create_new_user(mock_hass):
"""Test creating new user."""
manager = await auth.auth_manager_from_config(mock_hass, [{
'type': 'insecure_example',
'users': [{
'username': 'test-user',
'password': 'test-pass',
'name': 'Test Name'
}]
}])
step = await manager.login_flow.async_init(('insecure_example', None))
assert step['type'] == data_entry_flow.RESULT_TYPE_FORM
step = await manager.login_flow.async_configure(step['flow_id'], {
'username': 'test-user',
'password': 'test-pass',
})
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
credentials = step['result']
user = await manager.async_get_or_create_user(credentials)
assert user is not None
assert user.is_owner is True
assert user.name == 'Test Name'
async def test_login_as_existing_user(mock_hass):
"""Test login as existing user."""
manager = await auth.auth_manager_from_config(mock_hass, [{
'type': 'insecure_example',
'users': [{
'username': 'test-user',
'password': 'test-pass',
'name': 'Test Name'
}]
}])
ensure_auth_manager_loaded(manager)
# Add fake user with credentials for example auth provider.
user = MockUser(
id='mock-user',
is_owner=False,
is_active=False,
name='Paulus',
).add_to_auth_manager(manager)
user.credentials.append(auth.Credentials(
id='mock-id',
auth_provider_type='insecure_example',
auth_provider_id=None,
data={'username': 'test-user'},
is_new=False,
))
step = await manager.login_flow.async_init(('insecure_example', None))
assert step['type'] == data_entry_flow.RESULT_TYPE_FORM
step = await manager.login_flow.async_configure(step['flow_id'], {
'username': 'test-user',
'password': 'test-pass',
})
assert step['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
credentials = step['result']
user = await manager.async_get_or_create_user(credentials)
assert user is not None
assert user.id == 'mock-user'
assert user.is_owner is False
assert user.is_active is False
assert user.name == 'Paulus'
async def test_linking_user_to_two_auth_providers(mock_hass):
"""Test linking user to two auth providers."""
manager = await auth.auth_manager_from_config(mock_hass, [{
'type': 'insecure_example',
'users': [{
'username': 'test-user',
'password': 'test-pass',
}]
}, {
'type': 'insecure_example',
'id': 'another-provider',
'users': [{
'username': 'another-user',
'password': 'another-password',
}]
}])
step = await manager.login_flow.async_init(('insecure_example', None))
step = await manager.login_flow.async_configure(step['flow_id'], {
'username': 'test-user',
'password': 'test-pass',
})
user = await manager.async_get_or_create_user(step['result'])
assert user is not None
step = await manager.login_flow.async_init(('insecure_example',
'another-provider'))
step = await manager.login_flow.async_configure(step['flow_id'], {
'username': 'another-user',
'password': 'another-password',
})
await manager.async_link_user(user, step['result'])
assert len(user.credentials) == 2