Cleanup http (#12424)

* Clean up HTTP component

* Clean up HTTP mock

* Remove unused import

* Fix test

* Lint
This commit is contained in:
Paulus Schoutsen 2018-02-15 13:06:14 -08:00 committed by Pascal Vizeli
parent ad8fe8a93a
commit f32911d036
28 changed files with 811 additions and 1014 deletions

View file

@ -14,7 +14,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.components.http import REQUIREMENTS # NOQA
from homeassistant.components.http import HomeAssistantWSGI
from homeassistant.components.http import HomeAssistantHTTP
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.deprecation import get_deprecated
import homeassistant.helpers.config_validation as cv
@ -86,7 +86,7 @@ def setup(hass, yaml_config):
"""Activate the emulated_hue component."""
config = Config(hass, yaml_config.get(DOMAIN, {}))
server = HomeAssistantWSGI(
server = HomeAssistantHTTP(
hass,
server_host=config.host_ip_addr,
server_port=config.listen_port,

View file

@ -17,7 +17,7 @@ import jinja2
import homeassistant.helpers.config_validation as cv
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import is_trusted_ip
from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.config import find_config_file, load_yaml_config_file
from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED
from homeassistant.core import callback
@ -490,7 +490,7 @@ class IndexView(HomeAssistantView):
panel_url = hass.data[DATA_PANELS][panel].webcomponent_url_es5
no_auth = '1'
if hass.config.api.api_password and not is_trusted_ip(request):
if hass.config.api.api_password and not request[KEY_AUTHENTICATED]:
# do not try to auto connect on load
no_auth = '0'

View file

@ -12,35 +12,28 @@ import os
import ssl
from aiohttp import web
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
import voluptuous as vol
from homeassistant.const import (
SERVER_PORT, CONTENT_TYPE_JSON, HTTP_HEADER_HA_AUTH,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,
HTTP_HEADER_X_REQUESTED_WITH)
SERVER_PORT, CONTENT_TYPE_JSON,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,)
from homeassistant.core import is_callback
import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem
import homeassistant.util as hass_util
from homeassistant.util.logging import HideSensitiveDataFilter
from .auth import auth_middleware
from .ban import ban_middleware
from .const import (
KEY_BANS_ENABLED, KEY_AUTHENTICATED, KEY_LOGIN_THRESHOLD,
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR)
from .auth import setup_auth
from .ban import setup_bans
from .cors import setup_cors
from .real_ip import setup_real_ip
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
from .static import (
CachingFileResponse, CachingStaticResource, staticresource_middleware)
from .util import get_real_ip
REQUIREMENTS = ['aiohttp_cors==0.6.0']
ALLOWED_CORS_HEADERS = [
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
HTTP_HEADER_HA_AUTH]
DOMAIN = 'http'
CONF_API_PASSWORD = 'api_password'
@ -127,7 +120,7 @@ def async_setup(hass, config):
logging.getLogger('aiohttp.access').addFilter(
HideSensitiveDataFilter(api_password))
server = HomeAssistantWSGI(
server = HomeAssistantHTTP(
hass,
server_host=server_host,
server_port=server_port,
@ -173,25 +166,29 @@ def async_setup(hass, config):
return True
class HomeAssistantWSGI(object):
"""WSGI server for Home Assistant."""
class HomeAssistantHTTP(object):
"""HTTP server for Home Assistant."""
def __init__(self, hass, api_password, ssl_certificate,
ssl_key, server_host, server_port, cors_origins,
use_x_forwarded_for, trusted_networks,
login_threshold, is_ban_enabled):
"""Initialize the WSGI Home Assistant server."""
middlewares = [auth_middleware, staticresource_middleware]
"""Initialize the HTTP Home Assistant server."""
app = self.app = web.Application(
middlewares=[staticresource_middleware])
# This order matters
setup_real_ip(app, use_x_forwarded_for)
if is_ban_enabled:
middlewares.insert(0, ban_middleware)
setup_bans(hass, app, login_threshold)
self.app = web.Application(middlewares=middlewares)
self.app['hass'] = hass
self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for
self.app[KEY_TRUSTED_NETWORKS] = trusted_networks
self.app[KEY_BANS_ENABLED] = is_ban_enabled
self.app[KEY_LOGIN_THRESHOLD] = login_threshold
setup_auth(app, trusted_networks, api_password)
if cors_origins:
setup_cors(app, cors_origins)
app['hass'] = hass
self.hass = hass
self.api_password = api_password
@ -199,21 +196,10 @@ class HomeAssistantWSGI(object):
self.ssl_key = ssl_key
self.server_host = server_host
self.server_port = server_port
self.is_ban_enabled = is_ban_enabled
self._handler = None
self.server = None
if cors_origins:
import aiohttp_cors
self.cors = aiohttp_cors.setup(self.app, defaults={
host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods='*',
) for host in cors_origins
})
else:
self.cors = None
def register_view(self, view):
"""Register a view with the WSGI server.
@ -292,15 +278,7 @@ class HomeAssistantWSGI(object):
@asyncio.coroutine
def start(self):
"""Start the WSGI server."""
cors_added = set()
if self.cors is not None:
for route in list(self.app.router.routes()):
if hasattr(route, 'resource'):
route = route.resource
if route in cors_added:
continue
self.cors.add(route)
cors_added.add(route)
yield from self.app.startup()
if self.ssl_certificate:
try:
@ -420,7 +398,7 @@ def request_handler_factory(view, handler):
raise HTTPUnauthorized()
_LOGGER.info('Serving %s to %s (auth: %s)',
request.path, get_real_ip(request), authenticated)
request.path, request.get(KEY_REAL_IP), authenticated)
result = handler(request, **request.match_info)

View file

@ -7,55 +7,66 @@ import logging
from aiohttp import hdrs
from aiohttp.web import middleware
from homeassistant.core import callback
from homeassistant.const import HTTP_HEADER_HA_AUTH
from .util import get_real_ip
from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
DATA_API_PASSWORD = 'api_password'
_LOGGER = logging.getLogger(__name__)
@middleware
@asyncio.coroutine
def auth_middleware(request, handler):
"""Authenticate as middleware."""
# If no password set, just always set authenticated=True
if request.app['hass'].http.api_password is None:
request[KEY_AUTHENTICATED] = True
@callback
def setup_auth(app, trusted_networks, api_password):
"""Create auth middleware for the app."""
@middleware
@asyncio.coroutine
def auth_middleware(request, handler):
"""Authenticate as middleware."""
# If no password set, just always set authenticated=True
if api_password is None:
request[KEY_AUTHENTICATED] = True
return (yield from handler(request))
# Check authentication
authenticated = False
if (HTTP_HEADER_HA_AUTH in request.headers and
hmac.compare_digest(
api_password, request.headers[HTTP_HEADER_HA_AUTH])):
# A valid auth header has been set
authenticated = True
elif (DATA_API_PASSWORD in request.query and
hmac.compare_digest(api_password,
request.query[DATA_API_PASSWORD])):
authenticated = True
elif (hdrs.AUTHORIZATION in request.headers and
validate_authorization_header(api_password, request)):
authenticated = True
elif _is_trusted_ip(request, trusted_networks):
authenticated = True
request[KEY_AUTHENTICATED] = authenticated
return (yield from handler(request))
# Check authentication
authenticated = False
@asyncio.coroutine
def auth_startup(app):
"""Initialize auth middleware when app starts up."""
app.middlewares.append(auth_middleware)
if (HTTP_HEADER_HA_AUTH in request.headers and
validate_password(
request, request.headers[HTTP_HEADER_HA_AUTH])):
# A valid auth header has been set
authenticated = True
elif (DATA_API_PASSWORD in request.query and
validate_password(request, request.query[DATA_API_PASSWORD])):
authenticated = True
elif (hdrs.AUTHORIZATION in request.headers and
validate_authorization_header(request)):
authenticated = True
elif is_trusted_ip(request):
authenticated = True
request[KEY_AUTHENTICATED] = authenticated
return (yield from handler(request))
app.on_startup.append(auth_startup)
def is_trusted_ip(request):
def _is_trusted_ip(request, trusted_networks):
"""Test if request is from a trusted ip."""
ip_addr = get_real_ip(request)
ip_addr = request[KEY_REAL_IP]
return ip_addr and any(
return any(
ip_addr in trusted_network for trusted_network
in request.app[KEY_TRUSTED_NETWORKS])
in trusted_networks)
def validate_password(request, api_password):
@ -64,7 +75,7 @@ def validate_password(request, api_password):
api_password, request.app['hass'].http.api_password)
def validate_authorization_header(request):
def validate_authorization_header(api_password, request):
"""Test an authorization header if valid password."""
if hdrs.AUTHORIZATION not in request.headers:
return False
@ -80,4 +91,4 @@ def validate_authorization_header(request):
if username != 'homeassistant':
return False
return validate_password(request, password)
return hmac.compare_digest(api_password, password)

View file

@ -10,18 +10,20 @@ from aiohttp.web import middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.util.yaml import dump
from .const import (
KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD,
KEY_FAILED_LOGIN_ATTEMPTS)
from .util import get_real_ip
from .const import KEY_REAL_IP
_LOGGER = logging.getLogger(__name__)
KEY_BANNED_IPS = 'ha_banned_ips'
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
KEY_LOGIN_THRESHOLD = 'ha_login_threshold'
NOTIFICATION_ID_BAN = 'ip-ban'
NOTIFICATION_ID_LOGIN = 'http-login'
@ -33,21 +35,31 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
})
@callback
def setup_bans(hass, app, login_threshold):
"""Create IP Ban middleware for the app."""
@asyncio.coroutine
def ban_startup(app):
"""Initialize bans when app starts up."""
app.middlewares.append(ban_middleware)
app[KEY_BANNED_IPS] = yield from hass.async_add_job(
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
app[KEY_LOGIN_THRESHOLD] = login_threshold
app.on_startup.append(ban_startup)
@middleware
@asyncio.coroutine
def ban_middleware(request, handler):
"""IP Ban middleware."""
if not request.app[KEY_BANS_ENABLED]:
if KEY_BANNED_IPS not in request.app:
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
return (yield from handler(request))
if KEY_BANNED_IPS not in request.app:
hass = request.app['hass']
request.app[KEY_BANNED_IPS] = yield from hass.async_add_job(
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
# Verify if IP is not banned
ip_address_ = get_real_ip(request)
ip_address_ = request[KEY_REAL_IP]
is_banned = any(ip_ban.ip_address == ip_address_
for ip_ban in request.app[KEY_BANNED_IPS])
@ -64,7 +76,7 @@ def ban_middleware(request, handler):
@asyncio.coroutine
def process_wrong_login(request):
"""Process a wrong login attempt."""
remote_addr = get_real_ip(request)
remote_addr = request[KEY_REAL_IP]
msg = ('Login attempt or request with invalid authentication '
'from {}'.format(remote_addr))
@ -73,13 +85,11 @@ def process_wrong_login(request):
request.app['hass'], msg, 'Login attempt failed',
NOTIFICATION_ID_LOGIN)
if (not request.app[KEY_BANS_ENABLED] or
# Check if ban middleware is loaded
if (KEY_BANNED_IPS not in request.app or
request.app[KEY_LOGIN_THRESHOLD] < 1):
return
if KEY_FAILED_LOGIN_ATTEMPTS not in request.app:
request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >

View file

@ -1,11 +1,3 @@
"""HTTP specific constants."""
KEY_AUTHENTICATED = 'ha_authenticated'
KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for'
KEY_TRUSTED_NETWORKS = 'ha_trusted_networks'
KEY_REAL_IP = 'ha_real_ip'
KEY_BANS_ENABLED = 'ha_bans_enabled'
KEY_BANNED_IPS = 'ha_banned_ips'
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
KEY_LOGIN_THRESHOLD = 'ha_login_threshold'
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'

View file

@ -0,0 +1,43 @@
"""Provide cors support for the HTTP component."""
import asyncio
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
from homeassistant.const import (
HTTP_HEADER_X_REQUESTED_WITH, HTTP_HEADER_HA_AUTH)
from homeassistant.core import callback
ALLOWED_CORS_HEADERS = [
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
HTTP_HEADER_HA_AUTH]
@callback
def setup_cors(app, origins):
"""Setup cors."""
import aiohttp_cors
cors = aiohttp_cors.setup(app, defaults={
host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods='*',
) for host in origins
})
@asyncio.coroutine
def cors_startup(app):
"""Initialize cors when app starts up."""
cors_added = set()
for route in list(app.router.routes()):
if hasattr(route, 'resource'):
route = route.resource
if route in cors_added:
continue
cors.add(route)
cors_added.add(route)
app.on_startup.append(cors_startup)

View file

@ -0,0 +1,35 @@
"""Middleware to fetch real IP."""
import asyncio
from ipaddress import ip_address
from aiohttp.web import middleware
from aiohttp.hdrs import X_FORWARDED_FOR
from homeassistant.core import callback
from .const import KEY_REAL_IP
@callback
def setup_real_ip(app, use_x_forwarded_for):
"""Create IP Ban middleware for the app."""
@middleware
@asyncio.coroutine
def real_ip_middleware(request, handler):
"""Real IP middleware."""
if (use_x_forwarded_for and
X_FORWARDED_FOR in request.headers):
request[KEY_REAL_IP] = ip_address(
request.headers.get(X_FORWARDED_FOR).split(',')[0])
else:
request[KEY_REAL_IP] = \
ip_address(request.transport.get_extra_info('peername')[0])
return (yield from handler(request))
@asyncio.coroutine
def app_startup(app):
"""Initialize bans when app starts up."""
app.middlewares.append(real_ip_middleware)
app.on_startup.append(app_startup)

View file

@ -1,25 +0,0 @@
"""HTTP utilities."""
from ipaddress import ip_address
from .const import (
KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
def get_real_ip(request):
"""Get IP address of client."""
if KEY_REAL_IP in request:
return request[KEY_REAL_IP]
if (request.app.get(KEY_USE_X_FORWARDED_FOR) and
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
request[KEY_REAL_IP] = ip_address(
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
else:
peername = request.transport.get_extra_info('peername')
if peername:
request[KEY_REAL_IP] = ip_address(peername[0])
else:
request[KEY_REAL_IP] = None
return request[KEY_REAL_IP]

View file

@ -12,7 +12,7 @@ import logging
import voluptuous as vol
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.util import get_real_ip
from homeassistant.components.http.const import KEY_REAL_IP
from homeassistant.components.telegram_bot import (
CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, PLATFORM_SCHEMA)
from homeassistant.const import (
@ -110,7 +110,7 @@ class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity):
@asyncio.coroutine
def post(self, request):
"""Accept the POST from telegram."""
real_ip = get_real_ip(request)
real_ip = request[KEY_REAL_IP]
if not any(real_ip in net for net in self.trusted_networks):
_LOGGER.warning("Access denied from %s", real_ip)
return self.json_message('Access denied', HTTP_UNAUTHORIZED)

View file

@ -9,8 +9,6 @@ import logging
import threading
from contextlib import contextmanager
from aiohttp import web
from homeassistant import core as ha, loader
from homeassistant.setup import setup_component, async_setup_component
from homeassistant.config import async_process_component_config
@ -25,9 +23,6 @@ from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE)
from homeassistant.components import mqtt, recorder
from homeassistant.components.http.auth import auth_middleware
from homeassistant.components.http.const import (
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS)
from homeassistant.util.async import (
run_callback_threadsafe, run_coroutine_threadsafe)
@ -262,35 +257,6 @@ def mock_state_change_event(hass, new_state, old_state=None):
hass.bus.fire(EVENT_STATE_CHANGED, event_data)
def mock_http_component(hass, api_password=None):
"""Mock the HTTP component."""
hass.http = MagicMock(api_password=api_password)
mock_component(hass, 'http')
hass.http.views = {}
def mock_register_view(view):
"""Store registered view."""
if isinstance(view, type):
# Instantiate the view, if needed
view = view()
hass.http.views[view.name] = view
hass.http.register_view = mock_register_view
def mock_http_component_app(hass, api_password=None):
"""Create an aiohttp.web.Application instance for testing."""
if 'http' not in hass.config.components:
mock_http_component(hass, api_password)
app = web.Application(middlewares=[auth_middleware])
app['hass'] = hass
app[KEY_USE_X_FORWARDED_FOR] = False
app[KEY_BANS_ENABLED] = False
app[KEY_TRUSTED_NETWORKS] = []
return app
@asyncio.coroutine
def async_mock_mqtt_component(hass, config=None):
"""Mock the MQTT component."""

View file

@ -9,7 +9,7 @@ from uvcclient import nvr
from homeassistant.setup import setup_component
from homeassistant.components.camera import uvc
from tests.common import get_test_home_assistant, mock_http_component
from tests.common import get_test_home_assistant
class TestUVCSetup(unittest.TestCase):
@ -18,7 +18,6 @@ class TestUVCSetup(unittest.TestCase):
def setUp(self):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_http_component(self.hass)
def tearDown(self):
"""Stop everything that was started."""

View file

@ -14,7 +14,7 @@ def test_setup_check_env_prevents_load(hass, loop):
with patch.dict(os.environ, clear=True), \
patch.object(config, 'SECTIONS', ['hassbian']), \
patch('homeassistant.components.http.'
'HomeAssistantWSGI.register_view') as reg_view:
'HomeAssistantHTTP.register_view') as reg_view:
loop.run_until_complete(async_setup_component(hass, 'config', {}))
assert 'config' in hass.config.components
assert reg_view.called is False
@ -25,7 +25,7 @@ def test_setup_check_env_works(hass, loop):
with patch.dict(os.environ, {'FORCE_HASSBIAN': '1'}), \
patch.object(config, 'SECTIONS', ['hassbian']), \
patch('homeassistant.components.http.'
'HomeAssistantWSGI.register_view') as reg_view:
'HomeAssistantHTTP.register_view') as reg_view:
loop.run_until_complete(async_setup_component(hass, 'config', {}))
assert 'config' in hass.config.components
assert len(reg_view.mock_calls) == 2

View file

@ -2,19 +2,11 @@
import asyncio
from unittest.mock import patch
import pytest
from homeassistant.const import EVENT_COMPONENT_LOADED
from homeassistant.setup import async_setup_component, ATTR_COMPONENT
from homeassistant.components import config
from tests.common import mock_http_component, mock_coro, mock_component
@pytest.fixture(autouse=True)
def stub_http(hass):
"""Stub the HTTP component."""
mock_http_component(hass)
from tests.common import mock_coro, mock_component
@asyncio.coroutine

View file

@ -3,28 +3,30 @@ import asyncio
import json
from unittest.mock import MagicMock, patch
import pytest
from homeassistant.bootstrap import async_setup_component
from homeassistant.components import config
from homeassistant.components.zwave import DATA_NETWORK, const
from homeassistant.components.config.zwave import (
ZWaveNodeValueView, ZWaveNodeGroupView, ZWaveNodeConfigView,
ZWaveUserCodeView, ZWaveConfigWriteView)
from tests.common import mock_http_component_app
from tests.mock.zwave import MockNode, MockValue, MockEntityValues
VIEW_NAME = 'api:config:zwave:device_config'
@asyncio.coroutine
def test_get_device_config(hass, test_client):
"""Test getting device config."""
@pytest.fixture
def client(loop, hass, test_client):
"""Client to communicate with Z-Wave config views."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
loop.run_until_complete(async_setup_component(hass, 'config', {}))
client = yield from test_client(hass.http.app)
return loop.run_until_complete(test_client(hass.http.app))
@asyncio.coroutine
def test_get_device_config(client):
"""Test getting device config."""
def mock_read(path):
"""Mock reading data."""
return {
@ -47,13 +49,8 @@ def test_get_device_config(hass, test_client):
@asyncio.coroutine
def test_update_device_config(hass, test_client):
def test_update_device_config(client):
"""Test updating device config."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
client = yield from test_client(hass.http.app)
orig_data = {
'hello.beer': {
'ignored': True,
@ -90,13 +87,8 @@ def test_update_device_config(hass, test_client):
@asyncio.coroutine
def test_update_device_config_invalid_key(hass, test_client):
def test_update_device_config_invalid_key(client):
"""Test updating device config."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
client = yield from test_client(hass.http.app)
resp = yield from client.post(
'/api/config/zwave/device_config/invalid_entity', data=json.dumps({
'polling_intensity': 2
@ -106,13 +98,8 @@ def test_update_device_config_invalid_key(hass, test_client):
@asyncio.coroutine
def test_update_device_config_invalid_data(hass, test_client):
def test_update_device_config_invalid_data(client):
"""Test updating device config."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
client = yield from test_client(hass.http.app)
resp = yield from client.post(
'/api/config/zwave/device_config/hello.beer', data=json.dumps({
'invalid_option': 2
@ -122,13 +109,8 @@ def test_update_device_config_invalid_data(hass, test_client):
@asyncio.coroutine
def test_update_device_config_invalid_json(hass, test_client):
def test_update_device_config_invalid_json(client):
"""Test updating device config."""
with patch.object(config, 'SECTIONS', ['zwave']):
yield from async_setup_component(hass, 'config', {})
client = yield from test_client(hass.http.app)
resp = yield from client.post(
'/api/config/zwave/device_config/hello.beer', data='not json')
@ -136,11 +118,8 @@ def test_update_device_config_invalid_json(hass, test_client):
@asyncio.coroutine
def test_get_values(hass, test_client):
def test_get_values(hass, client):
"""Test getting values on node."""
app = mock_http_component_app(hass)
ZWaveNodeValueView().register(app.router)
node = MockNode(node_id=1)
value = MockValue(value_id=123456, node=node, label='Test Label',
instance=1, index=2, poll_intensity=4)
@ -150,8 +129,6 @@ def test_get_values(hass, test_client):
values2 = MockEntityValues(primary=value2)
hass.data[const.DATA_ENTITY_VALUES] = [values, values2]
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/values/1')
assert resp.status == 200
@ -168,11 +145,8 @@ def test_get_values(hass, test_client):
@asyncio.coroutine
def test_get_groups(hass, test_client):
def test_get_groups(hass, client):
"""Test getting groupdata on node."""
app = mock_http_component_app(hass)
ZWaveNodeGroupView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=2)
node.groups.associations = 'assoc'
@ -182,8 +156,6 @@ def test_get_groups(hass, test_client):
node.groups = {1: node.groups}
network.nodes = {2: node}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/groups/2')
assert resp.status == 200
@ -200,18 +172,13 @@ def test_get_groups(hass, test_client):
@asyncio.coroutine
def test_get_groups_nogroups(hass, test_client):
def test_get_groups_nogroups(hass, client):
"""Test getting groupdata on node with no groups."""
app = mock_http_component_app(hass)
ZWaveNodeGroupView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=2)
network.nodes = {2: node}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/groups/2')
assert resp.status == 200
@ -221,16 +188,11 @@ def test_get_groups_nogroups(hass, test_client):
@asyncio.coroutine
def test_get_groups_nonode(hass, test_client):
def test_get_groups_nonode(hass, client):
"""Test getting groupdata on nonexisting node."""
app = mock_http_component_app(hass)
ZWaveNodeGroupView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
network.nodes = {1: 1, 5: 5}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/groups/2')
assert resp.status == 404
@ -240,11 +202,8 @@ def test_get_groups_nonode(hass, test_client):
@asyncio.coroutine
def test_get_config(hass, test_client):
def test_get_config(hass, client):
"""Test getting config on node."""
app = mock_http_component_app(hass)
ZWaveNodeConfigView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=2)
value = MockValue(
@ -261,8 +220,6 @@ def test_get_config(hass, test_client):
network.nodes = {2: node}
node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/config/2')
assert resp.status == 200
@ -278,19 +235,14 @@ def test_get_config(hass, test_client):
@asyncio.coroutine
def test_get_config_noconfig_node(hass, test_client):
def test_get_config_noconfig_node(hass, client):
"""Test getting config on node without config."""
app = mock_http_component_app(hass)
ZWaveNodeConfigView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=2)
network.nodes = {2: node}
node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/config/2')
assert resp.status == 200
@ -300,16 +252,11 @@ def test_get_config_noconfig_node(hass, test_client):
@asyncio.coroutine
def test_get_config_nonode(hass, test_client):
def test_get_config_nonode(hass, client):
"""Test getting config on nonexisting node."""
app = mock_http_component_app(hass)
ZWaveNodeConfigView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
network.nodes = {1: 1, 5: 5}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/config/2')
assert resp.status == 404
@ -319,16 +266,11 @@ def test_get_config_nonode(hass, test_client):
@asyncio.coroutine
def test_get_usercodes_nonode(hass, test_client):
def test_get_usercodes_nonode(hass, client):
"""Test getting usercodes on nonexisting node."""
app = mock_http_component_app(hass)
ZWaveUserCodeView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
network.nodes = {1: 1, 5: 5}
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/usercodes/2')
assert resp.status == 404
@ -338,11 +280,8 @@ def test_get_usercodes_nonode(hass, test_client):
@asyncio.coroutine
def test_get_usercodes(hass, test_client):
def test_get_usercodes(hass, client):
"""Test getting usercodes on node."""
app = mock_http_component_app(hass)
ZWaveUserCodeView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=18,
command_classes=[const.COMMAND_CLASS_USER_CODE])
@ -356,8 +295,6 @@ def test_get_usercodes(hass, test_client):
network.nodes = {18: node}
node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/usercodes/18')
assert resp.status == 200
@ -369,19 +306,14 @@ def test_get_usercodes(hass, test_client):
@asyncio.coroutine
def test_get_usercode_nousercode_node(hass, test_client):
def test_get_usercode_nousercode_node(hass, client):
"""Test getting usercodes on node without usercodes."""
app = mock_http_component_app(hass)
ZWaveUserCodeView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=18)
network.nodes = {18: node}
node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/usercodes/18')
assert resp.status == 200
@ -391,11 +323,8 @@ def test_get_usercode_nousercode_node(hass, test_client):
@asyncio.coroutine
def test_get_usercodes_no_genreuser(hass, test_client):
def test_get_usercodes_no_genreuser(hass, client):
"""Test getting usercodes on node missing genre user."""
app = mock_http_component_app(hass)
ZWaveUserCodeView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
node = MockNode(node_id=18,
command_classes=[const.COMMAND_CLASS_USER_CODE])
@ -409,8 +338,6 @@ def test_get_usercodes_no_genreuser(hass, test_client):
network.nodes = {18: node}
node.get_values.return_value = node.values
client = yield from test_client(app)
resp = yield from client.get('/api/zwave/usercodes/18')
assert resp.status == 200
@ -420,13 +347,8 @@ def test_get_usercodes_no_genreuser(hass, test_client):
@asyncio.coroutine
def test_save_config_no_network(hass, test_client):
def test_save_config_no_network(hass, client):
"""Test saving configuration without network data."""
app = mock_http_component_app(hass)
ZWaveConfigWriteView().register(app.router)
client = yield from test_client(app)
resp = yield from client.post('/api/zwave/saveconfig')
assert resp.status == 404
@ -435,15 +357,10 @@ def test_save_config_no_network(hass, test_client):
@asyncio.coroutine
def test_save_config(hass, test_client):
def test_save_config(hass, client):
"""Test saving configuration."""
app = mock_http_component_app(hass)
ZWaveConfigWriteView().register(app.router)
network = hass.data[DATA_NETWORK] = MagicMock()
client = yield from test_client(app)
resp = yield from client.post('/api/zwave/saveconfig')
assert resp.status == 200

View file

@ -5,11 +5,10 @@ import logging
from unittest.mock import patch, MagicMock
import aioautomatic
from homeassistant.setup import async_setup_component
from homeassistant.components.device_tracker.automatic import (
async_setup_scanner)
from tests.common import mock_http_component
_LOGGER = logging.getLogger(__name__)
@ -23,8 +22,7 @@ def test_invalid_credentials(
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
mock_create_session, hass):
"""Test with invalid credentials."""
mock_http_component(hass)
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
mock_json_load.return_value = {'refresh_token': 'bad_token'}
@asyncio.coroutine
@ -59,8 +57,7 @@ def test_valid_credentials(
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
mock_ws_connect, mock_create_session, hass):
"""Test with valid credentials."""
mock_http_component(hass)
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
mock_json_load.return_value = {'refresh_token': 'good_token'}
session = MagicMock()

View file

@ -1 +1,38 @@
"""Tests for the HTTP component."""
import asyncio
from ipaddress import ip_address
from aiohttp import web
from homeassistant.components.http.const import KEY_REAL_IP
def mock_real_ip(app):
"""Inject middleware to mock real IP.
Returns a function to set the real IP.
"""
ip_to_mock = None
def set_ip_to_mock(value):
nonlocal ip_to_mock
ip_to_mock = value
@asyncio.coroutine
@web.middleware
def mock_real_ip(request, handler):
"""Mock Real IP middleware."""
nonlocal ip_to_mock
request[KEY_REAL_IP] = ip_address(ip_to_mock)
return (yield from handler(request))
@asyncio.coroutine
def real_ip_startup(app):
"""Startup of real ip."""
app.middlewares.insert(0, mock_real_ip)
app.on_startup.append(real_ip_startup)
return set_ip_to_mock

View file

@ -1,195 +1,156 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
import asyncio
from ipaddress import ip_address, ip_network
from ipaddress import ip_network
from unittest.mock import patch
import aiohttp
from aiohttp import BasicAuth, web
from aiohttp.web_exceptions import HTTPUnauthorized
import pytest
from homeassistant import const
from homeassistant.const import HTTP_HEADER_HA_AUTH
from homeassistant.setup import async_setup_component
import homeassistant.components.http as http
from homeassistant.components.http.const import (
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
from homeassistant.components.http.auth import setup_auth
from homeassistant.components.http.real_ip import setup_real_ip
from homeassistant.components.http.const import KEY_AUTHENTICATED
from . import mock_real_ip
API_PASSWORD = 'test1234'
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
'FD01:DB8::1']
TRUSTED_NETWORKS = [
ip_network('192.0.2.0/24'),
ip_network('2001:DB8:ABCD::/48'),
ip_network('100.64.0.1'),
ip_network('FD01:DB8::1'),
]
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
'2001:DB8:ABCD::1']
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
@pytest.fixture
def mock_api_client(hass, test_client):
"""Start the Hass HTTP component."""
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
'http': {
http.CONF_API_PASSWORD: API_PASSWORD,
}
}))
return hass.loop.run_until_complete(test_client(hass.http.app))
@asyncio.coroutine
def mock_handler(request):
"""Return if request was authenticated."""
if not request[KEY_AUTHENTICATED]:
raise HTTPUnauthorized
return web.Response(status=200)
@pytest.fixture
def mock_trusted_networks(hass, mock_api_client):
"""Mock trusted networks."""
hass.http.app[KEY_TRUSTED_NETWORKS] = [
ip_network(trusted_network)
for trusted_network in TRUSTED_NETWORKS]
def app():
"""Fixture to setup a web.Application."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_real_ip(app, False)
return app
@asyncio.coroutine
def test_access_denied_without_password(mock_api_client):
def test_auth_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_auth') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def test_access_without_password(app, test_client):
"""Test access without password."""
resp = yield from mock_api_client.get(const.URL_API)
setup_auth(app, [], None)
client = yield from test_client(app)
resp = yield from client.get('/')
assert resp.status == 200
@asyncio.coroutine
def test_access_with_password_in_header(app, test_client):
"""Test access with password in URL."""
setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app)
req = yield from client.get(
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status == 200
req = yield from client.get(
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
assert req.status == 401
@asyncio.coroutine
def test_access_with_password_in_query(app, test_client):
"""Test access without password."""
setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app)
resp = yield from client.get('/', params={
'api_password': API_PASSWORD
})
assert resp.status == 200
resp = yield from client.get('/')
assert resp.status == 401
@asyncio.coroutine
def test_access_denied_with_wrong_password_in_header(mock_api_client):
"""Test access with wrong password."""
resp = yield from mock_api_client.get(const.URL_API, headers={
const.HTTP_HEADER_HA_AUTH: 'wrongpassword'
resp = yield from client.get('/', params={
'api_password': 'wrong-password'
})
assert resp.status == 401
@asyncio.coroutine
def test_access_denied_with_x_forwarded_for(hass, mock_api_client,
mock_trusted_networks):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.use_x_forwarded_for = True
for remote_addr in UNTRUSTED_ADDRESSES:
resp = yield from mock_api_client.get(const.URL_API, headers={
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert resp.status == 401, \
"{} shouldn't be trusted".format(remote_addr)
@asyncio.coroutine
def test_access_denied_with_untrusted_ip(mock_api_client,
mock_trusted_networks):
"""Test access with an untrusted ip address."""
for remote_addr in UNTRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'util.get_real_ip',
return_value=ip_address(remote_addr)):
resp = yield from mock_api_client.get(
const.URL_API, params={'api_password': ''})
assert resp.status == 401, \
"{} shouldn't be trusted".format(remote_addr)
@asyncio.coroutine
def test_access_with_password_in_header(mock_api_client, caplog):
"""Test access with password in URL."""
# Hide logging from requests package that we use to test logging
req = yield from mock_api_client.get(
const.URL_API, headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status == 200
logs = caplog.text
assert const.URL_API in logs
assert API_PASSWORD not in logs
@asyncio.coroutine
def test_access_denied_with_wrong_password_in_url(mock_api_client):
"""Test access with wrong password."""
resp = yield from mock_api_client.get(
const.URL_API, params={'api_password': 'wrongpassword'})
assert resp.status == 401
@asyncio.coroutine
def test_access_with_password_in_url(mock_api_client, caplog):
"""Test access with password in URL."""
req = yield from mock_api_client.get(
const.URL_API, params={'api_password': API_PASSWORD})
assert req.status == 200
logs = caplog.text
assert const.URL_API in logs
assert API_PASSWORD not in logs
@asyncio.coroutine
def test_access_granted_with_x_forwarded_for(hass, mock_api_client, caplog,
mock_trusted_networks):
"""Test access denied through the X-Forwarded-For http header."""
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
for remote_addr in TRUSTED_ADDRESSES:
resp = yield from mock_api_client.get(const.URL_API, headers={
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
assert resp.status == 200, \
"{} should be trusted".format(remote_addr)
@asyncio.coroutine
def test_access_granted_with_trusted_ip(mock_api_client, caplog,
mock_trusted_networks):
"""Test access with trusted addresses."""
for remote_addr in TRUSTED_ADDRESSES:
with patch('homeassistant.components.http.'
'auth.get_real_ip',
return_value=ip_address(remote_addr)):
resp = yield from mock_api_client.get(
const.URL_API, params={'api_password': ''})
assert resp.status == 200, \
'{} should be trusted'.format(remote_addr)
@asyncio.coroutine
def test_basic_auth_works(mock_api_client, caplog):
def test_basic_auth_works(app, test_client):
"""Test access with basic authentication."""
req = yield from mock_api_client.get(
const.URL_API,
auth=aiohttp.BasicAuth('homeassistant', API_PASSWORD))
setup_auth(app, [], API_PASSWORD)
client = yield from test_client(app)
req = yield from client.get(
'/',
auth=BasicAuth('homeassistant', API_PASSWORD))
assert req.status == 200
assert const.URL_API in caplog.text
@asyncio.coroutine
def test_basic_auth_username_homeassistant(mock_api_client, caplog):
"""Test access with basic auth requires username homeassistant."""
req = yield from mock_api_client.get(
const.URL_API,
auth=aiohttp.BasicAuth('wrong_username', API_PASSWORD))
req = yield from client.get(
'/',
auth=BasicAuth('wrong_username', API_PASSWORD))
assert req.status == 401
@asyncio.coroutine
def test_basic_auth_wrong_password(mock_api_client, caplog):
"""Test access with basic auth not allowed with wrong password."""
req = yield from mock_api_client.get(
const.URL_API,
auth=aiohttp.BasicAuth('homeassistant', 'wrong password'))
req = yield from client.get(
'/',
auth=BasicAuth('homeassistant', 'wrong password'))
assert req.status == 401
@asyncio.coroutine
def test_authorization_header_must_be_basic_type(mock_api_client, caplog):
"""Test only basic authorization is allowed for auth header."""
req = yield from mock_api_client.get(
const.URL_API,
req = yield from client.get(
'/',
headers={
'authorization': 'NotBasic abcdefg'
})
assert req.status == 401
@asyncio.coroutine
def test_access_with_trusted_ip(test_client):
"""Test access with an untrusted ip address."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
set_mock_ip = mock_real_ip(app)
client = yield from test_client(app)
for remote_addr in UNTRUSTED_ADDRESSES:
set_mock_ip(remote_addr)
resp = yield from client.get('/')
assert resp.status == 401, \
"{} shouldn't be trusted".format(remote_addr)
for remote_addr in TRUSTED_ADDRESSES:
set_mock_ip(remote_addr)
resp = yield from client.get('/')
assert resp.status == 200, \
"{} should be trusted".format(remote_addr)

View file

@ -1,91 +1,96 @@
"""The tests for the Home Assistant HTTP component."""
# pylint: disable=protected-access
import asyncio
from ipaddress import ip_address
from unittest.mock import patch, mock_open
import pytest
from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized
from homeassistant import const
from homeassistant.setup import async_setup_component
import homeassistant.components.http as http
from homeassistant.components.http.const import (
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS)
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
from homeassistant.components.http.ban import (
IpBan, IP_BANS_FILE, setup_bans, KEY_BANNED_IPS)
from . import mock_real_ip
API_PASSWORD = 'test1234'
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
@pytest.fixture
def mock_api_client(hass, test_client):
"""Start the Hass HTTP component."""
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
'http': {
http.CONF_API_PASSWORD: API_PASSWORD,
}
}))
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
in BANNED_IPS]
return hass.loop.run_until_complete(test_client(hass.http.app))
@asyncio.coroutine
def test_access_from_banned_ip(hass, mock_api_client):
def test_access_from_banned_ip(hass, test_client):
"""Test accessing to server from banned IP. Both trusted and not."""
hass.http.app[KEY_BANS_ENABLED] = True
app = web.Application()
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
with patch('homeassistant.components.http.ban.load_ip_bans_config',
return_value=[IpBan(banned_ip) for banned_ip
in BANNED_IPS]):
client = yield from test_client(app)
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address(remote_addr)):
resp = yield from mock_api_client.get(
const.URL_API)
assert resp.status == 403
set_real_ip(remote_addr)
resp = yield from client.get('/')
assert resp.status == 403
@asyncio.coroutine
def test_access_from_banned_ip_when_ban_is_off(hass, mock_api_client):
def test_ban_middleware_not_loaded_by_config(hass):
"""Test accessing to server from banned IP when feature is off."""
hass.http.app[KEY_BANS_ENABLED] = False
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address(remote_addr)):
resp = yield from mock_api_client.get(
const.URL_API,
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert resp.status == 200
with patch('homeassistant.components.http.setup_bans') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {
http.CONF_IP_BAN_ENABLED: False,
}
})
assert len(mock_setup.mock_calls) == 0
@asyncio.coroutine
def test_ip_bans_file_creation(hass, mock_api_client):
def test_ban_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_bans') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def test_ip_bans_file_creation(hass, test_client):
"""Testing if banned IP file created."""
hass.http.app[KEY_BANS_ENABLED] = True
hass.http.app[KEY_LOGIN_THRESHOLD] = 1
app = web.Application()
app['hass'] = hass
@asyncio.coroutine
def unauth_handler(request):
"""Return a mock web response."""
raise HTTPUnauthorized
app.router.add_get('/', unauth_handler)
setup_bans(hass, app, 1)
mock_real_ip(app)("200.201.202.204")
with patch('homeassistant.components.http.ban.load_ip_bans_config',
return_value=[IpBan(banned_ip) for banned_ip
in BANNED_IPS]):
client = yield from test_client(app)
m = mock_open()
@asyncio.coroutine
def call_server():
with patch('homeassistant.components.http.'
'ban.get_real_ip',
return_value=ip_address("200.201.202.204")):
resp = yield from mock_api_client.get(
const.URL_API,
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
return resp
with patch('homeassistant.components.http.ban.open', m, create=True):
resp = yield from call_server()
resp = yield from client.get('/')
assert resp.status == 401
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS)
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
assert m.call_count == 0
resp = yield from call_server()
resp = yield from client.get('/')
assert resp.status == 401
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
resp = yield from call_server()
resp = yield from client.get('/')
assert resp.status == 403
assert m.call_count == 1

View file

@ -0,0 +1,104 @@
"""Test cors for the HTTP component."""
import asyncio
from unittest.mock import patch
from aiohttp import web
from aiohttp.hdrs import (
ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_REQUEST_HEADERS,
ACCESS_CONTROL_REQUEST_METHOD,
ORIGIN
)
import pytest
from homeassistant.const import HTTP_HEADER_HA_AUTH
from homeassistant.setup import async_setup_component
from homeassistant.components.http.cors import setup_cors
TRUSTED_ORIGIN = 'https://home-assistant.io'
@asyncio.coroutine
def test_cors_middleware_not_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {}
})
assert len(mock_setup.mock_calls) == 0
@asyncio.coroutine
def test_cors_middleware_loaded_from_config(hass):
"""Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup:
yield from async_setup_component(hass, 'http', {
'http': {
'cors_allowed_origins': ['http://home-assistant.io']
}
})
assert len(mock_setup.mock_calls) == 1
@asyncio.coroutine
def mock_handler(request):
"""Return if request was authenticated."""
return web.Response(status=200)
@pytest.fixture
def client(loop, test_client):
"""Fixture to setup a web.Application."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_cors(app, [TRUSTED_ORIGIN])
return loop.run_until_complete(test_client(app))
@asyncio.coroutine
def test_cors_requests(client):
"""Test cross origin requests."""
req = yield from client.get('/', headers={
ORIGIN: TRUSTED_ORIGIN
})
assert req.status == 200
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
TRUSTED_ORIGIN
# With password in URL
req = yield from client.get('/', params={
'api_password': 'some-pass'
}, headers={
ORIGIN: TRUSTED_ORIGIN
})
assert req.status == 200
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
TRUSTED_ORIGIN
# With password in headers
req = yield from client.get('/', headers={
HTTP_HEADER_HA_AUTH: 'some-pass',
ORIGIN: TRUSTED_ORIGIN
})
assert req.status == 200
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
TRUSTED_ORIGIN
@asyncio.coroutine
def test_cors_preflight_allowed(client):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
req = yield from client.options('/', headers={
ORIGIN: TRUSTED_ORIGIN,
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
})
assert req.status == 200
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == TRUSTED_ORIGIN
assert req.headers[ACCESS_CONTROL_ALLOW_HEADERS] == \
HTTP_HEADER_HA_AUTH.upper()

View file

@ -1,124 +1,10 @@
"""The tests for the Home Assistant HTTP component."""
import asyncio
from aiohttp.hdrs import (
ORIGIN, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_ALLOW_HEADERS,
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS,
CONTENT_TYPE)
import requests
from tests.common import get_test_instance_port, get_test_home_assistant
from homeassistant.setup import async_setup_component
from homeassistant import const, setup
import homeassistant.components.http as http
API_PASSWORD = 'test1234'
SERVER_PORT = get_test_instance_port()
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
CONTENT_TYPE: const.CONTENT_TYPE_JSON,
}
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
hass = None
def _url(path=''):
"""Helper method to generate URLs."""
return HTTP_BASE_URL + path
# pylint: disable=invalid-name
def setUpModule():
"""Initialize a Home Assistant server."""
global hass
hass = get_test_home_assistant()
setup.setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT,
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
}
}
)
setup.setup_component(hass, 'api')
# Registering static path as it caused CORS to blow up
hass.http.register_static_path(
'/custom_components', hass.config.path('custom_components'))
hass.start()
# pylint: disable=invalid-name
def tearDownModule():
"""Stop the Home Assistant server."""
hass.stop()
class TestCors:
"""Test HTTP component."""
def test_cors_allowed_with_password_in_url(self):
"""Test cross origin resource sharing with password in url."""
req = requests.get(_url(const.URL_API),
params={'api_password': API_PASSWORD},
headers={ORIGIN: HTTP_BASE_URL})
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_allowed_with_password_in_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
ORIGIN: HTTP_BASE_URL
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
def test_cors_denied_without_origin_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
}
req = requests.get(_url(const.URL_API), headers=headers)
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert allow_origin not in req.headers
assert allow_headers not in req.headers
def test_cors_preflight_allowed(self):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
headers = {
ORIGIN: HTTP_BASE_URL,
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
}
req = requests.options(_url(const.URL_API), headers=headers)
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == \
const.HTTP_HEADER_HA_AUTH.upper()
class TestView(http.HomeAssistantView):
"""Test the HTTP views."""
@ -133,12 +19,12 @@ class TestView(http.HomeAssistantView):
@asyncio.coroutine
def test_registering_view_while_running(hass, test_client):
def test_registering_view_while_running(hass, test_client, unused_port):
"""Test that we can register a view while the server is running."""
yield from setup.async_setup_component(
yield from async_setup_component(
hass, http.DOMAIN, {
http.DOMAIN: {
http.CONF_SERVER_PORT: get_test_instance_port(),
http.CONF_SERVER_PORT: unused_port(),
}
}
)
@ -151,7 +37,7 @@ def test_registering_view_while_running(hass, test_client):
@asyncio.coroutine
def test_api_base_url_with_domain(hass):
"""Test setting API URL."""
result = yield from setup.async_setup_component(hass, 'http', {
result = yield from async_setup_component(hass, 'http', {
'http': {
'base_url': 'example.com'
}
@ -163,7 +49,7 @@ def test_api_base_url_with_domain(hass):
@asyncio.coroutine
def test_api_base_url_with_ip(hass):
"""Test setting api url."""
result = yield from setup.async_setup_component(hass, 'http', {
result = yield from async_setup_component(hass, 'http', {
'http': {
'server_host': '1.1.1.1'
}
@ -175,7 +61,7 @@ def test_api_base_url_with_ip(hass):
@asyncio.coroutine
def test_api_base_url_with_ip_port(hass):
"""Test setting api url."""
result = yield from setup.async_setup_component(hass, 'http', {
result = yield from async_setup_component(hass, 'http', {
'http': {
'base_url': '1.1.1.1:8124'
}
@ -187,9 +73,34 @@ def test_api_base_url_with_ip_port(hass):
@asyncio.coroutine
def test_api_no_base_url(hass):
"""Test setting api url."""
result = yield from setup.async_setup_component(hass, 'http', {
result = yield from async_setup_component(hass, 'http', {
'http': {
}
})
assert result
assert hass.config.api.base_url == 'http://127.0.0.1:8123'
@asyncio.coroutine
def test_not_log_password(hass, unused_port, test_client, caplog):
"""Test access with password doesn't get logged."""
result = yield from async_setup_component(hass, 'api', {
'http': {
http.CONF_SERVER_PORT: unused_port(),
http.CONF_API_PASSWORD: 'some-pass'
}
})
assert result
client = yield from test_client(hass.http.app)
resp = yield from client.get('/api/', params={
'api_password': 'some-pass'
})
assert resp.status == 200
logs = caplog.text
# Ensure we don't log API passwords
assert '/api/' in logs
assert 'some-pass' not in logs

View file

@ -0,0 +1,48 @@
"""Test real IP middleware."""
import asyncio
from aiohttp import web
from aiohttp.hdrs import X_FORWARDED_FOR
from homeassistant.components.http.real_ip import setup_real_ip
from homeassistant.components.http.const import KEY_REAL_IP
@asyncio.coroutine
def mock_handler(request):
"""Handler that returns the real IP as text."""
return web.Response(text=str(request[KEY_REAL_IP]))
@asyncio.coroutine
def test_ignore_x_forwarded_for(test_client):
"""Test that we get the IP from the transport."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_real_ip(app, False)
mock_api_client = yield from test_client(app)
resp = yield from mock_api_client.get('/', headers={
X_FORWARDED_FOR: '255.255.255.255'
})
assert resp.status == 200
text = yield from resp.text()
assert text != '255.255.255.255'
@asyncio.coroutine
def test_use_x_forwarded_for(test_client):
"""Test that we get the IP from the transport."""
app = web.Application()
app.router.add_get('/', mock_handler)
setup_real_ip(app, True)
mock_api_client = yield from test_client(app)
resp = yield from mock_api_client.get('/', headers={
X_FORWARDED_FOR: '255.255.255.255'
})
assert resp.status == 200
text = yield from resp.text()
assert text == '255.255.255.255'

View file

@ -4,8 +4,7 @@ from unittest.mock import Mock, MagicMock, patch
from homeassistant.setup import setup_component
import homeassistant.components.mqtt as mqtt
from tests.common import (
get_test_home_assistant, mock_coro, mock_http_component)
from tests.common import get_test_home_assistant, mock_coro
class TestMQTT:
@ -14,7 +13,9 @@ class TestMQTT:
def setup_method(self, method):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_http_component(self.hass, 'super_secret')
setup_component(self.hass, 'http', {
'api_password': 'super_secret'
})
def teardown_method(self, method):
"""Stop everything that was started."""

View file

@ -4,12 +4,10 @@ import json
from unittest.mock import patch, MagicMock, mock_open
from aiohttp.hdrs import AUTHORIZATION
from homeassistant.setup import async_setup_component
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import save_json
from homeassistant.components.notify import html5
from tests.common import mock_http_component_app
CONFIG_FILE = 'file.conf'
SUBSCRIPTION_1 = {
@ -52,6 +50,23 @@ REGISTER_URL = '/api/notify.html5'
PUBLISH_URL = '/api/notify.html5/callback'
@asyncio.coroutine
def mock_client(hass, test_client, registrations=None):
"""Create a test client for HTML5 views."""
if registrations is None:
registrations = {}
with patch('homeassistant.components.notify.html5._load_config',
return_value=registrations):
yield from async_setup_component(hass, 'notify', {
'notify': {
'platform': 'html5'
}
})
return (yield from test_client(hass.http.app))
class TestHtml5Notify(object):
"""Tests for HTML5 notify platform."""
@ -89,8 +104,6 @@ class TestHtml5Notify(object):
service.send_message('Hello', target=['device', 'non_existing'],
data={'icon': 'beer.png'})
print(mock_wp.mock_calls)
assert len(mock_wp.mock_calls) == 3
# WebPusher constructor
@ -104,421 +117,224 @@ class TestHtml5Notify(object):
assert payload['body'] == 'Hello'
assert payload['icon'] == 'beer.png'
@asyncio.coroutine
def test_registering_new_device_view(self, loop, test_client):
"""Test that the HTML view works."""
hass = MagicMock()
expected = {
'unnamed device': SUBSCRIPTION_1,
}
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
@asyncio.coroutine
def test_registering_new_device_view(hass, test_client):
"""Test that the HTML view works."""
client = yield from mock_client(hass, test_client)
assert service is not None
assert len(hass.mock_calls) == 3
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
content = yield from resp.text()
assert resp.status == 200, content
assert view.registrations == expected
assert resp.status == 200
assert len(mock_save.mock_calls) == 1
assert mock_save.mock_calls[0][1][1] == {
'unnamed device': SUBSCRIPTION_1,
}
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine
def test_registering_new_device_expiration_view(self, loop, test_client):
"""Test that the HTML view works."""
hass = MagicMock()
expected = {
'unnamed device': SUBSCRIPTION_4,
}
@asyncio.coroutine
def test_registering_new_device_expiration_view(hass, test_client):
"""Test that the HTML view works."""
client = yield from mock_client(hass, test_client)
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == {}
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text()
assert resp.status == 200, content
assert view.registrations == expected
assert resp.status == 200
assert mock_save.mock_calls[0][1][1] == {
'unnamed device': SUBSCRIPTION_4,
}
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine
def test_registering_new_device_fails_view(self, loop, test_client):
"""Test subs. are not altered when registering a new device fails."""
hass = MagicMock()
expected = {}
hass.config.path.return_value = CONFIG_FILE
html5.get_service(hass, {})
view = hass.mock_calls[1][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
hass.async_add_job.side_effect = HomeAssistantError()
@asyncio.coroutine
def test_registering_new_device_fails_view(hass, test_client):
"""Test subs. are not altered when registering a new device fails."""
registrations = {}
client = yield from mock_client(hass, test_client, registrations)
with patch('homeassistant.components.notify.html5.save_json',
side_effect=HomeAssistantError()):
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text()
assert resp.status == 500, content
assert view.registrations == expected
assert resp.status == 500
assert registrations == {}
@asyncio.coroutine
def test_registering_existing_device_view(self, loop, test_client):
"""Test subscription is updated when registering existing device."""
hass = MagicMock()
expected = {
'unnamed device': SUBSCRIPTION_4,
}
hass.config.path.return_value = CONFIG_FILE
html5.get_service(hass, {})
view = hass.mock_calls[1][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
@asyncio.coroutine
def test_registering_existing_device_view(hass, test_client):
"""Test subscription is updated when registering existing device."""
registrations = {}
client = yield from mock_client(hass, test_client, registrations)
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text()
assert resp.status == 200, content
assert view.registrations == expected
assert resp.status == 200
assert mock_save.mock_calls[0][1][1] == {
'unnamed device': SUBSCRIPTION_4,
}
assert registrations == {
'unnamed device': SUBSCRIPTION_4,
}
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine
def test_registering_existing_device_fails_view(self, loop, test_client):
"""Test sub. is not updated when registering existing device fails."""
hass = MagicMock()
expected = {
'unnamed device': SUBSCRIPTION_1,
}
hass.config.path.return_value = CONFIG_FILE
html5.get_service(hass, {})
view = hass.mock_calls[1][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
@asyncio.coroutine
def test_registering_existing_device_fails_view(hass, test_client):
"""Test sub. is not updated when registering existing device fails."""
registrations = {}
client = yield from mock_client(hass, test_client, registrations)
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1))
hass.async_add_job.side_effect = HomeAssistantError()
mock_save.side_effect = HomeAssistantError
resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_4))
content = yield from resp.text()
assert resp.status == 500, content
assert view.registrations == expected
assert resp.status == 500
assert registrations == {
'unnamed device': SUBSCRIPTION_1,
}
@asyncio.coroutine
def test_registering_new_device_validation(self, loop, test_client):
"""Test various errors when registering a new device."""
hass = MagicMock()
m = mock_open()
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
@asyncio.coroutine
def test_registering_new_device_validation(hass, test_client):
"""Test various errors when registering a new device."""
client = yield from mock_client(hass, test_client)
assert service is not None
resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'invalid browser',
'subscription': 'sub info',
}))
assert resp.status == 400
# assert hass.called
assert len(hass.mock_calls) == 3
resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'chrome',
}))
assert resp.status == 400
view = hass.mock_calls[1][1][0]
with patch('homeassistant.components.notify.html5.save_json',
return_value=False):
resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'chrome',
'subscription': 'sub info',
}))
assert resp.status == 400
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'invalid browser',
'subscription': 'sub info',
}))
assert resp.status == 400
@asyncio.coroutine
def test_unregistering_device_view(hass, test_client):
"""Test that the HTML unregister view works."""
registrations = {
'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
client = yield from mock_client(hass, test_client, registrations)
resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'chrome',
}))
assert resp.status == 400
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
with patch('homeassistant.components.notify.html5.save_json',
return_value=False):
# resp = view.post(Request(builder.get_environ()))
resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'chrome',
'subscription': 'sub info',
}))
assert resp.status == 200
assert len(mock_save.mock_calls) == 1
assert registrations == {
'other device': SUBSCRIPTION_2
}
assert resp.status == 400
@asyncio.coroutine
def test_unregistering_device_view(self, loop, test_client):
"""Test that the HTML unregister view works."""
hass = MagicMock()
@asyncio.coroutine
def test_unregister_device_view_handle_unknown_subscription(hass, test_client):
"""Test that the HTML unregister view handles unknown subscriptions."""
registrations = {}
client = yield from mock_client(hass, test_client, registrations)
config = {
'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_3['subscription']
}))
m = mock_open(read_data=json.dumps(config))
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert resp.status == 200, resp.response
assert registrations == {}
assert len(mock_save.mock_calls) == 0
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
@asyncio.coroutine
def test_unregistering_device_view_handles_save_error(hass, test_client):
"""Test that the HTML unregister view handles save errors."""
registrations = {
'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
client = yield from mock_client(hass, test_client, registrations)
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == config
with patch('homeassistant.components.notify.html5.save_json',
side_effect=HomeAssistantError()):
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
assert resp.status == 500, resp.response
assert registrations == {
'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
config.pop('some device')
@asyncio.coroutine
def test_callback_view_no_jwt(hass, test_client):
"""Test that the notification callback view works without JWT."""
client = yield from mock_client(hass, test_client)
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
'type': 'push',
'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72'
}))
assert resp.status == 200, resp.response
assert view.registrations == config
assert resp.status == 401, resp.response
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE,
config)
@asyncio.coroutine
def test_unregister_device_view_handle_unknown_subscription(
self, loop, test_client):
"""Test that the HTML unregister view handles unknown subscriptions."""
hass = MagicMock()
@asyncio.coroutine
def test_callback_view_with_jwt(hass, test_client):
"""Test that the notification callback view works with JWT."""
registrations = {
'device': SUBSCRIPTION_1
}
client = yield from mock_client(hass, test_client, registrations)
config = {
'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
with patch('pywebpush.WebPusher') as mock_wp:
yield from hass.services.async_call('notify', 'notify', {
'message': 'Hello',
'target': ['device'],
'data': {'icon': 'beer.png'}
}, blocking=True)
m = mock_open(read_data=json.dumps(config))
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert len(mock_wp.mock_calls) == 3
assert service is not None
# WebPusher constructor
assert mock_wp.mock_calls[0][1][0] == \
SUBSCRIPTION_1['subscription']
# Third mock_call checks the status_code of the response.
assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__'
# assert hass.called
assert len(hass.mock_calls) == 3
# Call to send
push_payload = json.loads(mock_wp.mock_calls[1][1][0])
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == config
assert push_payload['body'] == 'Hello'
assert push_payload['icon'] == 'beer.png'
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_3['subscription']
}))
resp = yield from client.post(PUBLISH_URL, json={
'type': 'push',
}, headers={AUTHORIZATION: bearer_token})
assert resp.status == 200, resp.response
assert view.registrations == config
hass.async_add_job.assert_not_called()
@asyncio.coroutine
def test_unregistering_device_view_handles_save_error(
self, loop, test_client):
"""Test that the HTML unregister view handles save errors."""
hass = MagicMock()
config = {
'some device': SUBSCRIPTION_1,
'other device': SUBSCRIPTION_2,
}
m = mock_open(read_data=json.dumps(config))
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
view = hass.mock_calls[1][1][0]
assert view.json_path == hass.config.path.return_value
assert view.registrations == config
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
hass.async_add_job.side_effect = HomeAssistantError()
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'],
}))
assert resp.status == 500, resp.response
assert view.registrations == config
@asyncio.coroutine
def test_callback_view_no_jwt(self, loop, test_client):
"""Test that the notification callback view works without JWT."""
hass = MagicMock()
m = mock_open()
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
view = hass.mock_calls[2][1][0]
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
'type': 'push',
'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72'
}))
assert resp.status == 401, resp.response
@asyncio.coroutine
def test_callback_view_with_jwt(self, loop, test_client):
"""Test that the notification callback view works with JWT."""
hass = MagicMock()
data = {
'device': SUBSCRIPTION_1
}
m = mock_open(read_data=json.dumps(data))
with patch(
'homeassistant.util.json.open',
m, create=True
):
hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {'gcm_sender_id': '100'})
assert service is not None
# assert hass.called
assert len(hass.mock_calls) == 3
with patch('pywebpush.WebPusher') as mock_wp:
service.send_message(
'Hello', target=['device'], data={'icon': 'beer.png'})
assert len(mock_wp.mock_calls) == 3
# WebPusher constructor
assert mock_wp.mock_calls[0][1][0] == \
SUBSCRIPTION_1['subscription']
# Third mock_call checks the status_code of the response.
assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__'
# Call to send
push_payload = json.loads(mock_wp.mock_calls[1][1][0])
assert push_payload['body'] == 'Hello'
assert push_payload['icon'] == 'beer.png'
view = hass.mock_calls[2][1][0]
view.registrations = data
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
hass.loop = loop
app = mock_http_component_app(hass)
view.register(app.router)
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
'type': 'push',
}), headers={AUTHORIZATION: bearer_token})
assert resp.status == 200
body = yield from resp.json()
assert body == {"event": "push", "status": "ok"}
assert resp.status == 200
body = yield from resp.json()
assert body == {"event": "push", "status": "ok"}

View file

@ -10,8 +10,7 @@ import homeassistant.util.dt as dt_util
from homeassistant.components import history, recorder
from tests.common import (
init_recorder_component, mock_http_component, mock_state_change_event,
get_test_home_assistant)
init_recorder_component, mock_state_change_event, get_test_home_assistant)
class TestComponentHistory(unittest.TestCase):
@ -38,7 +37,6 @@ class TestComponentHistory(unittest.TestCase):
def test_setup(self):
"""Test setup method of history."""
mock_http_component(self.hass)
config = history.CONFIG_SCHEMA({
# ha.DOMAIN: {},
history.DOMAIN: {

View file

@ -14,7 +14,7 @@ from homeassistant.components import logbook
from homeassistant.setup import setup_component
from tests.common import (
mock_http_component, init_recorder_component, get_test_home_assistant)
init_recorder_component, get_test_home_assistant)
_LOGGER = logging.getLogger(__name__)
@ -29,10 +29,7 @@ class TestComponentLogbook(unittest.TestCase):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
init_recorder_component(self.hass) # Force an in memory DB
mock_http_component(self.hass)
self.hass.config.components |= set(['frontend', 'recorder', 'api'])
assert setup_component(self.hass, logbook.DOMAIN,
self.EMPTY_CONFIG)
assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG)
self.hass.start()
def tearDown(self):

View file

@ -150,7 +150,6 @@ def test_api_update_fails(hass, test_client):
assert resp.status == 404
beer_id = hass.data['shopping_list'].items[0]['id']
client = yield from test_client(hass.http.app)
resp = yield from client.post(
'/api/shopping_list/item/{}'.format(beer_id), json={
'name': 123,

View file

@ -8,8 +8,9 @@ import pytest
from homeassistant.core import callback
from homeassistant.components import websocket_api as wapi, frontend
from homeassistant.setup import async_setup_component
from tests.common import mock_http_component_app, mock_coro
from tests.common import mock_coro
API_PASSWORD = 'test1234'
@ -17,10 +18,10 @@ API_PASSWORD = 'test1234'
@pytest.fixture
def websocket_client(loop, hass, test_client):
"""Websocket client fixture connected to websocket server."""
websocket_app = mock_http_component_app(hass)
wapi.WebsocketAPIView().register(websocket_app.router)
assert loop.run_until_complete(
async_setup_component(hass, 'websocket_api'))
client = loop.run_until_complete(test_client(websocket_app))
client = loop.run_until_complete(test_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
auth_ok = loop.run_until_complete(ws.receive_json())
@ -35,10 +36,14 @@ def websocket_client(loop, hass, test_client):
@pytest.fixture
def no_auth_websocket_client(hass, loop, test_client):
"""Websocket connection that requires authentication."""
websocket_app = mock_http_component_app(hass, API_PASSWORD)
wapi.WebsocketAPIView().register(websocket_app.router)
assert loop.run_until_complete(
async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
}))
client = loop.run_until_complete(test_client(websocket_app))
client = loop.run_until_complete(test_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
auth_ok = loop.run_until_complete(ws.receive_json())