Only create front-end client_id once (#15214)

* Only create frontend client_id once

* Check user and client_id before create refresh token

* Lint

* Follow code review comment

* Minor clenaup

* Update doc string
This commit is contained in:
Jason Hu 2018-07-01 10:36:50 -07:00 committed by Paulus Schoutsen
parent dffe36761d
commit a64a66dd62
5 changed files with 121 additions and 51 deletions

View file

@ -1,23 +1,22 @@
"""Provide an authentication layer for Home Assistant."""
import asyncio
import binascii
from collections import OrderedDict
from datetime import datetime, timedelta
import os
import importlib
import logging
import os
import uuid
from collections import OrderedDict
from datetime import datetime, timedelta
import attr
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements
from homeassistant.core import callback
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
from homeassistant.util.decorator import Registry
from homeassistant.core import callback
from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry
_LOGGER = logging.getLogger(__name__)
@ -349,6 +348,16 @@ class AuthManager:
return await self._store.async_create_client(
name, redirect_uris, no_secret)
async def async_get_or_create_client(self, name, *, redirect_uris=None,
no_secret=False):
"""Find a client, if not exists, create a new one."""
for client in await self._store.async_get_clients():
if client.name == name:
return client
return await self._store.async_create_client(
name, redirect_uris, no_secret)
async def async_get_client(self, client_id):
"""Get a client."""
return await self._store.async_get_client(client_id)
@ -392,29 +401,36 @@ class AuthStore:
def __init__(self, hass):
"""Initialize the auth store."""
self.hass = hass
self.users = None
self.clients = None
self._users = None
self._clients = None
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
async def credentials_for_provider(self, provider_type, provider_id):
"""Return credentials for specific auth provider type and id."""
if self.users is None:
if self._users is None:
await self.async_load()
return [
credentials
for user in self.users.values()
for user in self._users.values()
for credentials in user.credentials
if (credentials.auth_provider_type == provider_type and
credentials.auth_provider_id == provider_id)
]
async def async_get_user(self, user_id):
"""Retrieve a user."""
if self.users is None:
async def async_get_users(self):
"""Retrieve all users."""
if self._users is None:
await self.async_load()
return self.users.get(user_id)
return list(self._users.values())
async def async_get_user(self, user_id):
"""Retrieve a user."""
if self._users is None:
await self.async_load()
return self._users.get(user_id)
async def async_get_or_create_user(self, credentials, auth_provider):
"""Get or create a new user for given credentials.
@ -422,7 +438,7 @@ class AuthStore:
If link_user is passed in, the credentials will be linked to the passed
in user if the credentials are new.
"""
if self.users is None:
if self._users is None:
await self.async_load()
# New credentials, store in user
@ -430,7 +446,7 @@ class AuthStore:
info = await auth_provider.async_user_meta_for_credentials(
credentials)
# Make owner and activate user if it's the first user.
if self.users:
if self._users:
is_owner = False
is_active = False
else:
@ -442,11 +458,11 @@ class AuthStore:
is_active=is_active,
name=info.get('name'),
)
self.users[new_user.id] = new_user
self._users[new_user.id] = new_user
await self.async_link_user(new_user, credentials)
return new_user
for user in self.users.values():
for user in self._users.values():
for creds in user.credentials:
if (creds.auth_provider_type == credentials.auth_provider_type
and creds.auth_provider_id ==
@ -463,11 +479,19 @@ class AuthStore:
async def async_remove_user(self, user):
"""Remove a user."""
self.users.pop(user.id)
self._users.pop(user.id)
await self.async_save()
async def async_create_refresh_token(self, user, client_id):
"""Create a new token for a user."""
local_user = await self.async_get_user(user.id)
if local_user is None:
raise ValueError('Invalid user')
local_client = await self.async_get_client(client_id)
if local_client is None:
raise ValueError('Invalid client_id')
refresh_token = RefreshToken(user, client_id)
user.refresh_tokens[refresh_token.token] = refresh_token
await self.async_save()
@ -475,10 +499,10 @@ class AuthStore:
async def async_get_refresh_token(self, token):
"""Get refresh token by token."""
if self.users is None:
if self._users is None:
await self.async_load()
for user in self.users.values():
for user in self._users.values():
refresh_token = user.refresh_tokens.get(token)
if refresh_token is not None:
return refresh_token
@ -487,7 +511,7 @@ class AuthStore:
async def async_create_client(self, name, redirect_uris, no_secret):
"""Create a new client."""
if self.clients is None:
if self._clients is None:
await self.async_load()
kwargs = {
@ -499,16 +523,23 @@ class AuthStore:
kwargs['secret'] = None
client = Client(**kwargs)
self.clients[client.id] = client
self._clients[client.id] = client
await self.async_save()
return client
async def async_get_client(self, client_id):
"""Get a client."""
if self.clients is None:
async def async_get_clients(self):
"""Return all clients."""
if self._clients is None:
await self.async_load()
return self.clients.get(client_id)
return list(self._clients.values())
async def async_get_client(self, client_id):
"""Get a client."""
if self._clients is None:
await self.async_load()
return self._clients.get(client_id)
async def async_load(self):
"""Load the users."""
@ -516,12 +547,12 @@ class AuthStore:
# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self.users is not None:
if self._users is not None:
return
if data is None:
self.users = {}
self.clients = {}
self._users = {}
self._clients = {}
return
users = {
@ -565,8 +596,8 @@ class AuthStore:
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
}
self.users = users
self.clients = clients
self._users = users
self._clients = clients
async def async_save(self):
"""Save users."""
@ -577,7 +608,7 @@ class AuthStore:
'is_active': user.is_active,
'name': user.name,
}
for user in self.users.values()
for user in self._users.values()
]
credentials = [
@ -588,7 +619,7 @@ class AuthStore:
'auth_provider_id': credential.auth_provider_id,
'data': credential.data,
}
for user in self.users.values()
for user in self._users.values()
for credential in user.credentials
]
@ -602,7 +633,7 @@ class AuthStore:
refresh_token.access_token_expiration.total_seconds(),
'token': refresh_token.token,
}
for user in self.users.values()
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
]
@ -613,7 +644,7 @@ class AuthStore:
'created_at': access_token.created_at.isoformat(),
'token': access_token.token,
}
for user in self.users.values()
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
for access_token in refresh_token.access_tokens
]
@ -625,7 +656,7 @@ class AuthStore:
'secret': client.secret,
'redirect_uris': client.redirect_uris,
}
for client in self.clients.values()
for client in self._clients.values()
]
data = {

View file

@ -201,7 +201,7 @@ def add_manifest_json_key(key, val):
async def async_setup(hass, config):
"""Set up the serving of the frontend."""
if hass.auth.active:
client = await hass.auth.async_create_client(
client = await hass.auth.async_get_or_create_client(
'Home Assistant Frontend',
redirect_uris=['/'],
no_secret=True,

View file

@ -321,7 +321,7 @@ class MockUser(auth.User):
def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass."""
ensure_auth_manager_loaded(auth_mgr)
auth_mgr._store.users[self.id] = self
auth_mgr._store._users[self.id] = self
return self
@ -329,10 +329,10 @@ class MockUser(auth.User):
def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded."""
store = auth_mgr._store
if store.clients is None:
store.clients = {}
if store.users is None:
store.users = {}
if store._clients is None:
store._clients = {}
if store._users is None:
store._users = {}
class MockModule(object):

View file

@ -34,7 +34,7 @@ async def async_setup_auth(hass, aiohttp_client, provider_configs=BASE_CONFIG,
})
client = auth.Client('Test Client', CLIENT_ID, CLIENT_SECRET,
redirect_uris=[CLIENT_REDIRECT_URI])
hass.auth._store.clients[client.id] = client
hass.auth._store._clients[client.id] = client
if setup_api:
await async_setup_component(hass, 'api', {})
return await aiohttp_client(hass.http.app)

View file

@ -191,12 +191,13 @@ async def test_saving_loading(hass, hass_storage):
await flush_store(manager._store._store)
store2 = auth.AuthStore(hass)
await store2.async_load()
assert len(store2.users) == 1
assert store2.users[user.id] == user
users = await store2.async_get_users()
assert len(users) == 1
assert users[0] == user
assert len(store2.clients) == 1
assert store2.clients[client.id] == client
clients = await store2.async_get_clients()
assert len(clients) == 1
assert clients[0] == client
def test_access_token_expired():
@ -224,15 +225,18 @@ def test_access_token_expired():
async def test_cannot_retrieve_expired_access_token(hass):
"""Test that we cannot retrieve expired access tokens."""
manager = await auth.auth_manager_from_config(hass, [])
client = await manager.async_create_client('test')
user = MockUser(
id='mock-user',
is_owner=False,
is_active=False,
name='Paulus',
).add_to_auth_manager(manager)
refresh_token = await manager.async_create_refresh_token(user, 'bla')
access_token = manager.async_create_access_token(refresh_token)
refresh_token = await manager.async_create_refresh_token(user, client.id)
assert refresh_token.user.id is user.id
assert refresh_token.client_id is client.id
access_token = manager.async_create_access_token(refresh_token)
assert manager.async_get_access_token(access_token.token) is access_token
with patch('homeassistant.auth.dt_util.utcnow',
@ -241,3 +245,38 @@ async def test_cannot_retrieve_expired_access_token(hass):
# Even with unpatched time, it should have been removed from manager
assert manager.async_get_access_token(access_token.token) is None
async def test_get_or_create_client(hass):
"""Test that get_or_create_client works."""
manager = await auth.auth_manager_from_config(hass, [])
client1 = await manager.async_get_or_create_client(
'Test Client', redirect_uris=['https://test.com/1'])
assert client1.name is 'Test Client'
client2 = await manager.async_get_or_create_client(
'Test Client', redirect_uris=['https://test.com/1'])
assert client2.id is client1.id
async def test_cannot_create_refresh_token_with_invalide_client_id(hass):
"""Test that we cannot create refresh token with invalid client id."""
manager = await auth.auth_manager_from_config(hass, [])
user = MockUser(
id='mock-user',
is_owner=False,
is_active=False,
name='Paulus',
).add_to_auth_manager(manager)
with pytest.raises(ValueError):
await manager.async_create_refresh_token(user, 'bla')
async def test_cannot_create_refresh_token_with_invalide_user(hass):
"""Test that we cannot create refresh token with invalid client id."""
manager = await auth.auth_manager_from_config(hass, [])
client = await manager.async_create_client('test')
user = MockUser(id='invalid-user')
with pytest.raises(ValueError):
await manager.async_create_refresh_token(user, client.id)