Events and States are no longer dicts but objects.

This commit is contained in:
Paulus Schoutsen
2014-01-19 19:10:40 -08:00
parent ae2058de70
commit 3c3e7e5825
9 changed files with 212 additions and 129 deletions

View File

@@ -2,19 +2,21 @@
homeassistant homeassistant
~~~~~~~~~~~~~ ~~~~~~~~~~~~~
Module to control the lights based on devices at home and the state of the sun. Home Assistant is a Home Automation framework for observing the state
of objects and react to changes.
""" """
import time import time
import logging import logging
import threading import threading
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from datetime import datetime import datetime as dt
import homeassistant.util as util
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
ALL_EVENTS = '*' MATCH_ALL = '*'
DOMAIN = "homeassistant" DOMAIN = "homeassistant"
@@ -38,8 +40,6 @@ TIMER_INTERVAL = 10 # seconds
# every minute. # every minute.
assert 60 % TIMER_INTERVAL == 0, "60 % TIMER_INTERVAL should be 0!" assert 60 % TIMER_INTERVAL == 0, "60 % TIMER_INTERVAL should be 0!"
DATE_STR_FORMAT = "%H:%M:%S %d-%m-%Y"
def start_home_assistant(bus): def start_home_assistant(bus):
""" Start home assistant. """ """ Start home assistant. """
@@ -60,37 +60,22 @@ def start_home_assistant(bus):
break break
def datetime_to_str(dattim): def _process_match_param(parameter):
""" Converts datetime to a string format. """ Wraps parameter in a list if it is not one and returns it. """
if parameter is None:
@rtype : str return MATCH_ALL
""" elif isinstance(parameter, list):
return dattim.strftime(DATE_STR_FORMAT) return parameter
else:
return [parameter]
def str_to_datetime(dt_str):
""" Converts a string to a datetime object.
@rtype: datetime
"""
return datetime.strptime(dt_str, DATE_STR_FORMAT)
def _ensure_list(parameter):
""" Wraps parameter in a list if it is not one and returns it.
@rtype : list
"""
return parameter if isinstance(parameter, list) else [parameter]
def _matcher(subject, pattern): def _matcher(subject, pattern):
""" Returns True if subject matches the pattern. """ Returns True if subject matches the pattern.
Pattern is either a list of allowed subjects or a '*'. Pattern is either a list of allowed subjects or a `MATCH_ALL`.
@rtype : bool
""" """
return '*' in pattern or subject in pattern return MATCH_ALL == pattern or subject in pattern
def split_state_category(category): def split_state_category(category):
@@ -98,36 +83,26 @@ def split_state_category(category):
return category.split(".", 1) return category.split(".", 1)
def filter_categories(categories, domain_filter=None, object_id_only=False): def filter_categories(categories, domain_filter=None, strip_domain=False):
""" Filter a list of categories based on domain. Setting object_id_only """ Filter a list of categories based on domain. Setting strip_domain
will only return the object_ids. """ will only return the object_ids. """
return [ return [
split_state_category(cat)[1] if object_id_only else cat split_state_category(cat)[1] if strip_domain else cat
for cat in categories if for cat in categories if
not domain_filter or cat.startswith(domain_filter) not domain_filter or cat.startswith(domain_filter)
] ]
def create_state(state, attributes=None, last_changed=None):
""" Creates a new state and initializes defaults where necessary. """
attributes = attributes or {}
last_changed = last_changed or datetime.now()
return {'state': state,
'attributes': attributes,
'last_changed': datetime_to_str(last_changed)}
def track_state_change(bus, category, action, from_state=None, to_state=None): def track_state_change(bus, category, action, from_state=None, to_state=None):
""" Helper method to track specific state changes. """ """ Helper method to track specific state changes. """
from_state = _ensure_list(from_state) if from_state else [ALL_EVENTS] from_state = _process_match_param(from_state)
to_state = _ensure_list(to_state) if to_state else [ALL_EVENTS] to_state = _process_match_param(to_state)
def listener(event): def listener(event):
""" State change listener that listens for specific state changes. """ """ State change listener that listens for specific state changes. """
if category == event.data['category'] and \ if category == event.data['category'] and \
_matcher(event.data['old_state']['state'], from_state) and \ _matcher(event.data['old_state'].state, from_state) and \
_matcher(event.data['new_state']['state'], to_state): _matcher(event.data['new_state'].state, to_state):
action(event.data['category'], action(event.data['category'],
event.data['old_state'], event.data['old_state'],
@@ -138,19 +113,19 @@ def track_state_change(bus, category, action, from_state=None, to_state=None):
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def track_time_change(bus, action, def track_time_change(bus, action,
year='*', month='*', day='*', year=None, month=None, day=None,
hour='*', minute='*', second='*', hour=None, minute=None, second=None,
point_in_time=None, listen_once=False): point_in_time=None, listen_once=False):
""" Adds a listener that will listen for a specified or matching time. """ """ Adds a listener that will listen for a specified or matching time. """
year, month = _ensure_list(year), _ensure_list(month) year, month = _process_match_param(year), _process_match_param(month)
day = _ensure_list(day) day = _process_match_param(day)
hour, minute = _ensure_list(hour), _ensure_list(minute) hour, minute = _process_match_param(hour), _process_match_param(minute)
second = _ensure_list(second) second = _process_match_param(second)
def listener(event): def listener(event):
""" Listens for matching time_changed events. """ """ Listens for matching time_changed events. """
now = str_to_datetime(event.data['now']) now = event.data['now']
if (point_in_time and now > point_in_time) or \ if (point_in_time and now > point_in_time) or \
(not point_in_time and (not point_in_time and
@@ -180,7 +155,7 @@ class Bus(object):
""" """
def __init__(self): def __init__(self):
self._event_listeners = defaultdict(list) self._event_listeners = {}
self._services = {} self._services = {}
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
@@ -196,8 +171,7 @@ class Bus(object):
of listeners. of listeners.
""" """
return {key: len(self._event_listeners[key]) return {key: len(self._event_listeners[key])
for key in self._event_listeners.keys() for key in self._event_listeners}
if len(self._event_listeners[key]) > 0}
def call_service(self, domain, service, service_data=None): def call_service(self, domain, service, service_data=None):
""" Calls a service. """ """ Calls a service. """
@@ -236,8 +210,16 @@ class Bus(object):
def fire_event(self, event_type, event_data=None): def fire_event(self, event_type, event_data=None):
""" Fire an event. """ """ Fire an event. """
if not event_data: # Copy the list of the current listeners because some listeners
event_data = {} # choose to remove themselves as a listener while being executed
# which causes the iterator to be confused.
listeners = self._event_listeners.get(MATCH_ALL, []) + \
self._event_listeners.get(event_type, [])
if not listeners:
return
event_data = event_data or {}
self.logger.info("Bus:Event {}: {}".format( self.logger.info("Bus:Event {}: {}".format(
event_type, event_data)) event_type, event_data))
@@ -246,10 +228,7 @@ class Bus(object):
""" Fire listeners for event. """ """ Fire listeners for event. """
event = Event(self, event_type, event_data) event = Event(self, event_type, event_data)
# We do not use itertools.chain() because some listeners might for listener in listeners:
# choose to remove themselves as a listener while being executed
for listener in self._event_listeners[ALL_EVENTS] + \
self._event_listeners[event.event_type]:
try: try:
listener(event) listener(event)
@@ -262,15 +241,19 @@ class Bus(object):
def listen_event(self, event_type, listener): def listen_event(self, event_type, listener):
""" Listen for all events or events of a specific type. """ Listen for all events or events of a specific type.
To listen to all events specify the constant ``ALL_EVENTS`` To listen to all events specify the constant ``MATCH_ALL``
as event_type. as event_type.
""" """
self._event_listeners[event_type].append(listener) try:
self._event_listeners[event_type].append(listener)
except KeyError: # event_type did not exist
self._event_listeners[event_type] = [listener]
def listen_once_event(self, event_type, listener): def listen_once_event(self, event_type, listener):
""" Listen once for event of a specific type. """ Listen once for event of a specific type.
To listen to all events specify the constant ``ALL_EVENTS`` To listen to all events specify the constant ``MATCH_ALL``
as event_type. as event_type.
Note: at the moment it is impossible to remove a one time listener. Note: at the moment it is impossible to remove a one time listener.
@@ -292,10 +275,67 @@ class Bus(object):
if len(self._event_listeners[event_type]) == 0: if len(self._event_listeners[event_type]) == 0:
del self._event_listeners[event_type] del self._event_listeners[event_type]
except ValueError: except (KeyError, ValueError):
pass pass
class State(object):
""" Object to represent a state within the state machine. """
def __init__(self, state, attributes=None, last_changed=None):
self.state = state
self.attributes = attributes or {}
last_changed = last_changed or dt.datetime.now()
# Strip microsecond from last_changed else we cannot guarantee
# state == State.from_json_dict(state.to_json_dict())
# This behavior occurs because to_json_dict strips microseconds
if last_changed.microsecond:
self.last_changed = last_changed - dt.timedelta(
microseconds=last_changed.microsecond)
else:
self.last_changed = last_changed
def to_json_dict(self, category=None):
""" Converts State to a dict to be used within JSON.
Ensures: state == State.from_json_dict(state.to_json_dict()) """
json_dict = {'state': self.state,
'attributes': self.attributes,
'last_changed': util.datetime_to_str(self.last_changed)}
if category:
json_dict['category'] = category
return json_dict
def copy(self):
""" Creates a copy of itself. """
return State(self.state, dict(self.attributes), self.last_changed)
@staticmethod
def from_json_dict(json_dict):
""" Static method to create a state from a dict.
Ensures: state == State.from_json_dict(state.to_json_dict()) """
try:
last_changed = json_dict.get('last_changed')
if last_changed:
last_changed = util.str_to_datetime(last_changed)
return State(json_dict['state'],
json_dict.get('attributes'),
last_changed)
except KeyError: # if key 'state' did not exist
return None
def __repr__(self):
return "{}({}, {})".format(
self.state, self.attributes,
util.datetime_to_str(self.last_changed))
class StateMachine(object): class StateMachine(object):
""" Helper class that tracks the state of different categories. """ """ Helper class that tracks the state of different categories. """
@@ -333,16 +373,16 @@ class StateMachine(object):
# Add category if it does not exist # Add category if it does not exist
if category not in self.states: if category not in self.states:
self.states[category] = create_state(new_state, attributes) self.states[category] = State(new_state, attributes)
# Change state and fire listeners # Change state and fire listeners
else: else:
old_state = self.states[category] old_state = self.states[category]
if old_state['state'] != new_state or \ if old_state.state != new_state or \
old_state['attributes'] != attributes: old_state.attributes != attributes:
self.states[category] = create_state(new_state, attributes) self.states[category] = State(new_state, attributes)
self.bus.fire_event(EVENT_STATE_CHANGED, self.bus.fire_event(EVENT_STATE_CHANGED,
{'category': category, {'category': category,
@@ -356,7 +396,7 @@ class StateMachine(object):
the state of the specified category. """ the state of the specified category. """
try: try:
# Make a copy so people won't mutate the state # Make a copy so people won't mutate the state
return dict(self.states[category]) return self.states[category].copy()
except KeyError: except KeyError:
# If category does not exist # If category does not exist
@@ -366,7 +406,7 @@ class StateMachine(object):
""" Returns True if category exists and is specified state. """ """ Returns True if category exists and is specified state. """
cur_state = self.get_state(category) cur_state = self.get_state(category)
return cur_state and cur_state['state'] == state return cur_state and cur_state.state == state
class Timer(threading.Thread): class Timer(threading.Thread):
@@ -389,7 +429,7 @@ class Timer(threading.Thread):
last_fired_on_second = -1 last_fired_on_second = -1
while True: while True:
now = datetime.now() now = dt.datetime.now()
# First check checks if we are not on a second matching the # First check checks if we are not on a second matching the
# timer interval. Second check checks if we did not already fire # timer interval. Second check checks if we did not already fire
@@ -407,12 +447,12 @@ class Timer(threading.Thread):
time.sleep(slp_seconds) time.sleep(slp_seconds)
now = datetime.now() now = dt.datetime.now()
last_fired_on_second = now.second last_fired_on_second = now.second
self.bus.fire_event(EVENT_TIME_CHANGED, self.bus.fire_event(EVENT_TIME_CHANGED,
{'now': datetime_to_str(now)}) {'now': now})
class HomeAssistantException(Exception): class HomeAssistantException(Exception):

View File

@@ -36,10 +36,10 @@ def turn_off(statemachine, cc_id=None):
state = statemachine.get_state(cat) state = statemachine.get_state(cat)
if state and \ if state and \
state['state'] != STATE_NO_APP or \ state.state != STATE_NO_APP or \
state['state'] != pychromecast.APP_ID_HOME: state.state != pychromecast.APP_ID_HOME:
pychromecast.quit_app(state['attributes'][ATTR_HOST]) pychromecast.quit_app(state.attributes[ATTR_HOST])
def setup(bus, statemachine, host): def setup(bus, statemachine, host):

View File

@@ -92,7 +92,7 @@ def setup(bus, statemachine, light_group=None):
# Specific device came home ? # Specific device came home ?
if (category != device_tracker.STATE_CATEGORY_ALL_DEVICES and if (category != device_tracker.STATE_CATEGORY_ALL_DEVICES and
new_state['state'] == ha.STATE_HOME): new_state.state == ha.STATE_HOME):
# These variables are needed for the elif check # These variables are needed for the elif check
now = datetime.now() now = datetime.now()
@@ -128,7 +128,7 @@ def setup(bus, statemachine, light_group=None):
# Did all devices leave the house? # Did all devices leave the house?
elif (category == device_tracker.STATE_CATEGORY_ALL_DEVICES and elif (category == device_tracker.STATE_CATEGORY_ALL_DEVICES and
new_state['state'] == ha.STATE_NOT_HOME and lights_are_on): new_state.state == ha.STATE_NOT_HOME and lights_are_on):
logger.info( logger.info(
"Everyone has left but there are devices on. Turning them off") "Everyone has left but there are devices on. Turning them off")

View File

@@ -35,12 +35,11 @@ def is_on(statemachine, group):
state = statemachine.get_state(group) state = statemachine.get_state(group)
if state: if state:
group_type = _get_group_type(state['state']) group_type = _get_group_type(state.state)
if group_type: if group_type:
group_on = _GROUP_TYPES[group_type][0] # We found group_type, compare to ON-state
return state.state == _GROUP_TYPES[group_type][0]
return state['state'] == group_on
else: else:
return False return False
else: else:
@@ -51,7 +50,7 @@ def get_categories(statemachine, group):
""" Get the categories that make up this group. """ """ Get the categories that make up this group. """
state = statemachine.get_state(group) state = statemachine.get_state(group)
return state['attributes'][STATE_ATTR_CATEGORIES] if state else [] return state.attributes[STATE_ATTR_CATEGORIES] if state else []
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
@@ -73,7 +72,7 @@ def setup(bus, statemachine, name, categories):
# Try to determine group type if we didn't yet # Try to determine group type if we didn't yet
if not group_type and state: if not group_type and state:
group_type = _get_group_type(state['state']) group_type = _get_group_type(state.state)
if group_type: if group_type:
group_on, group_off = _GROUP_TYPES[group_type] group_on, group_off = _GROUP_TYPES[group_type]
@@ -82,7 +81,7 @@ def setup(bus, statemachine, name, categories):
else: else:
# We did not find a matching group_type # We did not find a matching group_type
errors.append("Found unexpected state '{}'".format( errors.append("Found unexpected state '{}'".format(
name, state['state'])) name, state.state))
break break
@@ -91,13 +90,13 @@ def setup(bus, statemachine, name, categories):
errors.append("Category {} does not exist".format(cat)) errors.append("Category {} does not exist".format(cat))
# Check if category is valid state # Check if category is valid state
elif state['state'] != group_off and state['state'] != group_on: elif state.state != group_off and state.state != group_on:
errors.append("State of {} is {} (expected: {}, {})".format( errors.append("State of {} is {} (expected: {}, {})".format(
cat, state['state'], group_off, group_on)) cat, state.state, group_off, group_on))
# Keep track of the group state to init later on # Keep track of the group state to init later on
elif group_state == group_off and state['state'] == group_on: elif group_state == group_off and state.state == group_on:
group_state = group_on group_state = group_on
if errors: if errors:
@@ -114,17 +113,17 @@ def setup(bus, statemachine, name, categories):
""" Updates the group state based on a state change by a tracked """ Updates the group state based on a state change by a tracked
category. """ category. """
cur_group_state = statemachine.get_state(group_cat)['state'] cur_group_state = statemachine.get_state(group_cat).state
# if cur_group_state = OFF and new_state = ON: set ON # if cur_group_state = OFF and new_state = ON: set ON
# if cur_group_state = ON and new_state = OFF: research # if cur_group_state = ON and new_state = OFF: research
# else: ignore # else: ignore
if cur_group_state == group_off and new_state['state'] == group_on: if cur_group_state == group_off and new_state.state == group_on:
statemachine.set_state(group_cat, group_on, state_attr) statemachine.set_state(group_cat, group_on, state_attr)
elif cur_group_state == group_on and new_state['state'] == group_off: elif cur_group_state == group_on and new_state.state == group_off:
# Check if any of the other states is still on # Check if any of the other states is still on
if not any([statemachine.is_state(cat, group_on) if not any([statemachine.is_state(cat, group_on)

View File

@@ -341,16 +341,16 @@ class RequestHandler(BaseHTTPRequestHandler):
state = self.server.statemachine.get_state(category) state = self.server.statemachine.get_state(category)
attributes = "<br>".join( attributes = "<br>".join(
["{}: {}".format(attr, state['attributes'][attr]) ["{}: {}".format(attr, state.attributes[attr])
for attr in state['attributes']]) for attr in state.attributes])
write(("<tr>" write(("<tr>"
"<td>{}</td><td>{}</td><td>{}</td><td>{}</td>" "<td>{}</td><td>{}</td><td>{}</td><td>{}</td>"
"</tr>").format( "</tr>").format(
category, category,
state['state'], state.state,
attributes, attributes,
state['last_changed'])) state.last_changed))
# Change state form # Change state form
write(("<tr><td><input name='category' class='form-control' " write(("<tr><td><input name='category' class='form-control' "
@@ -518,9 +518,8 @@ class RequestHandler(BaseHTTPRequestHandler):
if self.use_json: if self.use_json:
state = self.server.statemachine.get_state(category) state = self.server.statemachine.get_state(category)
state['category'] = category self._write_json(state.to_json_dict(category),
status_code=HTTP_CREATED,
self._write_json(state, status_code=HTTP_CREATED,
location= location=
URL_API_STATES_CATEGORY.format(category)) URL_API_STATES_CATEGORY.format(category))
else: else:
@@ -619,10 +618,7 @@ class RequestHandler(BaseHTTPRequestHandler):
state = self.server.statemachine.get_state(category) state = self.server.statemachine.get_state(category)
if state: if state:
state['category'] = category self._write_json(state.to_json_dict(category))
self._write_json(state)
else: else:
# If category does not exist # If category does not exist
self._message("State does not exist.", HTTP_UNPROCESSABLE_ENTITY) self._message("State does not exist.", HTTP_UNPROCESSABLE_ENTITY)

View File

@@ -8,6 +8,7 @@ import logging
from datetime import timedelta from datetime import timedelta
import homeassistant as ha import homeassistant as ha
import homeassistant.util as util
STATE_CATEGORY = "weather.sun" STATE_CATEGORY = "weather.sun"
@@ -27,16 +28,16 @@ def next_setting(statemachine):
""" Returns the datetime object representing the next sun setting. """ """ Returns the datetime object representing the next sun setting. """
state = statemachine.get_state(STATE_CATEGORY) state = statemachine.get_state(STATE_CATEGORY)
return None if not state else ha.str_to_datetime( return None if not state else util.str_to_datetime(
state['attributes'][STATE_ATTR_NEXT_SETTING]) state.attributes[STATE_ATTR_NEXT_SETTING])
def next_rising(statemachine): def next_rising(statemachine):
""" Returns the datetime object representing the next sun setting. """ """ Returns the datetime object representing the next sun setting. """
state = statemachine.get_state(STATE_CATEGORY) state = statemachine.get_state(STATE_CATEGORY)
return None if not state else ha.str_to_datetime( return None if not state else util.str_to_datetime(
state['attributes'][STATE_ATTR_NEXT_RISING]) state.attributes[STATE_ATTR_NEXT_RISING])
def setup(bus, statemachine, latitude, longitude): def setup(bus, statemachine, latitude, longitude):
@@ -74,8 +75,8 @@ def setup(bus, statemachine, latitude, longitude):
next_change.strftime("%H:%M"))) next_change.strftime("%H:%M")))
state_attributes = { state_attributes = {
STATE_ATTR_NEXT_RISING: ha.datetime_to_str(next_rising_dt), STATE_ATTR_NEXT_RISING: util.datetime_to_str(next_rising_dt),
STATE_ATTR_NEXT_SETTING: ha.datetime_to_str(next_setting_dt) STATE_ATTR_NEXT_SETTING: util.datetime_to_str(next_setting_dt)
} }
statemachine.set_state(STATE_CATEGORY, new_state, state_attributes) statemachine.set_state(STATE_CATEGORY, new_state, state_attributes)

View File

@@ -49,6 +49,18 @@ def _setup_call_api(host, port, api_password):
return _call_api return _call_api
class JSONEncoder(json.JSONEncoder):
""" JSONEncoder that supports Home Assistant objects. """
def default(self, obj): # pylint: disable=method-hidden
""" Checks if Home Assistat object and encodes if possible.
Else hand it off to original method. """
if isinstance(obj, ha.State):
return obj.to_json_dict()
return json.JSONEncoder.default(self, obj)
class Bus(ha.Bus): class Bus(ha.Bus):
""" Drop-in replacement for a normal bus that will forward interaction to """ Drop-in replacement for a normal bus that will forward interaction to
a remote bus. a remote bus.
@@ -140,7 +152,10 @@ class Bus(ha.Bus):
def fire_event(self, event_type, event_data=None): def fire_event(self, event_type, event_data=None):
""" Fire an event. """ """ Fire an event. """
data = {'event_data': json.dumps(event_data)} if event_data else None if event_data:
data = {'event_data': json.dumps(event_data, cls=JSONEncoder)}
else:
data = None
req = self._call_api(METHOD_POST, req = self._call_api(METHOD_POST,
hah.URL_API_EVENTS_EVENT.format(event_type), hah.URL_API_EVENTS_EVENT.format(event_type),
@@ -159,6 +174,12 @@ class Bus(ha.Bus):
Will throw NotImplementedError. """ Will throw NotImplementedError. """
raise NotImplementedError raise NotImplementedError
def listen_once_event(self, event_type, listener):
""" Not implemented for remote bus.
Will throw NotImplementedError. """
raise NotImplementedError
def remove_event_listener(self, event_type, listener): def remove_event_listener(self, event_type, listener):
""" Not implemented for remote bus. """ Not implemented for remote bus.
@@ -201,6 +222,13 @@ class StateMachine(ha.StateMachine):
self.logger.exception("StateMachine:Got unexpected result (2)") self.logger.exception("StateMachine:Got unexpected result (2)")
return [] return []
def remove_category(self, category):
""" This method is not implemented for remote statemachine.
Throws NotImplementedError. """
raise NotImplementedError
def set_state(self, category, new_state, attributes=None): def set_state(self, category, new_state, attributes=None):
""" Set the state of a category, add category if it does not exist. """ Set the state of a category, add category if it does not exist.
@@ -243,9 +271,7 @@ class StateMachine(ha.StateMachine):
if req.status_code == 200: if req.status_code == 200:
data = req.json() data = req.json()
return ha.create_state(data['state'], data['attributes'], return ha.State.from_json_dict(data)
ha.str_to_datetime(
data['last_changed']))
elif req.status_code == 422: elif req.status_code == 422:
# Category does not exist # Category does not exist

View File

@@ -96,7 +96,7 @@ class TestHTTPInterface(unittest.TestCase):
"new_state": "debug_state_change2", "new_state": "debug_state_change2",
"api_password": API_PASSWORD}) "api_password": API_PASSWORD})
self.assertEqual(self.statemachine.get_state("test.test")['state'], self.assertEqual(self.statemachine.get_state("test.test").state,
"debug_state_change2") "debug_state_change2")
def test_debug_fire_event(self): def test_debug_fire_event(self):
@@ -138,14 +138,13 @@ class TestHTTPInterface(unittest.TestCase):
_url(hah.URL_API_STATES_CATEGORY.format("test")), _url(hah.URL_API_STATES_CATEGORY.format("test")),
data={"api_password": API_PASSWORD}) data={"api_password": API_PASSWORD})
data = req.json() data = ha.State.from_json_dict(req.json())
state = self.statemachine.get_state("test") state = self.statemachine.get_state("test")
self.assertEqual(data['category'], "test") self.assertEqual(data.state, state.state)
self.assertEqual(data['state'], state['state']) self.assertEqual(data.last_changed, state.last_changed)
self.assertEqual(data['last_changed'], state['last_changed']) self.assertEqual(data.attributes, state.attributes)
self.assertEqual(data['attributes'], state['attributes'])
def test_api_get_non_existing_state(self): def test_api_get_non_existing_state(self):
""" Test if the debug interface allows us to get a state. """ """ Test if the debug interface allows us to get a state. """
@@ -164,7 +163,7 @@ class TestHTTPInterface(unittest.TestCase):
data={"new_state": "debug_state_change2", data={"new_state": "debug_state_change2",
"api_password": API_PASSWORD}) "api_password": API_PASSWORD})
self.assertEqual(self.statemachine.get_state("test.test")['state'], self.assertEqual(self.statemachine.get_state("test.test").state,
"debug_state_change2") "debug_state_change2")
# pylint: disable=invalid-name # pylint: disable=invalid-name
@@ -181,7 +180,7 @@ class TestHTTPInterface(unittest.TestCase):
"api_password": API_PASSWORD}) "api_password": API_PASSWORD})
cur_state = (self.statemachine. cur_state = (self.statemachine.
get_state("test_category_that_does_not_exist")['state']) get_state("test_category_that_does_not_exist").state)
self.assertEqual(req.status_code, 201) self.assertEqual(req.status_code, 201)
self.assertEqual(cur_state, new_state) self.assertEqual(cur_state, new_state)
@@ -339,9 +338,9 @@ class TestRemote(unittest.TestCase):
state = self.statemachine.get_state("test") state = self.statemachine.get_state("test")
self.assertEqual(remote_state['state'], state['state']) self.assertEqual(remote_state.state, state.state)
self.assertEqual(remote_state['last_changed'], state['last_changed']) self.assertEqual(remote_state.last_changed, state.last_changed)
self.assertEqual(remote_state['attributes'], state['attributes']) self.assertEqual(remote_state.attributes, state.attributes)
def test_remote_sm_get_non_existing_state(self): def test_remote_sm_get_non_existing_state(self):
""" Test if the debug interface allows us to list state categories. """ """ Test if the debug interface allows us to list state categories. """
@@ -354,8 +353,8 @@ class TestRemote(unittest.TestCase):
state = self.statemachine.get_state("test") state = self.statemachine.get_state("test")
self.assertEqual(state['state'], "set_remotely") self.assertEqual(state.state, "set_remotely")
self.assertEqual(state['attributes']['test'], 1) self.assertEqual(state.attributes['test'], 1)
def test_remote_eb_listening_for_same(self): def test_remote_eb_listening_for_same(self):
""" Test if remote EB correctly reports listener overview. """ """ Test if remote EB correctly reports listener overview. """

View File

@@ -1,10 +1,13 @@
""" Helper methods for various modules. """ """ Helper methods for various modules. """
import datetime
import re import re
RE_SANITIZE_FILENAME = re.compile(r"(~|(\.\.)|/|\+)") RE_SANITIZE_FILENAME = re.compile(r"(~|(\.\.)|/|\+)")
RE_SLUGIFY = re.compile(r'[^A-Za-z0-9_]+') RE_SLUGIFY = re.compile(r'[^A-Za-z0-9_]+')
DATE_STR_FORMAT = "%H:%M:%S %d-%m-%Y"
def sanitize_filename(filename): def sanitize_filename(filename):
""" Sanitizes a filename by removing .. / and \\. """ """ Sanitizes a filename by removing .. / and \\. """
@@ -16,3 +19,22 @@ def slugify(text):
text = text.strip().replace(" ", "_") text = text.strip().replace(" ", "_")
return RE_SLUGIFY.sub("", text) return RE_SLUGIFY.sub("", text)
def datetime_to_str(dattim):
""" Converts datetime to a string format.
@rtype : str
"""
return dattim.strftime(DATE_STR_FORMAT)
def str_to_datetime(dt_str):
""" Converts a string to a datetime object.
@rtype: datetime
"""
try:
return datetime.datetime.strptime(dt_str, DATE_STR_FORMAT)
except ValueError: # If dt_str did not match our format
return None