Cleaned up the core.

This commit is contained in:
Paulus Schoutsen 2013-10-08 18:50:30 -07:00
parent f1042cd136
commit 71bd03ed8c
3 changed files with 37 additions and 59 deletions

View file

@ -13,6 +13,8 @@ from collections import defaultdict, namedtuple
from itertools import chain
from datetime import datetime
logging.basicConfig(level=logging.INFO)
ALL_EVENTS = '*'
EVENT_START = "start"
EVENT_SHUTDOWN = "shutdown"
@ -32,7 +34,7 @@ def start_home_assistant(eventbus):
""" Start home assistant. """
Timer(eventbus)
eventbus.fire(Event(EVENT_START))
eventbus.fire(EVENT_START)
while True:
try:
@ -40,7 +42,7 @@ def start_home_assistant(eventbus):
except KeyboardInterrupt:
print ""
eventbus.fire(Event(EVENT_SHUTDOWN))
eventbus.fire(EVENT_SHUTDOWN)
break
@ -62,8 +64,6 @@ def track_state_change(eventbus, category, from_state, to_state, action):
def listener(event):
""" State change listener that listens for specific state changes. """
assert isinstance(event, Event), "event needs to be of Event type"
if category == event.data['category'] and \
matcher(event.data['old_state'].state, from_state) and \
matcher(event.data['new_state'].state, to_state):
@ -86,8 +86,6 @@ def track_time_change(eventbus, action,
def listener(event):
""" Listens for matching time_changed events. """
assert isinstance(event, Event), "event needs to be of Event type"
if (point_in_time and event.data['now'] > point_in_time) or \
(not point_in_time and \
matcher(event.data['now'].year, year) and \
@ -99,80 +97,60 @@ def track_time_change(eventbus, action,
# point_in_time are exact points in time
# so we always remove it after fire
event.remove_listener = listen_once or point_in_time is not None
if listen_once or point_in_time:
event.eventbus.remove_listener(EVENT_TIME_CHANGED, listener)
action(event.data['now'])
eventbus.listen(EVENT_TIME_CHANGED, listener)
Event = namedtuple("Event", ["eventbus", "event_type", "data"])
class EventBus(object):
""" Class that allows code to listen for- and fire events. """
def __init__(self):
self.listeners = defaultdict(list)
self.lock = threading.RLock()
self.logger = logging.getLogger(__name__)
def fire(self, event):
def fire(self, event_type, event_data=None):
""" Fire an event. """
assert isinstance(event, Event), \
"event needs to be an instance of Event"
if not event_data:
event_data = {}
self.logger.info("EventBus:Event {}: {}".format(
event_type, event_data))
def run():
""" We dont want the eventbus to be blocking - run in a thread. """
self.lock.acquire()
""" Fire listeners for event. """
event = Event(self, event_type, event_data)
self.logger.info("EventBus:Event {}: {}".format(
event.event_type, event.data))
for callback in chain(self.listeners[ALL_EVENTS],
for listener in chain(self.listeners[ALL_EVENTS],
self.listeners[event.event_type]):
try:
callback(event)
listener(event)
except Exception: #pylint: disable=broad-except
self.logger.exception("EventBus:Exception in listener")
if event.remove_listener:
if callback in self.listeners[ALL_EVENTS]:
self.listeners[ALL_EVENTS].remove(callback)
if callback in self.listeners[event.event_type]:
self.listeners[event.event_type].remove(callback)
event.remove_listener = False
if event.stop_propegating:
break
self.lock.release()
# We dont want the eventbus to be blocking - run in a thread.
threading.Thread(target=run).start()
def listen(self, event_type, callback):
def listen(self, event_type, listener):
""" Listen for all events or events of a specific type.
To listen to all events specify the constant ``ALL_EVENTS``
as event_type.
"""
self.lock.acquire()
self.listeners[event_type].append(listener)
self.listeners[event_type].append(callback)
self.lock.release()
# pylint: disable=too-few-public-methods
class Event(object):
""" An event to be sent over the eventbus. """
def __init__(self, event_type, data=None):
self.event_type = event_type
self.data = {} if data is None else data
self.stop_propegating = False
self.remove_listener = False
def __str__(self):
return str([self.event_type, self.data])
def remove_listener(self, event_type, listener):
""" Removes a listener of a specific event_type. """
try:
self.listeners[event_type].remove(listener)
except ValueError:
pass
class StateMachine(object):
""" Helper class that tracks the state of different objects. """
@ -180,7 +158,7 @@ class StateMachine(object):
def __init__(self, eventbus):
self.states = dict()
self.eventbus = eventbus
self.lock = threading.RLock()
self.lock = threading.Lock()
def set_state(self, category, new_state):
""" Set the state of a category, add category is it does not exist. """
@ -198,10 +176,10 @@ class StateMachine(object):
if old_state.state != new_state:
self.states[category] = State(new_state, datetime.now())
self.eventbus.fire(Event(EVENT_STATE_CHANGED,
self.eventbus.fire(EVENT_STATE_CHANGED,
{'category':category,
'old_state':old_state,
'new_state':self.states[category]}))
'new_state':self.states[category]})
self.lock.release()
@ -265,7 +243,7 @@ class Timer(threading.Thread):
if self._stop.isSet():
break
self.eventbus.fire(Event(EVENT_TIME_CHANGED, {'now':now}))
self.eventbus.fire(EVENT_TIME_CHANGED, {'now':now})
class HomeAssistantException(Exception):
""" General Home Assistant exception occured. """

View file

@ -33,7 +33,7 @@ from urlparse import urlparse, parse_qs
import requests
from . import EVENT_START, EVENT_SHUTDOWN, Event
from . import EVENT_START, EVENT_SHUTDOWN
SERVER_PORT = 8123
@ -229,7 +229,7 @@ class RequestHandler(BaseHTTPRequestHandler):
else:
event_data = json.loads(post_data['event_data'][0])
self.server.eventbus.fire(Event(event_name, event_data))
self.server.eventbus.fire(event_name, event_data)
self._message(use_json, "Event {} fired.".
format(event_name))

View file

@ -11,7 +11,7 @@ import time
import requests
from . import EventBus, StateMachine, Event, EVENT_START, EVENT_SHUTDOWN
from . import EventBus, StateMachine, EVENT_START, EVENT_SHUTDOWN
from .httpinterface import HTTPInterface, SERVER_PORT
@ -34,7 +34,7 @@ class TestHTTPInterface(unittest.TestCase):
self.statemachine.set_state("test", "INIT_STATE")
self.eventbus.fire(Event(EVENT_START))
self.eventbus.fire(EVENT_START)
# Give objects time to startup
time.sleep(1)
@ -48,7 +48,7 @@ class TestHTTPInterface(unittest.TestCase):
@classmethod
def tearDownClass(cls): # pylint: disable=invalid-name
""" things to be run when tests are done. """
cls.eventbus.fire(Event(EVENT_SHUTDOWN))
cls.eventbus.fire(EVENT_SHUTDOWN)
time.sleep(1)