Add permission checks to Rest API (#18639)

* Add permission checks to Rest API

* Clean up unnecessary method

* Remove all the tuple stuff from entity check

* Simplify perms

* Correct param name for owner permission

* Hass.io make/update user to be admin

* Types
This commit is contained in:
Paulus Schoutsen 2018-11-25 18:04:48 +01:00 committed by Pascal Vizeli
parent f387cdec59
commit 8b8629a5f4
15 changed files with 282 additions and 145 deletions

View file

@ -132,13 +132,15 @@ class AuthManager:
return None
async def async_create_system_user(self, name: str) -> models.User:
async def async_create_system_user(
self, name: str,
group_ids: Optional[List[str]] = None) -> models.User:
"""Create a system user."""
user = await self._store.async_create_user(
name=name,
system_generated=True,
is_active=True,
group_ids=[],
group_ids=group_ids or [],
)
self.hass.bus.async_fire(EVENT_USER_ADDED, {
@ -217,6 +219,17 @@ class AuthManager:
'user_id': user.id
})
async def async_update_user(self, user: models.User,
name: Optional[str] = None,
group_ids: Optional[List[str]] = None) -> None:
"""Update a user."""
kwargs = {} # type: Dict[str,Any]
if name is not None:
kwargs['name'] = name
if group_ids is not None:
kwargs['group_ids'] = group_ids
await self._store.async_update_user(user, **kwargs)
async def async_activate_user(self, user: models.User) -> None:
"""Activate a user."""
await self._store.async_activate_user(user)

View file

@ -133,6 +133,33 @@ class AuthStore:
self._users.pop(user.id)
self._async_schedule_save()
async def async_update_user(
self, user: models.User, name: Optional[str] = None,
is_active: Optional[bool] = None,
group_ids: Optional[List[str]] = None) -> None:
"""Update a user."""
assert self._groups is not None
if group_ids is not None:
groups = []
for grid in group_ids:
group = self._groups.get(grid)
if group is None:
raise ValueError("Invalid group specified.")
groups.append(group)
user.groups = groups
user.invalidate_permission_cache()
for attr_name, value in (
('name', name),
('is_active', is_active),
):
if value is not None:
setattr(user, attr_name, value)
self._async_schedule_save()
async def async_activate_user(self, user: models.User) -> None:
"""Activate a user."""
user.is_active = True

View file

@ -8,6 +8,7 @@ import attr
from homeassistant.util import dt as dt_util
from . import permissions as perm_mdl
from .const import GROUP_ID_ADMIN
from .util import generate_secret
TOKEN_TYPE_NORMAL = 'normal'
@ -48,7 +49,7 @@ class User:
) # type: Dict[str, RefreshToken]
_permissions = attr.ib(
type=perm_mdl.PolicyPermissions,
type=Optional[perm_mdl.PolicyPermissions],
init=False,
cmp=False,
default=None,
@ -69,6 +70,19 @@ class User:
return self._permissions
@property
def is_admin(self) -> bool:
"""Return if user is part of the admin group."""
if self.is_owner:
return True
return self.is_active and any(
gr.id == GROUP_ID_ADMIN for gr in self.groups)
def invalidate_permission_cache(self) -> None:
"""Invalidate permission cache."""
self._permissions = None
@attr.s(slots=True)
class RefreshToken:

View file

@ -5,10 +5,8 @@ from typing import ( # noqa: F401
import voluptuous as vol
from homeassistant.core import State
from .const import CAT_ENTITIES
from .types import CategoryType, PolicyType
from .types import PolicyType
from .entities import ENTITY_POLICY_SCHEMA, compile_entities
from .merge import merge_policies # noqa
@ -22,13 +20,20 @@ _LOGGER = logging.getLogger(__name__)
class AbstractPermissions:
"""Default permissions class."""
def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
_cached_entity_func = None
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
raise NotImplementedError
def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
raise NotImplementedError
def check_entity(self, entity_id: str, key: str) -> bool:
"""Check if we can access entity."""
entity_func = self._cached_entity_func
if entity_func is None:
entity_func = self._cached_entity_func = self._entity_func()
return entity_func(entity_id, key)
class PolicyPermissions(AbstractPermissions):
@ -37,34 +42,10 @@ class PolicyPermissions(AbstractPermissions):
def __init__(self, policy: PolicyType) -> None:
"""Initialize the permission class."""
self._policy = policy
self._compiled = {} # type: Dict[str, Callable[..., bool]]
def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
func = self._policy_func(CAT_ENTITIES, compile_entities)
return func(entity_id, (key,))
def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
func = self._policy_func(CAT_ENTITIES, compile_entities)
keys = ('read',)
return [entity for entity in states if func(entity.entity_id, keys)]
def _policy_func(self, category: str,
compile_func: Callable[[CategoryType], Callable]) \
-> Callable[..., bool]:
"""Get a policy function."""
func = self._compiled.get(category)
if func:
return func
func = self._compiled[category] = compile_func(
self._policy.get(category))
_LOGGER.debug("Compiled %s func: %s", category, func)
return func
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
return compile_entities(self._policy.get(CAT_ENTITIES))
def __eq__(self, other: Any) -> bool:
"""Equals check."""
@ -78,13 +59,9 @@ class _OwnerPermissions(AbstractPermissions):
# pylint: disable=no-self-use
def check_entity(self, entity_id: str, key: str) -> bool:
"""Test if we can access entity."""
return True
def filter_states(self, states: List[State]) -> List[State]:
"""Filter a list of states for what the user is allowed to see."""
return states
def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access."""
return lambda entity_id, key: True
OwnerPermissions = _OwnerPermissions() # pylint: disable=invalid-name

View file

@ -28,28 +28,28 @@ ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({
}))
def _entity_allowed(schema: ValueType, keys: Tuple[str]) \
def _entity_allowed(schema: ValueType, key: str) \
-> Union[bool, None]:
"""Test if an entity is allowed based on the keys."""
if schema is None or isinstance(schema, bool):
return schema
assert isinstance(schema, dict)
return schema.get(keys[0])
return schema.get(key)
def compile_entities(policy: CategoryType) \
-> Callable[[str, Tuple[str]], bool]:
-> Callable[[str, str], bool]:
"""Compile policy into a function that tests policy."""
# None, Empty Dict, False
if not policy:
def apply_policy_deny_all(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_deny_all(entity_id: str, key: str) -> bool:
"""Decline all."""
return False
return apply_policy_deny_all
if policy is True:
def apply_policy_allow_all(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_allow_all(entity_id: str, key: str) -> bool:
"""Approve all."""
return True
@ -61,7 +61,7 @@ def compile_entities(policy: CategoryType) \
entity_ids = policy.get(ENTITY_ENTITY_IDS)
all_entities = policy.get(SUBCAT_ALL)
funcs = [] # type: List[Callable[[str, Tuple[str]], Union[None, bool]]]
funcs = [] # type: List[Callable[[str, str], Union[None, bool]]]
# The order of these functions matter. The more precise are at the top.
# If a function returns None, they cannot handle it.
@ -70,23 +70,23 @@ def compile_entities(policy: CategoryType) \
# Setting entity_ids to a boolean is final decision for permissions
# So return right away.
if isinstance(entity_ids, bool):
def allowed_entity_id_bool(entity_id: str, keys: Tuple[str]) -> bool:
def allowed_entity_id_bool(entity_id: str, key: str) -> bool:
"""Test if allowed entity_id."""
return entity_ids # type: ignore
return allowed_entity_id_bool
if entity_ids is not None:
def allowed_entity_id_dict(entity_id: str, keys: Tuple[str]) \
def allowed_entity_id_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed entity_id."""
return _entity_allowed(
entity_ids.get(entity_id), keys) # type: ignore
entity_ids.get(entity_id), key) # type: ignore
funcs.append(allowed_entity_id_dict)
if isinstance(domains, bool):
def allowed_domain_bool(entity_id: str, keys: Tuple[str]) \
def allowed_domain_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return domains
@ -94,31 +94,31 @@ def compile_entities(policy: CategoryType) \
funcs.append(allowed_domain_bool)
elif domains is not None:
def allowed_domain_dict(entity_id: str, keys: Tuple[str]) \
def allowed_domain_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
domain = entity_id.split(".", 1)[0]
return _entity_allowed(domains.get(domain), keys) # type: ignore
return _entity_allowed(domains.get(domain), key) # type: ignore
funcs.append(allowed_domain_dict)
if isinstance(all_entities, bool):
def allowed_all_entities_bool(entity_id: str, keys: Tuple[str]) \
def allowed_all_entities_bool(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return all_entities
funcs.append(allowed_all_entities_bool)
elif all_entities is not None:
def allowed_all_entities_dict(entity_id: str, keys: Tuple[str]) \
def allowed_all_entities_dict(entity_id: str, key: str) \
-> Union[None, bool]:
"""Test if allowed domain."""
return _entity_allowed(all_entities, keys)
return _entity_allowed(all_entities, key)
funcs.append(allowed_all_entities_dict)
# Can happen if no valid subcategories specified
if not funcs:
def apply_policy_deny_all_2(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_deny_all_2(entity_id: str, key: str) -> bool:
"""Decline all."""
return False
@ -128,16 +128,16 @@ def compile_entities(policy: CategoryType) \
func = funcs[0]
@wraps(func)
def apply_policy_func(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_func(entity_id: str, key: str) -> bool:
"""Apply a single policy function."""
return func(entity_id, keys) is True
return func(entity_id, key) is True
return apply_policy_func
def apply_policy_funcs(entity_id: str, keys: Tuple[str]) -> bool:
def apply_policy_funcs(entity_id: str, key: str) -> bool:
"""Apply several policy functions."""
for func in funcs:
result = func(entity_id, keys)
result = func(entity_id, key)
if result is not None:
return result
return False

View file

@ -20,7 +20,8 @@ from homeassistant.const import (
URL_API_SERVICES, URL_API_STATES, URL_API_STATES_ENTITY, URL_API_STREAM,
URL_API_TEMPLATE, __version__)
import homeassistant.core as ha
from homeassistant.exceptions import TemplateError
from homeassistant.auth.permissions.const import POLICY_READ
from homeassistant.exceptions import TemplateError, Unauthorized
from homeassistant.helpers import template
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.helpers.state import AsyncTrackStates
@ -81,6 +82,8 @@ class APIEventStream(HomeAssistantView):
async def get(self, request):
"""Provide a streaming interface for the event bus."""
if not request['hass_user'].is_admin:
raise Unauthorized()
hass = request.app['hass']
stop_obj = object()
to_write = asyncio.Queue(loop=hass.loop)
@ -185,7 +188,13 @@ class APIStatesView(HomeAssistantView):
@ha.callback
def get(self, request):
"""Get current states."""
return self.json(request.app['hass'].states.async_all())
user = request['hass_user']
entity_perm = user.permissions.check_entity
states = [
state for state in request.app['hass'].states.async_all()
if entity_perm(state.entity_id, 'read')
]
return self.json(states)
class APIEntityStateView(HomeAssistantView):
@ -197,6 +206,10 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback
def get(self, request, entity_id):
"""Retrieve state of entity."""
user = request['hass_user']
if not user.permissions.check_entity(entity_id, POLICY_READ):
raise Unauthorized(entity_id=entity_id)
state = request.app['hass'].states.get(entity_id)
if state:
return self.json(state)
@ -204,6 +217,8 @@ class APIEntityStateView(HomeAssistantView):
async def post(self, request, entity_id):
"""Update state of entity."""
if not request['hass_user'].is_admin:
raise Unauthorized(entity_id=entity_id)
hass = request.app['hass']
try:
data = await request.json()
@ -236,6 +251,8 @@ class APIEntityStateView(HomeAssistantView):
@ha.callback
def delete(self, request, entity_id):
"""Remove entity."""
if not request['hass_user'].is_admin:
raise Unauthorized(entity_id=entity_id)
if request.app['hass'].states.async_remove(entity_id):
return self.json_message("Entity removed.")
return self.json_message("Entity not found.", HTTP_NOT_FOUND)
@ -261,6 +278,8 @@ class APIEventView(HomeAssistantView):
async def post(self, request, event_type):
"""Fire events."""
if not request['hass_user'].is_admin:
raise Unauthorized()
body = await request.text()
try:
event_data = json.loads(body) if body else None
@ -346,6 +365,8 @@ class APITemplateView(HomeAssistantView):
async def post(self, request):
"""Render a template."""
if not request['hass_user'].is_admin:
raise Unauthorized()
try:
data = await request.json()
tpl = template.Template(data['template'], request.app['hass'])
@ -363,6 +384,8 @@ class APIErrorLog(HomeAssistantView):
async def get(self, request):
"""Retrieve API error log."""
if not request['hass_user'].is_admin:
raise Unauthorized()
return web.FileResponse(request.app['hass'].data[DATA_LOGGING])

View file

@ -10,6 +10,7 @@ import os
import voluptuous as vol
from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.components import SERVICE_CHECK_CONFIG
from homeassistant.const import (
ATTR_NAME, SERVICE_HOMEASSISTANT_RESTART, SERVICE_HOMEASSISTANT_STOP)
@ -181,8 +182,14 @@ async def async_setup(hass, config):
if user and user.refresh_tokens:
refresh_token = list(user.refresh_tokens.values())[0]
# Migrate old hass.io users to be admin.
if not user.is_admin:
await hass.auth.async_update_user(
user, group_ids=[GROUP_ID_ADMIN])
if refresh_token is None:
user = await hass.auth.async_create_system_user('Hass.io')
user = await hass.auth.async_create_system_user(
'Hass.io', [GROUP_ID_ADMIN])
refresh_token = await hass.auth.async_create_refresh_token(user)
data['hassio_user'] = user.id
await store.async_save(data)

View file

@ -14,6 +14,7 @@ from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
from homeassistant.components.http.ban import process_success_login
from homeassistant.core import Context, is_callback
from homeassistant.const import CONTENT_TYPE_JSON
from homeassistant import exceptions
from homeassistant.helpers.json import JSONEncoder
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
@ -107,10 +108,13 @@ def request_handler_factory(view, handler):
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, request.get(KEY_REAL_IP), authenticated)
result = handler(request, **request.match_info)
try:
result = handler(request, **request.match_info)
if asyncio.iscoroutine(result):
result = await result
if asyncio.iscoroutine(result):
result = await result
except exceptions.Unauthorized:
raise HTTPUnauthorized()
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it

View file

@ -192,9 +192,9 @@ async def entity_service_call(hass, platforms, func, call):
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
perms = user.permissions
entity_perms = user.permissions.check_entity
else:
perms = None
entity_perms = None
# Are we trying to target all entities
target_all_entities = ATTR_ENTITY_ID not in call.data
@ -218,7 +218,7 @@ async def entity_service_call(hass, platforms, func, call):
# the service on.
platforms_entities = []
if perms is None:
if entity_perms is None:
for platform in platforms:
if target_all_entities:
platforms_entities.append(list(platform.entities.values()))
@ -234,7 +234,7 @@ async def entity_service_call(hass, platforms, func, call):
for platform in platforms:
platforms_entities.append([
entity for entity in platform.entities.values()
if perms.check_entity(entity.entity_id, POLICY_CONTROL)])
if entity_perms(entity.entity_id, POLICY_CONTROL)])
else:
for platform in platforms:
@ -243,7 +243,7 @@ async def entity_service_call(hass, platforms, func, call):
if entity.entity_id not in entity_ids:
continue
if not perms.check_entity(entity.entity_id, POLICY_CONTROL):
if not entity_perms(entity.entity_id, POLICY_CONTROL):
raise Unauthorized(
context=call.context,
entity_id=entity.entity_id,

View file

@ -10,7 +10,7 @@ def test_entities_none():
"""Test entity ID policy."""
policy = None
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is False
assert compiled('light.kitchen', 'read') is False
def test_entities_empty():
@ -18,7 +18,7 @@ def test_entities_empty():
policy = {}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is False
assert compiled('light.kitchen', 'read') is False
def test_entities_false():
@ -33,7 +33,7 @@ def test_entities_true():
policy = True
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('light.kitchen', 'read') is True
def test_entities_domains_true():
@ -43,7 +43,7 @@ def test_entities_domains_true():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('light.kitchen', 'read') is True
def test_entities_domains_domain_true():
@ -55,8 +55,8 @@ def test_entities_domains_domain_true():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('switch.kitchen', ('read',)) is False
assert compiled('light.kitchen', 'read') is True
assert compiled('switch.kitchen', 'read') is False
def test_entities_domains_domain_false():
@ -77,7 +77,7 @@ def test_entities_entity_ids_true():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('light.kitchen', 'read') is True
def test_entities_entity_ids_false():
@ -98,8 +98,8 @@ def test_entities_entity_ids_entity_id_true():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('switch.kitchen', ('read',)) is False
assert compiled('light.kitchen', 'read') is True
assert compiled('switch.kitchen', 'read') is False
def test_entities_entity_ids_entity_id_false():
@ -124,9 +124,9 @@ def test_entities_control_only():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('light.kitchen', ('control',)) is False
assert compiled('light.kitchen', ('edit',)) is False
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is False
assert compiled('light.kitchen', 'edit') is False
def test_entities_read_control():
@ -141,9 +141,9 @@ def test_entities_read_control():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('light.kitchen', ('control',)) is True
assert compiled('light.kitchen', ('edit',)) is False
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is True
assert compiled('light.kitchen', 'edit') is False
def test_entities_all_allow():
@ -153,9 +153,9 @@ def test_entities_all_allow():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('light.kitchen', ('control',)) is True
assert compiled('switch.kitchen', ('read',)) is True
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is True
assert compiled('switch.kitchen', 'read') is True
def test_entities_all_read():
@ -167,9 +167,9 @@ def test_entities_all_read():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is True
assert compiled('light.kitchen', ('control',)) is False
assert compiled('switch.kitchen', ('read',)) is True
assert compiled('light.kitchen', 'read') is True
assert compiled('light.kitchen', 'control') is False
assert compiled('switch.kitchen', 'read') is True
def test_entities_all_control():
@ -181,7 +181,7 @@ def test_entities_all_control():
}
ENTITY_POLICY_SCHEMA(policy)
compiled = compile_entities(policy)
assert compiled('light.kitchen', ('read',)) is False
assert compiled('light.kitchen', ('control',)) is True
assert compiled('switch.kitchen', ('read',)) is False
assert compiled('switch.kitchen', ('control',)) is True
assert compiled('light.kitchen', 'read') is False
assert compiled('light.kitchen', 'control') is True
assert compiled('switch.kitchen', 'read') is False
assert compiled('switch.kitchen', 'control') is True

View file

@ -1,34 +0,0 @@
"""Tests for the auth permission system."""
from homeassistant.core import State
from homeassistant.auth import permissions
def test_policy_perm_filter_states():
"""Test filtering entitites."""
states = [
State('light.kitchen', 'on'),
State('light.living_room', 'off'),
State('light.balcony', 'on'),
]
perm = permissions.PolicyPermissions({
'entities': {
'entity_ids': {
'light.kitchen': True,
'light.balcony': True,
}
}
})
filtered = perm.filter_states(states)
assert len(filtered) == 2
assert filtered == [states[0], states[2]]
def test_owner_permissions():
"""Test owner permissions access all."""
assert permissions.OwnerPermissions.check_entity('light.kitchen', 'write')
states = [
State('light.kitchen', 'on'),
State('light.living_room', 'off'),
State('light.balcony', 'on'),
]
assert permissions.OwnerPermissions.filter_states(states) == states

View file

@ -14,7 +14,8 @@ from contextlib import contextmanager
from homeassistant import auth, core as ha, config_entries
from homeassistant.auth import (
models as auth_models, auth_store, providers as auth_providers)
models as auth_models, auth_store, providers as auth_providers,
permissions as auth_permissions)
from homeassistant.auth.permissions import system_policies
from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config
@ -400,6 +401,10 @@ class MockUser(auth_models.User):
auth_mgr._store._users[self.id] = self
return self
def mock_policy(self, policy):
"""Mock a policy for a user."""
self._permissions = auth_permissions.PolicyPermissions(policy)
async def register_auth_provider(hass, config):
"""Register an auth provider."""

View file

@ -80,11 +80,10 @@ def hass_ws_client(aiohttp_client):
@pytest.fixture
def hass_access_token(hass):
def hass_access_token(hass, hass_admin_user):
"""Return an access token to access Home Assistant."""
user = MockUser().add_to_hass(hass)
refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(user, CLIENT_ID))
hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID))
yield hass.auth.async_create_access_token(refresh_token)

View file

@ -5,6 +5,7 @@ from unittest.mock import patch, Mock
import pytest
from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.setup import async_setup_component
from homeassistant.components.hassio import (
STORAGE_KEY, async_check_config)
@ -106,6 +107,8 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
)
assert hassio_user is not None
assert hassio_user.system_generated
assert len(hassio_user.groups) == 1
assert hassio_user.groups[0].id == GROUP_ID_ADMIN
for token in hassio_user.refresh_tokens.values():
if token.token == refresh_token:
break
@ -113,6 +116,31 @@ async def test_setup_api_push_api_data_default(hass, aioclient_mock,
assert False, 'refresh token not found'
async def test_setup_adds_admin_group_to_user(hass, aioclient_mock,
hass_storage):
"""Test setup with API push default data."""
# Create user without admin
user = await hass.auth.async_create_system_user('Hass.io')
assert not user.is_admin
await hass.auth.async_create_refresh_token(user)
hass_storage[STORAGE_KEY] = {
'data': {'hassio_user': user.id},
'key': STORAGE_KEY,
'version': 1
}
with patch.dict(os.environ, MOCK_ENVIRON), \
patch('homeassistant.auth.AuthManager.active', return_value=True):
result = await async_setup_component(hass, 'hassio', {
'http': {},
'hassio': {}
})
assert result
assert user.is_admin
async def test_setup_api_push_api_data_no_auth(hass, aioclient_mock,
hass_storage):
"""Test setup with API push default data."""

View file

@ -16,10 +16,12 @@ from tests.common import async_mock_service
@pytest.fixture
def mock_api_client(hass, aiohttp_client):
"""Start the Hass HTTP component."""
def mock_api_client(hass, aiohttp_client, hass_access_token):
"""Start the Hass HTTP component and return admin API client."""
hass.loop.run_until_complete(async_setup_component(hass, 'api', {}))
return hass.loop.run_until_complete(aiohttp_client(hass.http.app))
return hass.loop.run_until_complete(aiohttp_client(hass.http.app, headers={
'Authorization': 'Bearer {}'.format(hass_access_token)
}))
@asyncio.coroutine
@ -405,7 +407,8 @@ def _listen_count(hass):
return sum(hass.bus.async_listeners().values())
async def test_api_error_log(hass, aiohttp_client):
async def test_api_error_log(hass, aiohttp_client, hass_access_token,
hass_admin_user):
"""Test if we can fetch the error log."""
hass.data[DATA_LOGGING] = '/some/path'
await async_setup_component(hass, 'api', {
@ -416,7 +419,7 @@ async def test_api_error_log(hass, aiohttp_client):
client = await aiohttp_client(hass.http.app)
resp = await client.get(const.URL_API_ERROR_LOG)
# Verufy auth required
# Verify auth required
assert resp.status == 401
with patch(
@ -424,7 +427,7 @@ async def test_api_error_log(hass, aiohttp_client):
return_value=web.Response(status=200, text='Hello')
) as mock_file:
resp = await client.get(const.URL_API_ERROR_LOG, headers={
'x-ha-access': 'yolo'
'Authorization': 'Bearer {}'.format(hass_access_token)
})
assert len(mock_file.mock_calls) == 1
@ -432,6 +435,13 @@ async def test_api_error_log(hass, aiohttp_client):
assert resp.status == 200
assert await resp.text() == 'Hello'
# Verify we require admin user
hass_admin_user.groups = []
resp = await client.get(const.URL_API_ERROR_LOG, headers={
'Authorization': 'Bearer {}'.format(hass_access_token)
})
assert resp.status == 401
async def test_api_fire_event_context(hass, mock_api_client,
hass_access_token):
@ -494,3 +504,67 @@ async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
state = hass.states.get('light.kitchen')
assert state.context.user_id == refresh_token.user.id
async def test_event_stream_requires_admin(hass, mock_api_client,
hass_admin_user):
"""Test user needs to be admin to access event stream."""
hass_admin_user.groups = []
resp = await mock_api_client.get('/api/stream')
assert resp.status == 401
async def test_states_view_filters(hass, mock_api_client, hass_admin_user):
"""Test filtering only visible states."""
hass_admin_user.mock_policy({
'entities': {
'entity_ids': {
'test.entity': True
}
}
})
hass.states.async_set('test.entity', 'hello')
hass.states.async_set('test.not_visible_entity', 'invisible')
resp = await mock_api_client.get(const.URL_API_STATES)
assert resp.status == 200
json = await resp.json()
assert len(json) == 1
assert json[0]['entity_id'] == 'test.entity'
async def test_get_entity_state_read_perm(hass, mock_api_client,
hass_admin_user):
"""Test getting a state requires read permission."""
hass_admin_user.mock_policy({})
resp = await mock_api_client.get('/api/states/light.test')
assert resp.status == 401
async def test_post_entity_state_admin(hass, mock_api_client, hass_admin_user):
"""Test updating state requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/states/light.test')
assert resp.status == 401
async def test_delete_entity_state_admin(hass, mock_api_client,
hass_admin_user):
"""Test deleting entity requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.delete('/api/states/light.test')
assert resp.status == 401
async def test_post_event_admin(hass, mock_api_client, hass_admin_user):
"""Test sending event requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/events/state_changed')
assert resp.status == 401
async def test_rendering_template_admin(hass, mock_api_client,
hass_admin_user):
"""Test rendering a template requires admin."""
hass_admin_user.groups = []
resp = await mock_api_client.post('/api/template')
assert resp.status == 401