mirror of
https://github.com/home-assistant/core
synced 2024-10-05 20:07:58 +00:00
Context (#15674)
* Add context * Add context to switch/light services * Test set_state API * Lint * Fix tests * Do not include context yet in comparison * Do not pass in loop * Fix Z-Wave tests * Add websocket test without user
This commit is contained in:
parent
867f80715e
commit
c7f4bdafc0
|
@ -220,7 +220,8 @@ class APIEntityStateView(HomeAssistantView):
|
||||||
is_new_state = hass.states.get(entity_id) is None
|
is_new_state = hass.states.get(entity_id) is None
|
||||||
|
|
||||||
# Write state
|
# Write state
|
||||||
hass.states.async_set(entity_id, new_state, attributes, force_update)
|
hass.states.async_set(entity_id, new_state, attributes, force_update,
|
||||||
|
self.context(request))
|
||||||
|
|
||||||
# Read the state back for our response
|
# Read the state back for our response
|
||||||
status_code = HTTP_CREATED if is_new_state else 200
|
status_code = HTTP_CREATED if is_new_state else 200
|
||||||
|
@ -279,7 +280,8 @@ class APIEventView(HomeAssistantView):
|
||||||
event_data[key] = state
|
event_data[key] = state
|
||||||
|
|
||||||
request.app['hass'].bus.async_fire(
|
request.app['hass'].bus.async_fire(
|
||||||
event_type, event_data, ha.EventOrigin.remote)
|
event_type, event_data, ha.EventOrigin.remote,
|
||||||
|
self.context(request))
|
||||||
|
|
||||||
return self.json_message("Event {} fired.".format(event_type))
|
return self.json_message("Event {} fired.".format(event_type))
|
||||||
|
|
||||||
|
@ -316,7 +318,8 @@ class APIDomainServicesView(HomeAssistantView):
|
||||||
"Data should be valid JSON.", HTTP_BAD_REQUEST)
|
"Data should be valid JSON.", HTTP_BAD_REQUEST)
|
||||||
|
|
||||||
with AsyncTrackStates(hass) as changed_states:
|
with AsyncTrackStates(hass) as changed_states:
|
||||||
await hass.services.async_call(domain, service, data, True)
|
await hass.services.async_call(
|
||||||
|
domain, service, data, True, self.context(request))
|
||||||
|
|
||||||
return self.json(changed_states)
|
return self.json(changed_states)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
|
||||||
|
|
||||||
import homeassistant.remote as rem
|
import homeassistant.remote as rem
|
||||||
from homeassistant.components.http.ban import process_success_login
|
from homeassistant.components.http.ban import process_success_login
|
||||||
from homeassistant.core import is_callback
|
from homeassistant.core import Context, is_callback
|
||||||
from homeassistant.const import CONTENT_TYPE_JSON
|
from homeassistant.const import CONTENT_TYPE_JSON
|
||||||
|
|
||||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||||
|
@ -32,6 +32,14 @@ class HomeAssistantView:
|
||||||
cors_allowed = False
|
cors_allowed = False
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
|
def context(self, request):
|
||||||
|
"""Generate a context from a request."""
|
||||||
|
user = request.get('hass_user')
|
||||||
|
if user is None:
|
||||||
|
return Context()
|
||||||
|
|
||||||
|
return Context(user_id=user.id)
|
||||||
|
|
||||||
def json(self, result, status_code=200, headers=None):
|
def json(self, result, status_code=200, headers=None):
|
||||||
"""Return a JSON response."""
|
"""Return a JSON response."""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -359,7 +359,9 @@ async def async_setup(hass, config):
|
||||||
|
|
||||||
if not light.should_poll:
|
if not light.should_poll:
|
||||||
continue
|
continue
|
||||||
update_tasks.append(light.async_update_ha_state(True))
|
|
||||||
|
update_tasks.append(
|
||||||
|
light.async_update_ha_state(True, service.context))
|
||||||
|
|
||||||
if update_tasks:
|
if update_tasks:
|
||||||
await asyncio.wait(update_tasks, loop=hass.loop)
|
await asyncio.wait(update_tasks, loop=hass.loop)
|
||||||
|
|
|
@ -114,7 +114,8 @@ async def async_setup(hass, config):
|
||||||
|
|
||||||
if not switch.should_poll:
|
if not switch.should_poll:
|
||||||
continue
|
continue
|
||||||
update_tasks.append(switch.async_update_ha_state(True))
|
update_tasks.append(
|
||||||
|
switch.async_update_ha_state(True, service.context))
|
||||||
|
|
||||||
if update_tasks:
|
if update_tasks:
|
||||||
await asyncio.wait(update_tasks, loop=hass.loop)
|
await asyncio.wait(update_tasks, loop=hass.loop)
|
||||||
|
|
|
@ -18,7 +18,7 @@ from voluptuous.humanize import humanize_error
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
|
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
|
||||||
__version__)
|
__version__)
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.remote import JSONEncoder
|
from homeassistant.remote import JSONEncoder
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
@ -262,6 +262,18 @@ class ActiveConnection:
|
||||||
self._handle_task = None
|
self._handle_task = None
|
||||||
self._writer_task = None
|
self._writer_task = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user(self):
|
||||||
|
"""Return the user associated with the connection."""
|
||||||
|
return self.request.get('hass_user')
|
||||||
|
|
||||||
|
def context(self, msg):
|
||||||
|
"""Return a context."""
|
||||||
|
user = self.user
|
||||||
|
if user is None:
|
||||||
|
return Context()
|
||||||
|
return Context(user_id=user.id)
|
||||||
|
|
||||||
def debug(self, message1, message2=''):
|
def debug(self, message1, message2=''):
|
||||||
"""Print a debug message."""
|
"""Print a debug message."""
|
||||||
_LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)
|
_LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)
|
||||||
|
@ -287,7 +299,7 @@ class ActiveConnection:
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def send_message_outside(self, message):
|
def send_message_outside(self, message):
|
||||||
"""Send a message to the client outside of the main task.
|
"""Send a message to the client.
|
||||||
|
|
||||||
Closes connection if the client is not reading the messages.
|
Closes connection if the client is not reading the messages.
|
||||||
|
|
||||||
|
@ -508,7 +520,8 @@ def handle_call_service(hass, connection, msg):
|
||||||
async def call_service_helper(msg):
|
async def call_service_helper(msg):
|
||||||
"""Call a service and fire complete message."""
|
"""Call a service and fire complete message."""
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
msg['domain'], msg['service'], msg.get('service_data'), True)
|
msg['domain'], msg['service'], msg.get('service_data'), True,
|
||||||
|
connection.context(msg))
|
||||||
connection.send_message_outside(result_message(msg['id']))
|
connection.send_message_outside(result_message(msg['id']))
|
||||||
|
|
||||||
hass.async_add_job(call_service_helper(msg))
|
hass.async_add_job(call_service_helper(msg))
|
||||||
|
|
|
@ -224,9 +224,6 @@ ATTR_ID = 'id'
|
||||||
# Name
|
# Name
|
||||||
ATTR_NAME = 'name'
|
ATTR_NAME = 'name'
|
||||||
|
|
||||||
# Data for a SERVICE_EXECUTED event
|
|
||||||
ATTR_SERVICE_CALL_ID = 'service_call_id'
|
|
||||||
|
|
||||||
# Contains one string or a list of strings, each being an entity id
|
# Contains one string or a list of strings, each being an entity id
|
||||||
ATTR_ENTITY_ID = 'entity_id'
|
ATTR_ENTITY_ID = 'entity_id'
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ import re
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from time import monotonic
|
from time import monotonic
|
||||||
|
import uuid
|
||||||
|
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
|
@ -23,12 +24,13 @@ from typing import ( # NOQA
|
||||||
TYPE_CHECKING, Awaitable, Iterator)
|
TYPE_CHECKING, Awaitable, Iterator)
|
||||||
|
|
||||||
from async_timeout import timeout
|
from async_timeout import timeout
|
||||||
|
import attr
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
|
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
|
||||||
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
|
ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
|
||||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||||
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
|
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
|
||||||
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
|
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
|
||||||
|
@ -191,7 +193,7 @@ class HomeAssistant:
|
||||||
try:
|
try:
|
||||||
# Only block for EVENT_HOMEASSISTANT_START listener
|
# Only block for EVENT_HOMEASSISTANT_START listener
|
||||||
self.async_stop_track_tasks()
|
self.async_stop_track_tasks()
|
||||||
with timeout(TIMEOUT_EVENT_START, loop=self.loop):
|
with timeout(TIMEOUT_EVENT_START):
|
||||||
await self.async_block_till_done()
|
await self.async_block_till_done()
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
|
@ -201,7 +203,7 @@ class HomeAssistant:
|
||||||
', '.join(self.config.components))
|
', '.join(self.config.components))
|
||||||
|
|
||||||
# Allow automations to set up the start triggers before changing state
|
# Allow automations to set up the start triggers before changing state
|
||||||
await asyncio.sleep(0, loop=self.loop)
|
await asyncio.sleep(0)
|
||||||
self.state = CoreState.running
|
self.state = CoreState.running
|
||||||
_async_create_timer(self)
|
_async_create_timer(self)
|
||||||
|
|
||||||
|
@ -307,16 +309,16 @@ class HomeAssistant:
|
||||||
async def async_block_till_done(self) -> None:
|
async def async_block_till_done(self) -> None:
|
||||||
"""Block till all pending work is done."""
|
"""Block till all pending work is done."""
|
||||||
# To flush out any call_soon_threadsafe
|
# To flush out any call_soon_threadsafe
|
||||||
await asyncio.sleep(0, loop=self.loop)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
while self._pending_tasks:
|
while self._pending_tasks:
|
||||||
pending = [task for task in self._pending_tasks
|
pending = [task for task in self._pending_tasks
|
||||||
if not task.done()]
|
if not task.done()]
|
||||||
self._pending_tasks.clear()
|
self._pending_tasks.clear()
|
||||||
if pending:
|
if pending:
|
||||||
await asyncio.wait(pending, loop=self.loop)
|
await asyncio.wait(pending)
|
||||||
else:
|
else:
|
||||||
await asyncio.sleep(0, loop=self.loop)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop Home Assistant and shuts down all threads."""
|
"""Stop Home Assistant and shuts down all threads."""
|
||||||
|
@ -343,6 +345,27 @@ class HomeAssistant:
|
||||||
self.loop.stop()
|
self.loop.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class Context:
|
||||||
|
"""The context that triggered something."""
|
||||||
|
|
||||||
|
user_id = attr.ib(
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
id = attr.ib(
|
||||||
|
type=str,
|
||||||
|
default=attr.Factory(lambda: uuid.uuid4().hex),
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
"""Return a dictionary representation of the context."""
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'user_id': self.user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class EventOrigin(enum.Enum):
|
class EventOrigin(enum.Enum):
|
||||||
"""Represent the origin of an event."""
|
"""Represent the origin of an event."""
|
||||||
|
|
||||||
|
@ -357,16 +380,18 @@ class EventOrigin(enum.Enum):
|
||||||
class Event:
|
class Event:
|
||||||
"""Representation of an event within the bus."""
|
"""Representation of an event within the bus."""
|
||||||
|
|
||||||
__slots__ = ['event_type', 'data', 'origin', 'time_fired']
|
__slots__ = ['event_type', 'data', 'origin', 'time_fired', 'context']
|
||||||
|
|
||||||
def __init__(self, event_type: str, data: Optional[Dict] = None,
|
def __init__(self, event_type: str, data: Optional[Dict] = None,
|
||||||
origin: EventOrigin = EventOrigin.local,
|
origin: EventOrigin = EventOrigin.local,
|
||||||
time_fired: Optional[int] = None) -> None:
|
time_fired: Optional[int] = None,
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Initialize a new event."""
|
"""Initialize a new event."""
|
||||||
self.event_type = event_type
|
self.event_type = event_type
|
||||||
self.data = data or {}
|
self.data = data or {}
|
||||||
self.origin = origin
|
self.origin = origin
|
||||||
self.time_fired = time_fired or dt_util.utcnow()
|
self.time_fired = time_fired or dt_util.utcnow()
|
||||||
|
self.context = context or Context()
|
||||||
|
|
||||||
def as_dict(self) -> Dict:
|
def as_dict(self) -> Dict:
|
||||||
"""Create a dict representation of this Event.
|
"""Create a dict representation of this Event.
|
||||||
|
@ -378,6 +403,7 @@ class Event:
|
||||||
'data': dict(self.data),
|
'data': dict(self.data),
|
||||||
'origin': str(self.origin),
|
'origin': str(self.origin),
|
||||||
'time_fired': self.time_fired,
|
'time_fired': self.time_fired,
|
||||||
|
'context': self.context.as_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -425,14 +451,16 @@ class EventBus:
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
def fire(self, event_type: str, event_data: Optional[Dict] = None,
|
def fire(self, event_type: str, event_data: Optional[Dict] = None,
|
||||||
origin: EventOrigin = EventOrigin.local) -> None:
|
origin: EventOrigin = EventOrigin.local,
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Fire an event."""
|
"""Fire an event."""
|
||||||
self._hass.loop.call_soon_threadsafe(
|
self._hass.loop.call_soon_threadsafe(
|
||||||
self.async_fire, event_type, event_data, origin)
|
self.async_fire, event_type, event_data, origin, context)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_fire(self, event_type: str, event_data: Optional[Dict] = None,
|
def async_fire(self, event_type: str, event_data: Optional[Dict] = None,
|
||||||
origin: EventOrigin = EventOrigin.local) -> None:
|
origin: EventOrigin = EventOrigin.local,
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Fire an event.
|
"""Fire an event.
|
||||||
|
|
||||||
This method must be run in the event loop.
|
This method must be run in the event loop.
|
||||||
|
@ -445,7 +473,7 @@ class EventBus:
|
||||||
event_type != EVENT_HOMEASSISTANT_CLOSE):
|
event_type != EVENT_HOMEASSISTANT_CLOSE):
|
||||||
listeners = match_all_listeners + listeners
|
listeners = match_all_listeners + listeners
|
||||||
|
|
||||||
event = Event(event_type, event_data, origin)
|
event = Event(event_type, event_data, origin, None, context)
|
||||||
|
|
||||||
if event_type != EVENT_TIME_CHANGED:
|
if event_type != EVENT_TIME_CHANGED:
|
||||||
_LOGGER.info("Bus:Handling %s", event)
|
_LOGGER.info("Bus:Handling %s", event)
|
||||||
|
@ -569,15 +597,17 @@ class State:
|
||||||
attributes: extra information on entity and state
|
attributes: extra information on entity and state
|
||||||
last_changed: last time the state was changed, not the attributes.
|
last_changed: last time the state was changed, not the attributes.
|
||||||
last_updated: last time this object was updated.
|
last_updated: last time this object was updated.
|
||||||
|
context: Context in which it was created
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ['entity_id', 'state', 'attributes',
|
__slots__ = ['entity_id', 'state', 'attributes',
|
||||||
'last_changed', 'last_updated']
|
'last_changed', 'last_updated', 'context']
|
||||||
|
|
||||||
def __init__(self, entity_id: str, state: Any,
|
def __init__(self, entity_id: str, state: Any,
|
||||||
attributes: Optional[Dict] = None,
|
attributes: Optional[Dict] = None,
|
||||||
last_changed: Optional[datetime.datetime] = None,
|
last_changed: Optional[datetime.datetime] = None,
|
||||||
last_updated: Optional[datetime.datetime] = None) -> None:
|
last_updated: Optional[datetime.datetime] = None,
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Initialize a new state."""
|
"""Initialize a new state."""
|
||||||
state = str(state)
|
state = str(state)
|
||||||
|
|
||||||
|
@ -596,6 +626,7 @@ class State:
|
||||||
self.attributes = MappingProxyType(attributes or {})
|
self.attributes = MappingProxyType(attributes or {})
|
||||||
self.last_updated = last_updated or dt_util.utcnow()
|
self.last_updated = last_updated or dt_util.utcnow()
|
||||||
self.last_changed = last_changed or self.last_updated
|
self.last_changed = last_changed or self.last_updated
|
||||||
|
self.context = context or Context()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def domain(self) -> str:
|
def domain(self) -> str:
|
||||||
|
@ -626,7 +657,8 @@ class State:
|
||||||
'state': self.state,
|
'state': self.state,
|
||||||
'attributes': dict(self.attributes),
|
'attributes': dict(self.attributes),
|
||||||
'last_changed': self.last_changed,
|
'last_changed': self.last_changed,
|
||||||
'last_updated': self.last_updated}
|
'last_updated': self.last_updated,
|
||||||
|
'context': self.context.as_dict()}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, json_dict: Dict) -> Any:
|
def from_dict(cls, json_dict: Dict) -> Any:
|
||||||
|
@ -650,8 +682,13 @@ class State:
|
||||||
if isinstance(last_updated, str):
|
if isinstance(last_updated, str):
|
||||||
last_updated = dt_util.parse_datetime(last_updated)
|
last_updated = dt_util.parse_datetime(last_updated)
|
||||||
|
|
||||||
|
context = json_dict.get('context')
|
||||||
|
if context:
|
||||||
|
context = Context(**context)
|
||||||
|
|
||||||
return cls(json_dict['entity_id'], json_dict['state'],
|
return cls(json_dict['entity_id'], json_dict['state'],
|
||||||
json_dict.get('attributes'), last_changed, last_updated)
|
json_dict.get('attributes'), last_changed, last_updated,
|
||||||
|
context)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Return the comparison of the state."""
|
"""Return the comparison of the state."""
|
||||||
|
@ -662,11 +699,11 @@ class State:
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Return the representation of the states."""
|
"""Return the representation of the states."""
|
||||||
attr = "; {}".format(util.repr_helper(self.attributes)) \
|
attrs = "; {}".format(util.repr_helper(self.attributes)) \
|
||||||
if self.attributes else ""
|
if self.attributes else ""
|
||||||
|
|
||||||
return "<state {}={}{} @ {}>".format(
|
return "<state {}={}{} @ {}>".format(
|
||||||
self.entity_id, self.state, attr,
|
self.entity_id, self.state, attrs,
|
||||||
dt_util.as_local(self.last_changed).isoformat())
|
dt_util.as_local(self.last_changed).isoformat())
|
||||||
|
|
||||||
|
|
||||||
|
@ -761,7 +798,8 @@ class StateMachine:
|
||||||
|
|
||||||
def set(self, entity_id: str, new_state: Any,
|
def set(self, entity_id: str, new_state: Any,
|
||||||
attributes: Optional[Dict] = None,
|
attributes: Optional[Dict] = None,
|
||||||
force_update: bool = False) -> None:
|
force_update: bool = False,
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Set the state of an entity, add entity if it does not exist.
|
"""Set the state of an entity, add entity if it does not exist.
|
||||||
|
|
||||||
Attributes is an optional dict to specify attributes of this state.
|
Attributes is an optional dict to specify attributes of this state.
|
||||||
|
@ -772,12 +810,14 @@ class StateMachine:
|
||||||
run_callback_threadsafe(
|
run_callback_threadsafe(
|
||||||
self._loop,
|
self._loop,
|
||||||
self.async_set, entity_id, new_state, attributes, force_update,
|
self.async_set, entity_id, new_state, attributes, force_update,
|
||||||
|
context,
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_set(self, entity_id: str, new_state: Any,
|
def async_set(self, entity_id: str, new_state: Any,
|
||||||
attributes: Optional[Dict] = None,
|
attributes: Optional[Dict] = None,
|
||||||
force_update: bool = False) -> None:
|
force_update: bool = False,
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Set the state of an entity, add entity if it does not exist.
|
"""Set the state of an entity, add entity if it does not exist.
|
||||||
|
|
||||||
Attributes is an optional dict to specify attributes of this state.
|
Attributes is an optional dict to specify attributes of this state.
|
||||||
|
@ -804,13 +844,17 @@ class StateMachine:
|
||||||
if same_state and same_attr:
|
if same_state and same_attr:
|
||||||
return
|
return
|
||||||
|
|
||||||
state = State(entity_id, new_state, attributes, last_changed)
|
if context is None:
|
||||||
|
context = Context()
|
||||||
|
|
||||||
|
state = State(entity_id, new_state, attributes, last_changed, None,
|
||||||
|
context)
|
||||||
self._states[entity_id] = state
|
self._states[entity_id] = state
|
||||||
self._bus.async_fire(EVENT_STATE_CHANGED, {
|
self._bus.async_fire(EVENT_STATE_CHANGED, {
|
||||||
'entity_id': entity_id,
|
'entity_id': entity_id,
|
||||||
'old_state': old_state,
|
'old_state': old_state,
|
||||||
'new_state': state,
|
'new_state': state,
|
||||||
})
|
}, EventOrigin.local, context)
|
||||||
|
|
||||||
|
|
||||||
class Service:
|
class Service:
|
||||||
|
@ -818,7 +862,8 @@ class Service:
|
||||||
|
|
||||||
__slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction']
|
__slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction']
|
||||||
|
|
||||||
def __init__(self, func: Callable, schema: Optional[vol.Schema]) -> None:
|
def __init__(self, func: Callable, schema: Optional[vol.Schema],
|
||||||
|
context: Optional[Context] = None) -> None:
|
||||||
"""Initialize a service."""
|
"""Initialize a service."""
|
||||||
self.func = func
|
self.func = func
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
|
@ -829,23 +874,25 @@ class Service:
|
||||||
class ServiceCall:
|
class ServiceCall:
|
||||||
"""Representation of a call to a service."""
|
"""Representation of a call to a service."""
|
||||||
|
|
||||||
__slots__ = ['domain', 'service', 'data', 'call_id']
|
__slots__ = ['domain', 'service', 'data', 'context']
|
||||||
|
|
||||||
def __init__(self, domain: str, service: str, data: Optional[Dict] = None,
|
def __init__(self, domain: str, service: str, data: Optional[Dict] = None,
|
||||||
call_id: Optional[str] = None) -> None:
|
context: Optional[Context] = None) -> None:
|
||||||
"""Initialize a service call."""
|
"""Initialize a service call."""
|
||||||
self.domain = domain.lower()
|
self.domain = domain.lower()
|
||||||
self.service = service.lower()
|
self.service = service.lower()
|
||||||
self.data = MappingProxyType(data or {})
|
self.data = MappingProxyType(data or {})
|
||||||
self.call_id = call_id
|
self.context = context or Context()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Return the representation of the service."""
|
"""Return the representation of the service."""
|
||||||
if self.data:
|
if self.data:
|
||||||
return "<ServiceCall {}.{}: {}>".format(
|
return "<ServiceCall {}.{} (c:{}): {}>".format(
|
||||||
self.domain, self.service, util.repr_helper(self.data))
|
self.domain, self.service, self.context.id,
|
||||||
|
util.repr_helper(self.data))
|
||||||
|
|
||||||
return "<ServiceCall {}.{}>".format(self.domain, self.service)
|
return "<ServiceCall {}.{} (c:{})>".format(
|
||||||
|
self.domain, self.service, self.context.id)
|
||||||
|
|
||||||
|
|
||||||
class ServiceRegistry:
|
class ServiceRegistry:
|
||||||
|
@ -857,15 +904,6 @@ class ServiceRegistry:
|
||||||
self._hass = hass
|
self._hass = hass
|
||||||
self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
|
self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
|
||||||
|
|
||||||
def _gen_unique_id() -> Iterator[str]:
|
|
||||||
cur_id = 1
|
|
||||||
while True:
|
|
||||||
yield '{}-{}'.format(id(self), cur_id)
|
|
||||||
cur_id += 1
|
|
||||||
|
|
||||||
gen = _gen_unique_id()
|
|
||||||
self._generate_unique_id = lambda: next(gen)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def services(self) -> Dict[str, Dict[str, Service]]:
|
def services(self) -> Dict[str, Dict[str, Service]]:
|
||||||
"""Return dictionary with per domain a list of available services."""
|
"""Return dictionary with per domain a list of available services."""
|
||||||
|
@ -957,7 +995,8 @@ class ServiceRegistry:
|
||||||
|
|
||||||
def call(self, domain: str, service: str,
|
def call(self, domain: str, service: str,
|
||||||
service_data: Optional[Dict] = None,
|
service_data: Optional[Dict] = None,
|
||||||
blocking: bool = False) -> Optional[bool]:
|
blocking: bool = False,
|
||||||
|
context: Optional[Context] = None) -> Optional[bool]:
|
||||||
"""
|
"""
|
||||||
Call a service.
|
Call a service.
|
||||||
|
|
||||||
|
@ -975,13 +1014,14 @@ class ServiceRegistry:
|
||||||
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
||||||
"""
|
"""
|
||||||
return run_coroutine_threadsafe( # type: ignore
|
return run_coroutine_threadsafe( # type: ignore
|
||||||
self.async_call(domain, service, service_data, blocking),
|
self.async_call(domain, service, service_data, blocking, context),
|
||||||
self._hass.loop
|
self._hass.loop
|
||||||
).result()
|
).result()
|
||||||
|
|
||||||
async def async_call(self, domain: str, service: str,
|
async def async_call(self, domain: str, service: str,
|
||||||
service_data: Optional[Dict] = None,
|
service_data: Optional[Dict] = None,
|
||||||
blocking: bool = False) -> Optional[bool]:
|
blocking: bool = False,
|
||||||
|
context: Optional[Context] = None) -> Optional[bool]:
|
||||||
"""
|
"""
|
||||||
Call a service.
|
Call a service.
|
||||||
|
|
||||||
|
@ -1000,44 +1040,42 @@ class ServiceRegistry:
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
call_id = self._generate_unique_id()
|
context = context or Context()
|
||||||
|
|
||||||
event_data = {
|
event_data = {
|
||||||
ATTR_DOMAIN: domain.lower(),
|
ATTR_DOMAIN: domain.lower(),
|
||||||
ATTR_SERVICE: service.lower(),
|
ATTR_SERVICE: service.lower(),
|
||||||
ATTR_SERVICE_DATA: service_data,
|
ATTR_SERVICE_DATA: service_data,
|
||||||
ATTR_SERVICE_CALL_ID: call_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if blocking:
|
if not blocking:
|
||||||
fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future
|
self._hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE, event_data, EventOrigin.local, context)
|
||||||
|
return None
|
||||||
|
|
||||||
@callback
|
fut = asyncio.Future() # type: asyncio.Future
|
||||||
def service_executed(event: Event) -> None:
|
|
||||||
"""Handle an executed service."""
|
|
||||||
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
|
||||||
fut.set_result(True)
|
|
||||||
|
|
||||||
unsub = self._hass.bus.async_listen(
|
@callback
|
||||||
EVENT_SERVICE_EXECUTED, service_executed)
|
def service_executed(event: Event) -> None:
|
||||||
|
"""Handle an executed service."""
|
||||||
|
if event.context == context:
|
||||||
|
fut.set_result(True)
|
||||||
|
|
||||||
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
|
unsub = self._hass.bus.async_listen(
|
||||||
|
EVENT_SERVICE_EXECUTED, service_executed)
|
||||||
|
|
||||||
done, _ = await asyncio.wait(
|
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data,
|
||||||
[fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT)
|
EventOrigin.local, context)
|
||||||
success = bool(done)
|
|
||||||
unsub()
|
|
||||||
return success
|
|
||||||
|
|
||||||
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
|
done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT)
|
||||||
return None
|
success = bool(done)
|
||||||
|
unsub()
|
||||||
|
return success
|
||||||
|
|
||||||
async def _event_to_service_call(self, event: Event) -> None:
|
async def _event_to_service_call(self, event: Event) -> None:
|
||||||
"""Handle the SERVICE_CALLED events from the EventBus."""
|
"""Handle the SERVICE_CALLED events from the EventBus."""
|
||||||
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
||||||
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
|
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
|
||||||
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
|
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
|
||||||
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
|
||||||
|
|
||||||
if not self.has_service(domain, service):
|
if not self.has_service(domain, service):
|
||||||
if event.origin == EventOrigin.local:
|
if event.origin == EventOrigin.local:
|
||||||
|
@ -1049,16 +1087,13 @@ class ServiceRegistry:
|
||||||
|
|
||||||
def fire_service_executed() -> None:
|
def fire_service_executed() -> None:
|
||||||
"""Fire service executed event."""
|
"""Fire service executed event."""
|
||||||
if not call_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
data = {ATTR_SERVICE_CALL_ID: call_id}
|
|
||||||
|
|
||||||
if (service_handler.is_coroutinefunction or
|
if (service_handler.is_coroutinefunction or
|
||||||
service_handler.is_callback):
|
service_handler.is_callback):
|
||||||
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data)
|
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, {},
|
||||||
|
EventOrigin.local, event.context)
|
||||||
else:
|
else:
|
||||||
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data)
|
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, {},
|
||||||
|
EventOrigin.local, event.context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if service_handler.schema:
|
if service_handler.schema:
|
||||||
|
@ -1069,7 +1104,8 @@ class ServiceRegistry:
|
||||||
fire_service_executed()
|
fire_service_executed()
|
||||||
return
|
return
|
||||||
|
|
||||||
service_call = ServiceCall(domain, service, service_data, call_id)
|
service_call = ServiceCall(
|
||||||
|
domain, service, service_data, event.context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if service_handler.is_callback:
|
if service_handler.is_callback:
|
||||||
|
|
|
@ -179,7 +179,7 @@ class Entity:
|
||||||
# produce undesirable effects in the entity's operation.
|
# produce undesirable effects in the entity's operation.
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_update_ha_state(self, force_refresh=False):
|
def async_update_ha_state(self, force_refresh=False, context=None):
|
||||||
"""Update Home Assistant with current state of entity.
|
"""Update Home Assistant with current state of entity.
|
||||||
|
|
||||||
If force_refresh == True will update entity before setting state.
|
If force_refresh == True will update entity before setting state.
|
||||||
|
@ -279,7 +279,7 @@ class Entity:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.hass.states.async_set(
|
self.hass.states.async_set(
|
||||||
self.entity_id, state, attr, self.force_update)
|
self.entity_id, state, attr, self.force_update, context)
|
||||||
|
|
||||||
def schedule_update_ha_state(self, force_refresh=False):
|
def schedule_update_ha_state(self, force_refresh=False):
|
||||||
"""Schedule an update ha state change task.
|
"""Schedule an update ha state change task.
|
||||||
|
|
|
@ -187,7 +187,7 @@ def async_mock_service(hass, domain, service, schema=None):
|
||||||
"""Set up a fake service & return a calls log list to this service."""
|
"""Set up a fake service & return a calls log list to this service."""
|
||||||
calls = []
|
calls = []
|
||||||
|
|
||||||
@asyncio.coroutine
|
@ha.callback
|
||||||
def mock_service_log(call): # pylint: disable=unnecessary-lambda
|
def mock_service_log(call): # pylint: disable=unnecessary-lambda
|
||||||
"""Mock service call."""
|
"""Mock service call."""
|
||||||
calls.append(call)
|
calls.append(call)
|
||||||
|
|
|
@ -5,12 +5,12 @@ import unittest.mock as mock
|
||||||
import os
|
import os
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant import core, loader
|
||||||
import homeassistant.loader as loader
|
from homeassistant.setup import setup_component, async_setup_component
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_ENTITY_ID, STATE_ON, STATE_OFF, CONF_PLATFORM,
|
ATTR_ENTITY_ID, STATE_ON, STATE_OFF, CONF_PLATFORM,
|
||||||
SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE, ATTR_SUPPORTED_FEATURES)
|
SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE, ATTR_SUPPORTED_FEATURES)
|
||||||
import homeassistant.components.light as light
|
from homeassistant.components import light
|
||||||
from homeassistant.helpers.intent import IntentHandleError
|
from homeassistant.helpers.intent import IntentHandleError
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
|
@ -475,3 +475,24 @@ async def test_intent_set_color_and_brightness(hass):
|
||||||
assert call.data.get(ATTR_ENTITY_ID) == 'light.hello_2'
|
assert call.data.get(ATTR_ENTITY_ID) == 'light.hello_2'
|
||||||
assert call.data.get(light.ATTR_RGB_COLOR) == (0, 0, 255)
|
assert call.data.get(light.ATTR_RGB_COLOR) == (0, 0, 255)
|
||||||
assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20
|
assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20
|
||||||
|
|
||||||
|
|
||||||
|
async def test_light_context(hass):
|
||||||
|
"""Test that light context works."""
|
||||||
|
assert await async_setup_component(hass, 'light', {
|
||||||
|
'light': {
|
||||||
|
'platform': 'test'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
state = hass.states.get('light.ceiling')
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
await hass.services.async_call('light', 'toggle', {
|
||||||
|
'entity_id': state.entity_id,
|
||||||
|
}, True, core.Context(user_id='abcd'))
|
||||||
|
|
||||||
|
state2 = hass.states.get('light.ceiling')
|
||||||
|
assert state2 is not None
|
||||||
|
assert state.state != state2.state
|
||||||
|
assert state2.context.user_id == 'abcd'
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from homeassistant.setup import setup_component
|
from homeassistant.setup import setup_component, async_setup_component
|
||||||
from homeassistant import loader
|
from homeassistant import core, loader
|
||||||
from homeassistant.components import switch
|
from homeassistant.components import switch
|
||||||
from homeassistant.const import STATE_ON, STATE_OFF, CONF_PLATFORM
|
from homeassistant.const import STATE_ON, STATE_OFF, CONF_PLATFORM
|
||||||
|
|
||||||
|
@ -91,3 +91,24 @@ class TestSwitch(unittest.TestCase):
|
||||||
'{} 2'.format(switch.DOMAIN): {CONF_PLATFORM: 'test2'},
|
'{} 2'.format(switch.DOMAIN): {CONF_PLATFORM: 'test2'},
|
||||||
}
|
}
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_switch_context(hass):
|
||||||
|
"""Test that switch context works."""
|
||||||
|
assert await async_setup_component(hass, 'switch', {
|
||||||
|
'switch': {
|
||||||
|
'platform': 'test'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
state = hass.states.get('switch.ac')
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
await hass.services.async_call('switch', 'toggle', {
|
||||||
|
'entity_id': state.entity_id,
|
||||||
|
}, True, core.Context(user_id='abcd'))
|
||||||
|
|
||||||
|
state2 = hass.states.get('switch.ac')
|
||||||
|
assert state2 is not None
|
||||||
|
assert state.state != state2.state
|
||||||
|
assert state2.context.user_id == 'abcd'
|
||||||
|
|
|
@ -12,6 +12,8 @@ from homeassistant.bootstrap import DATA_LOGGING
|
||||||
import homeassistant.core as ha
|
import homeassistant.core as ha
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.common import async_mock_service
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_api_client(hass, aiohttp_client):
|
def mock_api_client(hass, aiohttp_client):
|
||||||
|
@ -429,3 +431,58 @@ async def test_api_error_log(hass, aiohttp_client):
|
||||||
assert mock_file.mock_calls[0][1][0] == hass.data[DATA_LOGGING]
|
assert mock_file.mock_calls[0][1][0] == hass.data[DATA_LOGGING]
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
assert await resp.text() == 'Hello'
|
assert await resp.text() == 'Hello'
|
||||||
|
|
||||||
|
|
||||||
|
async def test_api_fire_event_context(hass, mock_api_client,
|
||||||
|
hass_access_token):
|
||||||
|
"""Test if the API sets right context if we fire an event."""
|
||||||
|
test_value = []
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
|
def listener(event):
|
||||||
|
"""Helper method that will verify our event got called."""
|
||||||
|
test_value.append(event)
|
||||||
|
|
||||||
|
hass.bus.async_listen("test.event", listener)
|
||||||
|
|
||||||
|
await mock_api_client.post(
|
||||||
|
const.URL_API_EVENTS_EVENT.format("test.event"),
|
||||||
|
headers={
|
||||||
|
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
||||||
|
})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(test_value) == 1
|
||||||
|
assert test_value[0].context.user_id == \
|
||||||
|
hass_access_token.refresh_token.user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_api_call_service_context(hass, mock_api_client,
|
||||||
|
hass_access_token):
|
||||||
|
"""Test if the API sets right context if we call a service."""
|
||||||
|
calls = async_mock_service(hass, 'test_domain', 'test_service')
|
||||||
|
|
||||||
|
await mock_api_client.post(
|
||||||
|
'/api/services/test_domain/test_service',
|
||||||
|
headers={
|
||||||
|
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
||||||
|
})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0].context.user_id == hass_access_token.refresh_token.user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
|
||||||
|
"""Test if the API sets right context if we set state."""
|
||||||
|
await mock_api_client.post(
|
||||||
|
'/api/states/light.kitchen',
|
||||||
|
json={
|
||||||
|
'state': 'on'
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
'authorization': 'Bearer {}'.format(hass_access_token.token)
|
||||||
|
})
|
||||||
|
|
||||||
|
state = hass.states.get('light.kitchen')
|
||||||
|
assert state.context.user_id == hass_access_token.refresh_token.user.id
|
||||||
|
|
|
@ -104,12 +104,14 @@ class TestMqttEventStream:
|
||||||
"state": "on",
|
"state": "on",
|
||||||
"entity_id": e_id,
|
"entity_id": e_id,
|
||||||
"attributes": {},
|
"attributes": {},
|
||||||
"last_changed": now.isoformat()
|
"last_changed": now.isoformat(),
|
||||||
}
|
}
|
||||||
event['event_data'] = {"new_state": new_state, "entity_id": e_id}
|
event['event_data'] = {"new_state": new_state, "entity_id": e_id}
|
||||||
|
|
||||||
# Verify that the message received was that expected
|
# Verify that the message received was that expected
|
||||||
assert json.loads(msg) == event
|
result = json.loads(msg)
|
||||||
|
result['event_data']['new_state'].pop('context')
|
||||||
|
assert result == event
|
||||||
|
|
||||||
@patch('homeassistant.components.mqtt.async_publish')
|
@patch('homeassistant.components.mqtt.async_publish')
|
||||||
def test_time_event_does_not_send_message(self, mock_pub):
|
def test_time_event_does_not_send_message(self, mock_pub):
|
||||||
|
|
|
@ -10,7 +10,7 @@ from homeassistant.core import callback
|
||||||
from homeassistant.components import websocket_api as wapi
|
from homeassistant.components import websocket_api as wapi
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import mock_coro
|
from tests.common import mock_coro, async_mock_service
|
||||||
|
|
||||||
API_PASSWORD = 'test1234'
|
API_PASSWORD = 'test1234'
|
||||||
|
|
||||||
|
@ -443,3 +443,94 @@ async def test_auth_with_invalid_token(hass, aiohttp_client):
|
||||||
|
|
||||||
auth_msg = await ws.receive_json()
|
auth_msg = await ws.receive_json()
|
||||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_service_context_with_user(hass, aiohttp_client,
|
||||||
|
hass_access_token):
|
||||||
|
"""Test that the user is set in the service call context."""
|
||||||
|
assert await async_setup_component(hass, 'websocket_api', {
|
||||||
|
'http': {
|
||||||
|
'api_password': API_PASSWORD
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||||
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
|
async with client.ws_connect(wapi.URL) as ws:
|
||||||
|
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||||
|
auth_active.return_value = True
|
||||||
|
auth_msg = await ws.receive_json()
|
||||||
|
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
|
await ws.send_json({
|
||||||
|
'type': wapi.TYPE_AUTH,
|
||||||
|
'access_token': hass_access_token.token
|
||||||
|
})
|
||||||
|
|
||||||
|
auth_msg = await ws.receive_json()
|
||||||
|
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||||
|
|
||||||
|
await ws.send_json({
|
||||||
|
'id': 5,
|
||||||
|
'type': wapi.TYPE_CALL_SERVICE,
|
||||||
|
'domain': 'domain_test',
|
||||||
|
'service': 'test_service',
|
||||||
|
'service_data': {
|
||||||
|
'hello': 'world'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
msg = await ws.receive_json()
|
||||||
|
assert msg['success']
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
assert call.domain == 'domain_test'
|
||||||
|
assert call.service == 'test_service'
|
||||||
|
assert call.data == {'hello': 'world'}
|
||||||
|
assert call.context.user_id == hass_access_token.refresh_token.user.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_call_service_context_no_user(hass, aiohttp_client):
|
||||||
|
"""Test that connection without user sets context."""
|
||||||
|
assert await async_setup_component(hass, 'websocket_api', {
|
||||||
|
'http': {
|
||||||
|
'api_password': API_PASSWORD
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||||
|
client = await aiohttp_client(hass.http.app)
|
||||||
|
|
||||||
|
async with client.ws_connect(wapi.URL) as ws:
|
||||||
|
auth_msg = await ws.receive_json()
|
||||||
|
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||||
|
|
||||||
|
await ws.send_json({
|
||||||
|
'type': wapi.TYPE_AUTH,
|
||||||
|
'api_password': API_PASSWORD
|
||||||
|
})
|
||||||
|
|
||||||
|
auth_msg = await ws.receive_json()
|
||||||
|
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||||
|
|
||||||
|
await ws.send_json({
|
||||||
|
'id': 5,
|
||||||
|
'type': wapi.TYPE_CALL_SERVICE,
|
||||||
|
'domain': 'domain_test',
|
||||||
|
'service': 'test_service',
|
||||||
|
'service_data': {
|
||||||
|
'hello': 'world'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
msg = await ws.receive_json()
|
||||||
|
assert msg['success']
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
assert call.domain == 'domain_test'
|
||||||
|
assert call.service == 'test_service'
|
||||||
|
assert call.data == {'hello': 'world'}
|
||||||
|
assert call.context.user_id is None
|
||||||
|
|
|
@ -163,10 +163,10 @@ def test_zwave_ready_wait(hass, mock_openzwave):
|
||||||
asyncio_sleep = asyncio.sleep
|
asyncio_sleep = asyncio.sleep
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def sleep(duration, loop):
|
def sleep(duration, loop=None):
|
||||||
if duration > 0:
|
if duration > 0:
|
||||||
sleeps.append(duration)
|
sleeps.append(duration)
|
||||||
yield from asyncio_sleep(0, loop=loop)
|
yield from asyncio_sleep(0)
|
||||||
|
|
||||||
with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow):
|
with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow):
|
||||||
with patch('asyncio.sleep', new=sleep):
|
with patch('asyncio.sleep', new=sleep):
|
||||||
|
@ -248,10 +248,10 @@ async def test_unparsed_node_discovery(hass, mock_openzwave):
|
||||||
|
|
||||||
asyncio_sleep = asyncio.sleep
|
asyncio_sleep = asyncio.sleep
|
||||||
|
|
||||||
async def sleep(duration, loop):
|
async def sleep(duration, loop=None):
|
||||||
if duration > 0:
|
if duration > 0:
|
||||||
sleeps.append(duration)
|
sleeps.append(duration)
|
||||||
await asyncio_sleep(0, loop=loop)
|
await asyncio_sleep(0)
|
||||||
|
|
||||||
with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow):
|
with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow):
|
||||||
with patch('asyncio.sleep', new=sleep):
|
with patch('asyncio.sleep', new=sleep):
|
||||||
|
|
|
@ -277,6 +277,10 @@ class TestEvent(unittest.TestCase):
|
||||||
'data': data,
|
'data': data,
|
||||||
'origin': 'LOCAL',
|
'origin': 'LOCAL',
|
||||||
'time_fired': now,
|
'time_fired': now,
|
||||||
|
'context': {
|
||||||
|
'id': event.context.id,
|
||||||
|
'user_id': event.context.user_id,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
self.assertEqual(expected, event.as_dict())
|
self.assertEqual(expected, event.as_dict())
|
||||||
|
|
||||||
|
@ -598,18 +602,16 @@ class TestStateMachine(unittest.TestCase):
|
||||||
self.assertEqual(1, len(events))
|
self.assertEqual(1, len(events))
|
||||||
|
|
||||||
|
|
||||||
class TestServiceCall(unittest.TestCase):
|
def test_service_call_repr():
|
||||||
"""Test ServiceCall class."""
|
"""Test ServiceCall repr."""
|
||||||
|
call = ha.ServiceCall('homeassistant', 'start')
|
||||||
|
assert str(call) == \
|
||||||
|
"<ServiceCall homeassistant.start (c:{})>".format(call.context.id)
|
||||||
|
|
||||||
def test_repr(self):
|
call2 = ha.ServiceCall('homeassistant', 'start', {'fast': 'yes'})
|
||||||
"""Test repr method."""
|
assert str(call2) == \
|
||||||
self.assertEqual(
|
"<ServiceCall homeassistant.start (c:{}): fast=yes>".format(
|
||||||
"<ServiceCall homeassistant.start>",
|
call2.context.id)
|
||||||
str(ha.ServiceCall('homeassistant', 'start')))
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
"<ServiceCall homeassistant.start: fast=yes>",
|
|
||||||
str(ha.ServiceCall('homeassistant', 'start', {"fast": "yes"})))
|
|
||||||
|
|
||||||
|
|
||||||
class TestServiceRegistry(unittest.TestCase):
|
class TestServiceRegistry(unittest.TestCase):
|
||||||
|
|
Loading…
Reference in a new issue