mirror of
https://github.com/home-assistant/core
synced 2024-10-05 20:07:58 +00:00
Add system generated users (#15291)
* Add system generated users * Fix typing
This commit is contained in:
parent
a6e9dc81aa
commit
cb129bd207
|
@ -79,7 +79,14 @@ class AuthProvider:
|
||||||
|
|
||||||
async def async_credentials(self):
|
async def async_credentials(self):
|
||||||
"""Return all credentials of this provider."""
|
"""Return all credentials of this provider."""
|
||||||
return await self.store.credentials_for_provider(self.type, self.id)
|
users = await self.store.async_get_users()
|
||||||
|
return [
|
||||||
|
credentials
|
||||||
|
for user in users
|
||||||
|
for credentials in user.credentials
|
||||||
|
if (credentials.auth_provider_type == self.type and
|
||||||
|
credentials.auth_provider_id == self.id)
|
||||||
|
]
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_create_credentials(self, data):
|
def async_create_credentials(self, data):
|
||||||
|
@ -118,10 +125,11 @@ class AuthProvider:
|
||||||
class User:
|
class User:
|
||||||
"""A user."""
|
"""A user."""
|
||||||
|
|
||||||
|
name = attr.ib(type=str)
|
||||||
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex))
|
||||||
is_owner = attr.ib(type=bool, default=False)
|
is_owner = attr.ib(type=bool, default=False)
|
||||||
is_active = attr.ib(type=bool, default=False)
|
is_active = attr.ib(type=bool, default=False)
|
||||||
name = attr.ib(type=str, default=None)
|
system_generated = attr.ib(type=bool, default=False)
|
||||||
|
|
||||||
# List of credentials of a user.
|
# List of credentials of a user.
|
||||||
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
credentials = attr.ib(type=list, default=attr.Factory(list), cmp=False)
|
||||||
|
@ -300,10 +308,45 @@ class AuthManager:
|
||||||
"""Retrieve a user."""
|
"""Retrieve a user."""
|
||||||
return await self._store.async_get_user(user_id)
|
return await self._store.async_get_user(user_id)
|
||||||
|
|
||||||
|
async def async_create_system_user(self, name):
|
||||||
|
"""Create a system user."""
|
||||||
|
return await self._store.async_create_user(
|
||||||
|
name=name,
|
||||||
|
system_generated=True,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def async_get_or_create_user(self, credentials):
|
async def async_get_or_create_user(self, credentials):
|
||||||
"""Get or create a user."""
|
"""Get or create a user."""
|
||||||
return await self._store.async_get_or_create_user(
|
if not credentials.is_new:
|
||||||
credentials, self._async_get_auth_provider(credentials))
|
for user in await self._store.async_get_users():
|
||||||
|
for creds in user.credentials:
|
||||||
|
if (creds.auth_provider_type ==
|
||||||
|
credentials.auth_provider_type
|
||||||
|
and creds.auth_provider_id ==
|
||||||
|
credentials.auth_provider_id):
|
||||||
|
return user
|
||||||
|
|
||||||
|
raise ValueError('Unable to find the user.')
|
||||||
|
|
||||||
|
auth_provider = self._async_get_auth_provider(credentials)
|
||||||
|
info = await auth_provider.async_user_meta_for_credentials(
|
||||||
|
credentials)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
'credentials': credentials,
|
||||||
|
'name': info.get('name')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make owner and activate user if it's the first user.
|
||||||
|
if await self._store.async_get_users():
|
||||||
|
kwargs['is_owner'] = False
|
||||||
|
kwargs['is_active'] = False
|
||||||
|
else:
|
||||||
|
kwargs['is_owner'] = True
|
||||||
|
kwargs['is_active'] = True
|
||||||
|
|
||||||
|
return await self._store.async_create_user(**kwargs)
|
||||||
|
|
||||||
async def async_link_user(self, user, credentials):
|
async def async_link_user(self, user, credentials):
|
||||||
"""Link credentials to an existing user."""
|
"""Link credentials to an existing user."""
|
||||||
|
@ -313,9 +356,20 @@ class AuthManager:
|
||||||
"""Remove a user."""
|
"""Remove a user."""
|
||||||
await self._store.async_remove_user(user)
|
await self._store.async_remove_user(user)
|
||||||
|
|
||||||
async def async_create_refresh_token(self, user, client_id):
|
async def async_create_refresh_token(self, user, client=None):
|
||||||
"""Create a new refresh token for a user."""
|
"""Create a new refresh token for a user."""
|
||||||
return await self._store.async_create_refresh_token(user, client_id)
|
if not user.is_active:
|
||||||
|
raise ValueError('User is not active')
|
||||||
|
|
||||||
|
if user.system_generated and client is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'System generated users cannot have refresh tokens connected '
|
||||||
|
'to a client.')
|
||||||
|
|
||||||
|
if not user.system_generated and client is None:
|
||||||
|
raise ValueError('Client is required to generate a refresh token.')
|
||||||
|
|
||||||
|
return await self._store.async_create_refresh_token(user, client)
|
||||||
|
|
||||||
async def async_get_refresh_token(self, token):
|
async def async_get_refresh_token(self, token):
|
||||||
"""Get refresh token by token."""
|
"""Get refresh token by token."""
|
||||||
|
@ -324,7 +378,7 @@ class AuthManager:
|
||||||
@callback
|
@callback
|
||||||
def async_create_access_token(self, refresh_token):
|
def async_create_access_token(self, refresh_token):
|
||||||
"""Create a new access token."""
|
"""Create a new access token."""
|
||||||
access_token = AccessToken(refresh_token)
|
access_token = AccessToken(refresh_token=refresh_token)
|
||||||
self._access_tokens[access_token.token] = access_token
|
self._access_tokens[access_token.token] = access_token
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
|
@ -405,19 +459,6 @@ class AuthStore:
|
||||||
self._clients = None
|
self._clients = None
|
||||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
|
||||||
async def credentials_for_provider(self, provider_type, provider_id):
|
|
||||||
"""Return credentials for specific auth provider type and id."""
|
|
||||||
if self._users is None:
|
|
||||||
await self.async_load()
|
|
||||||
|
|
||||||
return [
|
|
||||||
credentials
|
|
||||||
for user in self._users.values()
|
|
||||||
for credentials in user.credentials
|
|
||||||
if (credentials.auth_provider_type == provider_type and
|
|
||||||
credentials.auth_provider_id == provider_id)
|
|
||||||
]
|
|
||||||
|
|
||||||
async def async_get_users(self):
|
async def async_get_users(self):
|
||||||
"""Retrieve all users."""
|
"""Retrieve all users."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
|
@ -426,50 +467,42 @@ class AuthStore:
|
||||||
return list(self._users.values())
|
return list(self._users.values())
|
||||||
|
|
||||||
async def async_get_user(self, user_id):
|
async def async_get_user(self, user_id):
|
||||||
"""Retrieve a user."""
|
"""Retrieve a user by id."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
return self._users.get(user_id)
|
return self._users.get(user_id)
|
||||||
|
|
||||||
async def async_get_or_create_user(self, credentials, auth_provider):
|
async def async_create_user(self, name, is_owner=None, is_active=None,
|
||||||
"""Get or create a new user for given credentials.
|
system_generated=None, credentials=None):
|
||||||
|
"""Create a new user."""
|
||||||
If link_user is passed in, the credentials will be linked to the passed
|
|
||||||
in user if the credentials are new.
|
|
||||||
"""
|
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
await self.async_load()
|
await self.async_load()
|
||||||
|
|
||||||
# New credentials, store in user
|
kwargs = {
|
||||||
if credentials.is_new:
|
'name': name
|
||||||
info = await auth_provider.async_user_meta_for_credentials(
|
}
|
||||||
credentials)
|
|
||||||
# Make owner and activate user if it's the first user.
|
|
||||||
if self._users:
|
|
||||||
is_owner = False
|
|
||||||
is_active = False
|
|
||||||
else:
|
|
||||||
is_owner = True
|
|
||||||
is_active = True
|
|
||||||
|
|
||||||
new_user = User(
|
if is_owner is not None:
|
||||||
is_owner=is_owner,
|
kwargs['is_owner'] = is_owner
|
||||||
is_active=is_active,
|
|
||||||
name=info.get('name'),
|
if is_active is not None:
|
||||||
)
|
kwargs['is_active'] = is_active
|
||||||
self._users[new_user.id] = new_user
|
|
||||||
await self.async_link_user(new_user, credentials)
|
if system_generated is not None:
|
||||||
|
kwargs['system_generated'] = system_generated
|
||||||
|
|
||||||
|
new_user = User(**kwargs)
|
||||||
|
|
||||||
|
self._users[new_user.id] = new_user
|
||||||
|
|
||||||
|
if credentials is None:
|
||||||
|
await self.async_save()
|
||||||
return new_user
|
return new_user
|
||||||
|
|
||||||
for user in self._users.values():
|
# Saving is done inside the link.
|
||||||
for creds in user.credentials:
|
await self.async_link_user(new_user, credentials)
|
||||||
if (creds.auth_provider_type == credentials.auth_provider_type
|
return new_user
|
||||||
and creds.auth_provider_id ==
|
|
||||||
credentials.auth_provider_id):
|
|
||||||
return user
|
|
||||||
|
|
||||||
raise ValueError('We got credentials with ID but found no user')
|
|
||||||
|
|
||||||
async def async_link_user(self, user, credentials):
|
async def async_link_user(self, user, credentials):
|
||||||
"""Add credentials to an existing user."""
|
"""Add credentials to an existing user."""
|
||||||
|
@ -482,17 +515,10 @@ class AuthStore:
|
||||||
self._users.pop(user.id)
|
self._users.pop(user.id)
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
|
|
||||||
async def async_create_refresh_token(self, user, client_id):
|
async def async_create_refresh_token(self, user, client=None):
|
||||||
"""Create a new token for a user."""
|
"""Create a new token for a user."""
|
||||||
local_user = await self.async_get_user(user.id)
|
client_id = client.id if client is not None else None
|
||||||
if local_user is None:
|
refresh_token = RefreshToken(user=user, client_id=client_id)
|
||||||
raise ValueError('Invalid user')
|
|
||||||
|
|
||||||
local_client = await self.async_get_client(client_id)
|
|
||||||
if local_client is None:
|
|
||||||
raise ValueError('Invalid client_id')
|
|
||||||
|
|
||||||
refresh_token = RefreshToken(user, client_id)
|
|
||||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||||
await self.async_save()
|
await self.async_save()
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
@ -607,6 +633,7 @@ class AuthStore:
|
||||||
'is_owner': user.is_owner,
|
'is_owner': user.is_owner,
|
||||||
'is_active': user.is_active,
|
'is_active': user.is_active,
|
||||||
'name': user.name,
|
'name': user.name,
|
||||||
|
'system_generated': user.system_generated,
|
||||||
}
|
}
|
||||||
for user in self._users.values()
|
for user in self._users.values()
|
||||||
]
|
]
|
||||||
|
|
|
@ -236,18 +236,16 @@ class GrantTokenView(HomeAssistantView):
|
||||||
grant_type = data.get('grant_type')
|
grant_type = data.get('grant_type')
|
||||||
|
|
||||||
if grant_type == 'authorization_code':
|
if grant_type == 'authorization_code':
|
||||||
return await self._async_handle_auth_code(
|
return await self._async_handle_auth_code(hass, client, data)
|
||||||
hass, client.id, data)
|
|
||||||
|
|
||||||
elif grant_type == 'refresh_token':
|
elif grant_type == 'refresh_token':
|
||||||
return await self._async_handle_refresh_token(
|
return await self._async_handle_refresh_token(hass, client, data)
|
||||||
hass, client.id, data)
|
|
||||||
|
|
||||||
return self.json({
|
return self.json({
|
||||||
'error': 'unsupported_grant_type',
|
'error': 'unsupported_grant_type',
|
||||||
}, status_code=400)
|
}, status_code=400)
|
||||||
|
|
||||||
async def _async_handle_auth_code(self, hass, client_id, data):
|
async def _async_handle_auth_code(self, hass, client, data):
|
||||||
"""Handle authorization code request."""
|
"""Handle authorization code request."""
|
||||||
code = data.get('code')
|
code = data.get('code')
|
||||||
|
|
||||||
|
@ -256,7 +254,7 @@ class GrantTokenView(HomeAssistantView):
|
||||||
'error': 'invalid_request',
|
'error': 'invalid_request',
|
||||||
}, status_code=400)
|
}, status_code=400)
|
||||||
|
|
||||||
credentials = self._retrieve_credentials(client_id, code)
|
credentials = self._retrieve_credentials(client.id, code)
|
||||||
|
|
||||||
if credentials is None:
|
if credentials is None:
|
||||||
return self.json({
|
return self.json({
|
||||||
|
@ -265,7 +263,7 @@ class GrantTokenView(HomeAssistantView):
|
||||||
|
|
||||||
user = await hass.auth.async_get_or_create_user(credentials)
|
user = await hass.auth.async_get_or_create_user(credentials)
|
||||||
refresh_token = await hass.auth.async_create_refresh_token(user,
|
refresh_token = await hass.auth.async_create_refresh_token(user,
|
||||||
client_id)
|
client)
|
||||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
return self.json({
|
return self.json({
|
||||||
|
@ -276,7 +274,7 @@ class GrantTokenView(HomeAssistantView):
|
||||||
int(refresh_token.access_token_expiration.total_seconds()),
|
int(refresh_token.access_token_expiration.total_seconds()),
|
||||||
})
|
})
|
||||||
|
|
||||||
async def _async_handle_refresh_token(self, hass, client_id, data):
|
async def _async_handle_refresh_token(self, hass, client, data):
|
||||||
"""Handle authorization code request."""
|
"""Handle authorization code request."""
|
||||||
token = data.get('refresh_token')
|
token = data.get('refresh_token')
|
||||||
|
|
||||||
|
@ -287,7 +285,7 @@ class GrantTokenView(HomeAssistantView):
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_get_refresh_token(token)
|
refresh_token = await hass.auth.async_get_refresh_token(token)
|
||||||
|
|
||||||
if refresh_token is None or refresh_token.client_id != client_id:
|
if refresh_token is None or refresh_token.client_id != client.id:
|
||||||
return self.json({
|
return self.json({
|
||||||
'error': 'invalid_grant',
|
'error': 'invalid_grant',
|
||||||
}, status_code=400)
|
}, status_code=400)
|
||||||
|
|
|
@ -54,7 +54,7 @@ async def test_match_existing_credentials(store, provider):
|
||||||
},
|
},
|
||||||
is_new=False,
|
is_new=False,
|
||||||
)
|
)
|
||||||
store.credentials_for_provider = Mock(return_value=mock_coro([existing]))
|
provider.async_credentials = Mock(return_value=mock_coro([existing]))
|
||||||
credentials = await provider.async_get_or_create_credentials({
|
credentials = await provider.async_get_or_create_credentials({
|
||||||
'username': 'user-test',
|
'username': 'user-test',
|
||||||
'password': 'password-test',
|
'password': 'password-test',
|
||||||
|
|
|
@ -21,6 +21,14 @@ def provider(hass, store):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(hass, store, provider):
|
||||||
|
"""Mock manager."""
|
||||||
|
return auth.AuthManager(hass, store, {
|
||||||
|
(provider.type, provider.id): provider
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
async def test_create_new_credential(provider):
|
async def test_create_new_credential(provider):
|
||||||
"""Test that we create a new credential."""
|
"""Test that we create a new credential."""
|
||||||
credentials = await provider.async_get_or_create_credentials({})
|
credentials = await provider.async_get_or_create_credentials({})
|
||||||
|
@ -28,13 +36,13 @@ async def test_create_new_credential(provider):
|
||||||
assert credentials.is_new is True
|
assert credentials.is_new is True
|
||||||
|
|
||||||
|
|
||||||
async def test_only_one_credentials(store, provider):
|
async def test_only_one_credentials(manager, provider):
|
||||||
"""Call create twice will return same credential."""
|
"""Call create twice will return same credential."""
|
||||||
credentials = await provider.async_get_or_create_credentials({})
|
credentials = await provider.async_get_or_create_credentials({})
|
||||||
await store.async_get_or_create_user(credentials, provider)
|
await manager.async_get_or_create_user(credentials)
|
||||||
credentials2 = await provider.async_get_or_create_credentials({})
|
credentials2 = await provider.async_get_or_create_credentials({})
|
||||||
assert credentials2.data["username"] is legacy_api_password.LEGACY_USER
|
assert credentials2.data["username"] == legacy_api_password.LEGACY_USER
|
||||||
assert credentials2.id is credentials.id
|
assert credentials2.id == credentials.id
|
||||||
assert credentials2.is_new is False
|
assert credentials2.is_new is False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -312,7 +312,8 @@ class MockUser(auth.User):
|
||||||
def __init__(self, id='mock-id', is_owner=True, is_active=True,
|
def __init__(self, id='mock-id', is_owner=True, is_active=True,
|
||||||
name='Mock User'):
|
name='Mock User'):
|
||||||
"""Initialize mock user."""
|
"""Initialize mock user."""
|
||||||
super().__init__(id, is_owner, is_active, name)
|
super().__init__(
|
||||||
|
id=id, is_owner=is_owner, is_active=is_active, name=name)
|
||||||
|
|
||||||
def add_to_hass(self, hass):
|
def add_to_hass(self, hass):
|
||||||
"""Test helper to add entry to hass."""
|
"""Test helper to add entry to hass."""
|
||||||
|
|
|
@ -34,5 +34,5 @@ def hass_access_token(hass):
|
||||||
no_secret=True,
|
no_secret=True,
|
||||||
))
|
))
|
||||||
refresh_token = hass.loop.run_until_complete(
|
refresh_token = hass.loop.run_until_complete(
|
||||||
hass.auth.async_create_refresh_token(user, client.id))
|
hass.auth.async_create_refresh_token(user, client))
|
||||||
yield hass.auth.async_create_access_token(refresh_token)
|
yield hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
|
@ -184,7 +184,7 @@ async def test_saving_loading(hass, hass_storage):
|
||||||
client = await manager.async_create_client(
|
client = await manager.async_create_client(
|
||||||
'test', redirect_uris=['https://example.com'])
|
'test', redirect_uris=['https://example.com'])
|
||||||
|
|
||||||
refresh_token = await manager.async_create_refresh_token(user, client.id)
|
refresh_token = await manager.async_create_refresh_token(user, client)
|
||||||
|
|
||||||
manager.async_create_access_token(refresh_token)
|
manager.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
|
@ -226,13 +226,8 @@ async def test_cannot_retrieve_expired_access_token(hass):
|
||||||
"""Test that we cannot retrieve expired access tokens."""
|
"""Test that we cannot retrieve expired access tokens."""
|
||||||
manager = await auth.auth_manager_from_config(hass, [])
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
client = await manager.async_create_client('test')
|
client = await manager.async_create_client('test')
|
||||||
user = MockUser(
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
id='mock-user',
|
refresh_token = await manager.async_create_refresh_token(user, client)
|
||||||
is_owner=False,
|
|
||||||
is_active=False,
|
|
||||||
name='Paulus',
|
|
||||||
).add_to_auth_manager(manager)
|
|
||||||
refresh_token = await manager.async_create_refresh_token(user, client.id)
|
|
||||||
assert refresh_token.user.id is user.id
|
assert refresh_token.user.id is user.id
|
||||||
assert refresh_token.client_id is client.id
|
assert refresh_token.client_id is client.id
|
||||||
|
|
||||||
|
@ -260,23 +255,41 @@ async def test_get_or_create_client(hass):
|
||||||
assert client2.id is client1.id
|
assert client2.id is client1.id
|
||||||
|
|
||||||
|
|
||||||
async def test_cannot_create_refresh_token_with_invalide_client_id(hass):
|
async def test_generating_system_user(hass):
|
||||||
"""Test that we cannot create refresh token with invalid client id."""
|
"""Test that we can add a system user."""
|
||||||
manager = await auth.auth_manager_from_config(hass, [])
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
user = MockUser(
|
user = await manager.async_create_system_user('Hass.io')
|
||||||
id='mock-user',
|
token = await manager.async_create_refresh_token(user)
|
||||||
is_owner=False,
|
assert user.system_generated
|
||||||
is_active=False,
|
assert token is not None
|
||||||
name='Paulus',
|
assert token.client_id is None
|
||||||
).add_to_auth_manager(manager)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await manager.async_create_refresh_token(user, 'bla')
|
|
||||||
|
|
||||||
|
|
||||||
async def test_cannot_create_refresh_token_with_invalide_user(hass):
|
async def test_refresh_token_requires_client_for_user(hass):
|
||||||
"""Test that we cannot create refresh token with invalid client id."""
|
"""Test that we can add a system user."""
|
||||||
manager = await auth.auth_manager_from_config(hass, [])
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
client = await manager.async_create_client('test')
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
user = MockUser(id='invalid-user')
|
assert user.system_generated is False
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await manager.async_create_refresh_token(user, client.id)
|
await manager.async_create_refresh_token(user)
|
||||||
|
|
||||||
|
client = await manager.async_get_or_create_client('Test client')
|
||||||
|
token = await manager.async_create_refresh_token(user, client)
|
||||||
|
assert token is not None
|
||||||
|
assert token.client_id == client.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_not_requires_client_for_system_user(hass):
|
||||||
|
"""Test that we can add a system user."""
|
||||||
|
manager = await auth.auth_manager_from_config(hass, [])
|
||||||
|
user = await manager.async_create_system_user('Hass.io')
|
||||||
|
assert user.system_generated is True
|
||||||
|
client = await manager.async_get_or_create_client('Test client')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await manager.async_create_refresh_token(user, client)
|
||||||
|
|
||||||
|
token = await manager.async_create_refresh_token(user)
|
||||||
|
assert token is not None
|
||||||
|
assert token.client_id is None
|
||||||
|
|
Loading…
Reference in a new issue