Spread async love (#3575)

* Convert Entity.update_ha_state to be async

* Make Service.call async

* Update entity.py

* Add Entity.async_update

* Make automation zone trigger async

* Fix linting

* Reduce flakiness in hass.block_till_done

* Make automation.numeric_state async

* Make mqtt.subscribe async

* Make automation.mqtt async

* Make automation.time async

* Make automation.sun async

* Add async_track_point_in_utc_time

* Make helpers.track_sunrise/set async

* Add async_track_state_change

* Make automation.state async

* Clean up helpers/entity.py tests

* Lint

* Lint

* Core.is_state and Core.is_state_attr are async friendly

* Lint

* Lint
This commit is contained in:
Paulus Schoutsen
2016-09-30 12:57:24 -07:00
committed by GitHub
parent 7e50ccd32a
commit b650b2b0db
17 changed files with 323 additions and 151 deletions

View File

@@ -4,6 +4,7 @@ Offer MQTT listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#mqtt-trigger at https://home-assistant.io/components/automation/#mqtt-trigger
""" """
import asyncio
import voluptuous as vol import voluptuous as vol
import homeassistant.components.mqtt as mqtt import homeassistant.components.mqtt as mqtt
@@ -26,10 +27,11 @@ def trigger(hass, config, action):
topic = config.get(CONF_TOPIC) topic = config.get(CONF_TOPIC)
payload = config.get(CONF_PAYLOAD) payload = config.get(CONF_PAYLOAD)
@asyncio.coroutine
def mqtt_automation_listener(msg_topic, msg_payload, qos): def mqtt_automation_listener(msg_topic, msg_payload, qos):
"""Listen for MQTT messages.""" """Listen for MQTT messages."""
if payload is None or payload == msg_payload: if payload is None or payload == msg_payload:
action({ hass.async_add_job(action, {
'trigger': { 'trigger': {
'platform': 'mqtt', 'platform': 'mqtt',
'topic': msg_topic, 'topic': msg_topic,

View File

@@ -4,6 +4,7 @@ Offer numeric state listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#numeric-state-trigger at https://home-assistant.io/components/automation/#numeric-state-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@@ -34,7 +35,7 @@ def trigger(hass, config, action):
if value_template is not None: if value_template is not None:
value_template.hass = hass value_template.hass = hass
# pylint: disable=unused-argument @asyncio.coroutine
def state_automation_listener(entity, from_s, to_s): def state_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
if to_s is None: if to_s is None:
@@ -50,19 +51,19 @@ def trigger(hass, config, action):
} }
# If new one doesn't match, nothing to do # If new one doesn't match, nothing to do
if not condition.numeric_state( if not condition.async_numeric_state(
hass, to_s, below, above, value_template, variables): hass, to_s, below, above, value_template, variables):
return return
# Only match if old didn't exist or existed but didn't match # Only match if old didn't exist or existed but didn't match
# Written as: skip if old one did exist and matched # Written as: skip if old one did exist and matched
if from_s is not None and condition.numeric_state( if from_s is not None and condition.async_numeric_state(
hass, from_s, below, above, value_template, variables): hass, from_s, below, above, value_template, variables):
return return
variables['trigger']['from_state'] = from_s variables['trigger']['from_state'] = from_s
variables['trigger']['to_state'] = to_s variables['trigger']['to_state'] = to_s
action(variables) hass.async_add_job(action, variables)
return track_state_change(hass, entity_id, state_automation_listener) return track_state_change(hass, entity_id, state_automation_listener)

View File

@@ -4,12 +4,15 @@ Offer state listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#state-trigger at https://home-assistant.io/components/automation/#state-trigger
""" """
import asyncio
import voluptuous as vol import voluptuous as vol
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.const import MATCH_ALL, CONF_PLATFORM from homeassistant.const import MATCH_ALL, CONF_PLATFORM
from homeassistant.helpers.event import track_state_change, track_point_in_time from homeassistant.helpers.event import (
async_track_state_change, async_track_point_in_utc_time)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.async import run_callback_threadsafe
CONF_ENTITY_ID = "entity_id" CONF_ENTITY_ID = "entity_id"
CONF_FROM = "from" CONF_FROM = "from"
@@ -38,16 +41,17 @@ def trigger(hass, config, action):
from_state = config.get(CONF_FROM, MATCH_ALL) from_state = config.get(CONF_FROM, MATCH_ALL)
to_state = config.get(CONF_TO) or config.get(CONF_STATE) or MATCH_ALL to_state = config.get(CONF_TO) or config.get(CONF_STATE) or MATCH_ALL
time_delta = config.get(CONF_FOR) time_delta = config.get(CONF_FOR)
remove_state_for_cancel = None async_remove_state_for_cancel = None
remove_state_for_listener = None async_remove_state_for_listener = None
@asyncio.coroutine
def state_automation_listener(entity, from_s, to_s): def state_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
nonlocal remove_state_for_cancel, remove_state_for_listener nonlocal async_remove_state_for_cancel, async_remove_state_for_listener
def call_action(): def call_action():
"""Call action with right context.""" """Call action with right context."""
action({ hass.async_add_job(action, {
'trigger': { 'trigger': {
'platform': 'state', 'platform': 'state',
'entity_id': entity, 'entity_id': entity,
@@ -61,35 +65,41 @@ def trigger(hass, config, action):
call_action() call_action()
return return
@asyncio.coroutine
def state_for_listener(now): def state_for_listener(now):
"""Fire on state changes after a delay and calls action.""" """Fire on state changes after a delay and calls action."""
remove_state_for_cancel() async_remove_state_for_cancel()
call_action() call_action()
@asyncio.coroutine
def state_for_cancel_listener(entity, inner_from_s, inner_to_s): def state_for_cancel_listener(entity, inner_from_s, inner_to_s):
"""Fire on changes and cancel for listener if changed.""" """Fire on changes and cancel for listener if changed."""
if inner_to_s.state == to_s.state: if inner_to_s.state == to_s.state:
return return
remove_state_for_listener() async_remove_state_for_listener()
remove_state_for_cancel() async_remove_state_for_cancel()
remove_state_for_listener = track_point_in_time( async_remove_state_for_listener = async_track_point_in_utc_time(
hass, state_for_listener, dt_util.utcnow() + time_delta) hass, state_for_listener, dt_util.utcnow() + time_delta)
remove_state_for_cancel = track_state_change( async_remove_state_for_cancel = async_track_state_change(
hass, entity, state_for_cancel_listener) hass, entity, state_for_cancel_listener)
unsub = track_state_change(hass, entity_id, state_automation_listener, unsub = async_track_state_change(
from_state, to_state) hass, entity_id, state_automation_listener, from_state, to_state)
def async_remove():
"""Remove state listeners async."""
unsub()
# pylint: disable=not-callable
if async_remove_state_for_cancel is not None:
async_remove_state_for_cancel()
if async_remove_state_for_listener is not None:
async_remove_state_for_listener()
def remove(): def remove():
"""Remove state listeners.""" """Remove state listeners."""
unsub() run_callback_threadsafe(hass.loop, async_remove).result()
# pylint: disable=not-callable
if remove_state_for_cancel is not None:
remove_state_for_cancel()
if remove_state_for_listener is not None:
remove_state_for_listener()
return remove return remove

View File

@@ -4,6 +4,7 @@ Offer sun based automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#sun-trigger at https://home-assistant.io/components/automation/#sun-trigger
""" """
import asyncio
from datetime import timedelta from datetime import timedelta
import logging import logging
@@ -30,9 +31,10 @@ def trigger(hass, config, action):
event = config.get(CONF_EVENT) event = config.get(CONF_EVENT)
offset = config.get(CONF_OFFSET) offset = config.get(CONF_OFFSET)
@asyncio.coroutine
def call_action(): def call_action():
"""Call action with right context.""" """Call action with right context."""
action({ hass.async_add_job(action, {
'trigger': { 'trigger': {
'platform': 'sun', 'platform': 'sun',
'event': event, 'event': event,

View File

@@ -4,6 +4,7 @@ Offer time listening automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#time-trigger at https://home-assistant.io/components/automation/#time-trigger
""" """
import asyncio
import logging import logging
import voluptuous as vol import voluptuous as vol
@@ -38,9 +39,10 @@ def trigger(hass, config, action):
minutes = config.get(CONF_MINUTES) minutes = config.get(CONF_MINUTES)
seconds = config.get(CONF_SECONDS) seconds = config.get(CONF_SECONDS)
@asyncio.coroutine
def time_automation_listener(now): def time_automation_listener(now):
"""Listen for time changes and calls action.""" """Listen for time changes and calls action."""
action({ hass.async_add_job(action, {
'trigger': { 'trigger': {
'platform': 'time', 'platform': 'time',
'now': now, 'now': now,

View File

@@ -4,6 +4,7 @@ Offer zone automation rules.
For more details about this automation rule, please refer to the documentation For more details about this automation rule, please refer to the documentation
at https://home-assistant.io/components/automation/#zone-trigger at https://home-assistant.io/components/automation/#zone-trigger
""" """
import asyncio
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
@@ -31,6 +32,7 @@ def trigger(hass, config, action):
zone_entity_id = config.get(CONF_ZONE) zone_entity_id = config.get(CONF_ZONE)
event = config.get(CONF_EVENT) event = config.get(CONF_EVENT)
@asyncio.coroutine
def zone_automation_listener(entity, from_s, to_s): def zone_automation_listener(entity, from_s, to_s):
"""Listen for state changes and calls action.""" """Listen for state changes and calls action."""
if from_s and not location.has_location(from_s) or \ if from_s and not location.has_location(from_s) or \
@@ -47,7 +49,7 @@ def trigger(hass, config, action):
# pylint: disable=too-many-boolean-expressions # pylint: disable=too-many-boolean-expressions
if event == EVENT_ENTER and not from_match and to_match or \ if event == EVENT_ENTER and not from_match and to_match or \
event == EVENT_LEAVE and from_match and not to_match: event == EVENT_LEAVE and from_match and not to_match:
action({ hass.async_add_job(action, {
'trigger': { 'trigger': {
'platform': 'zone', 'platform': 'zone',
'entity_id': entity, 'entity_id': entity,

View File

@@ -4,6 +4,7 @@ Event parser and human readable log generator.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/logbook/ https://home-assistant.io/components/logbook/
""" """
import asyncio
import logging import logging
from datetime import timedelta from datetime import timedelta
from itertools import groupby from itertools import groupby
@@ -20,6 +21,7 @@ from homeassistant.const import (EVENT_HOMEASSISTANT_START,
STATE_NOT_HOME, STATE_OFF, STATE_ON, STATE_NOT_HOME, STATE_OFF, STATE_ON,
ATTR_HIDDEN) ATTR_HIDDEN)
from homeassistant.core import State, split_entity_id, DOMAIN as HA_DOMAIN from homeassistant.core import State, split_entity_id, DOMAIN as HA_DOMAIN
from homeassistant.util.async import run_callback_threadsafe
DOMAIN = "logbook" DOMAIN = "logbook"
DEPENDENCIES = ['recorder', 'frontend'] DEPENDENCIES = ['recorder', 'frontend']
@@ -57,6 +59,13 @@ LOG_MESSAGE_SCHEMA = vol.Schema({
def log_entry(hass, name, message, domain=None, entity_id=None): def log_entry(hass, name, message, domain=None, entity_id=None):
"""Add an entry to the logbook."""
run_callback_threadsafe(
hass.loop, async_log_entry, hass, name, message, domain, entity_id
).result()
def async_log_entry(hass, name, message, domain=None, entity_id=None):
"""Add an entry to the logbook.""" """Add an entry to the logbook."""
data = { data = {
ATTR_NAME: name, ATTR_NAME: name,
@@ -67,11 +76,12 @@ def log_entry(hass, name, message, domain=None, entity_id=None):
data[ATTR_DOMAIN] = domain data[ATTR_DOMAIN] = domain
if entity_id is not None: if entity_id is not None:
data[ATTR_ENTITY_ID] = entity_id data[ATTR_ENTITY_ID] = entity_id
hass.bus.fire(EVENT_LOGBOOK_ENTRY, data) hass.bus.async_fire(EVENT_LOGBOOK_ENTRY, data)
def setup(hass, config): def setup(hass, config):
"""Listen for download events to download files.""" """Listen for download events to download files."""
@asyncio.coroutine
def log_message(service): def log_message(service):
"""Handle sending notification message service calls.""" """Handle sending notification message service calls."""
message = service.data[ATTR_MESSAGE] message = service.data[ATTR_MESSAGE]
@@ -80,8 +90,8 @@ def setup(hass, config):
entity_id = service.data.get(ATTR_ENTITY_ID) entity_id = service.data.get(ATTR_ENTITY_ID)
message.hass = hass message.hass = hass
message = message.render() message = message.async_render()
log_entry(hass, name, message, domain, entity_id) async_log_entry(hass, name, message, domain, entity_id)
hass.wsgi.register_view(LogbookView(hass, config)) hass.wsgi.register_view(LogbookView(hass, config))

View File

@@ -4,6 +4,7 @@ Support for MQTT message handling.
For more details about this component, please refer to the documentation at For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/ https://home-assistant.io/components/mqtt/
""" """
import asyncio
import logging import logging
import os import os
import socket import socket
@@ -11,6 +12,7 @@ import time
import voluptuous as vol import voluptuous as vol
from homeassistant.core import JobPriority
from homeassistant.bootstrap import prepare_setup_platform from homeassistant.bootstrap import prepare_setup_platform
from homeassistant.config import load_yaml_config_file from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -164,11 +166,20 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
def subscribe(hass, topic, callback, qos=DEFAULT_QOS): def subscribe(hass, topic, callback, qos=DEFAULT_QOS):
"""Subscribe to an MQTT topic.""" """Subscribe to an MQTT topic."""
@asyncio.coroutine
def mqtt_topic_subscriber(event): def mqtt_topic_subscriber(event):
"""Match subscribed MQTT topic.""" """Match subscribed MQTT topic."""
if _match_topic(topic, event.data[ATTR_TOPIC]): if not _match_topic(topic, event.data[ATTR_TOPIC]):
callback(event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], return
event.data[ATTR_QOS])
if asyncio.iscoroutinefunction(callback):
yield from callback(
event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD],
event.data[ATTR_QOS])
else:
hass.add_job(callback, event.data[ATTR_TOPIC],
event.data[ATTR_PAYLOAD], event.data[ATTR_QOS],
priority=JobPriority.EVENT_CALLBACK)
remove = hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED, remove = hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED,
mqtt_topic_subscriber) mqtt_topic_subscriber)

View File

@@ -248,12 +248,16 @@ class HomeAssistant(object):
def notify_when_done(): def notify_when_done():
"""Notify event loop when pool done.""" """Notify event loop when pool done."""
count = 0
while True: while True:
# Wait for the work queue to empty # Wait for the work queue to empty
self.pool.block_till_done() self.pool.block_till_done()
# Verify the loop is empty # Verify the loop is empty
if self._loop_empty(): if self._loop_empty():
count += 1
if count == 2:
break break
# sleep in the loop executor, this forces execution back into # sleep in the loop executor, this forces execution back into
@@ -675,40 +679,29 @@ class StateMachine(object):
return list(self._states.values()) return list(self._states.values())
def get(self, entity_id): def get(self, entity_id):
"""Retrieve state of entity_id or None if not found.""" """Retrieve state of entity_id or None if not found.
Async friendly.
"""
return self._states.get(entity_id.lower()) return self._states.get(entity_id.lower())
def is_state(self, entity_id, state): def is_state(self, entity_id, state):
"""Test if entity exists and is specified state."""
return run_callback_threadsafe(
self._loop, self.async_is_state, entity_id, state
).result()
def async_is_state(self, entity_id, state):
"""Test if entity exists and is specified state. """Test if entity exists and is specified state.
This method must be run in the event loop. Async friendly.
""" """
entity_id = entity_id.lower() state_obj = self.get(entity_id)
return (entity_id in self._states and return state_obj and state_obj.state == state
self._states[entity_id].state == state)
def is_state_attr(self, entity_id, name, value): def is_state_attr(self, entity_id, name, value):
"""Test if entity exists and has a state attribute set to value."""
return run_callback_threadsafe(
self._loop, self.async_is_state_attr, entity_id, name, value
).result()
def async_is_state_attr(self, entity_id, name, value):
"""Test if entity exists and has a state attribute set to value. """Test if entity exists and has a state attribute set to value.
This method must be run in the event loop. Async friendly.
""" """
entity_id = entity_id.lower() state_obj = self.get(entity_id)
return (entity_id in self._states and return state_obj and state_obj.attributes.get(name, None) == value
self._states[entity_id].attributes.get(name, None) == value)
def remove(self, entity_id): def remove(self, entity_id):
"""Remove the state of an entity. """Remove the state of an entity.
@@ -799,7 +792,8 @@ class StateMachine(object):
class Service(object): class Service(object):
"""Represents a callable service.""" """Represents a callable service."""
__slots__ = ['func', 'description', 'fields', 'schema'] __slots__ = ['func', 'description', 'fields', 'schema',
'iscoroutinefunction']
def __init__(self, func, description, fields, schema): def __init__(self, func, description, fields, schema):
"""Initialize a service.""" """Initialize a service."""
@@ -807,6 +801,7 @@ class Service(object):
self.description = description or '' self.description = description or ''
self.fields = fields or {} self.fields = fields or {}
self.schema = schema self.schema = schema
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
def as_dict(self): def as_dict(self):
"""Return dictionary representation of this service.""" """Return dictionary representation of this service."""
@@ -815,19 +810,6 @@ class Service(object):
'fields': self.fields, 'fields': self.fields,
} }
def __call__(self, call):
"""Execute the service."""
try:
if self.schema:
call.data = self.schema(call.data)
call.data = MappingProxyType(call.data)
self.func(call)
except vol.MultipleInvalid as ex:
_LOGGER.error('Invalid service data for %s.%s: %s',
call.domain, call.service,
humanize_error(call.data, ex))
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
class ServiceCall(object): class ServiceCall(object):
@@ -839,7 +821,7 @@ class ServiceCall(object):
"""Initialize a service call.""" """Initialize a service call."""
self.domain = domain.lower() self.domain = domain.lower()
self.service = service.lower() self.service = service.lower()
self.data = data or {} self.data = MappingProxyType(data or {})
self.call_id = call_id self.call_id = call_id
def __repr__(self): def __repr__(self):
@@ -983,9 +965,9 @@ class ServiceRegistry(object):
fut = asyncio.Future(loop=self._loop) fut = asyncio.Future(loop=self._loop)
@asyncio.coroutine @asyncio.coroutine
def service_executed(call): def service_executed(event):
"""Callback method that is called when service is executed.""" """Callback method that is called when service is executed."""
if call.data[ATTR_SERVICE_CALL_ID] == call_id: if event.data[ATTR_SERVICE_CALL_ID] == call_id:
fut.set_result(True) fut.set_result(True)
unsub = self._bus.async_listen(EVENT_SERVICE_EXECUTED, unsub = self._bus.async_listen(EVENT_SERVICE_EXECUTED,
@@ -1000,9 +982,10 @@ class ServiceRegistry(object):
unsub() unsub()
return success return success
@asyncio.coroutine
def _event_to_service_call(self, event): def _event_to_service_call(self, event):
"""Callback for SERVICE_CALLED events from the event bus.""" """Callback for SERVICE_CALLED events from the event bus."""
service_data = event.data.get(ATTR_SERVICE_DATA) service_data = event.data.get(ATTR_SERVICE_DATA) or {}
domain = event.data.get(ATTR_DOMAIN).lower() domain = event.data.get(ATTR_DOMAIN).lower()
service = event.data.get(ATTR_SERVICE).lower() service = event.data.get(ATTR_SERVICE).lower()
call_id = event.data.get(ATTR_SERVICE_CALL_ID) call_id = event.data.get(ATTR_SERVICE_CALL_ID)
@@ -1014,19 +997,41 @@ class ServiceRegistry(object):
return return
service_handler = self._services[domain][service] service_handler = self._services[domain][service]
def fire_service_executed():
"""Fire service executed event."""
if not call_id:
return
data = {ATTR_SERVICE_CALL_ID: call_id}
if service_handler.iscoroutinefunction:
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
else:
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
try:
if service_handler.schema:
service_data = service_handler.schema(service_data)
except vol.Invalid as ex:
_LOGGER.error('Invalid service data for %s.%s: %s',
domain, service, humanize_error(service_data, ex))
fire_service_executed()
return
service_call = ServiceCall(domain, service, service_data, call_id) service_call = ServiceCall(domain, service, service_data, call_id)
# Add a job to the pool that calls _execute_service if not service_handler.iscoroutinefunction:
self._add_job(self._execute_service, service_handler, service_call, def execute_service():
priority=JobPriority.EVENT_SERVICE) """Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call)
fire_service_executed()
def _execute_service(self, service, call): self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
"""Execute a service and fires a SERVICE_EXECUTED event.""" return
service(call)
if call.call_id is not None: yield from service_handler.func(service_call)
self._bus.fire( fire_service_executed()
EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id})
def _generate_unique_id(self): def _generate_unique_id(self):
"""Generate a unique service call id.""" """Generate a unique service call id."""

View File

@@ -84,6 +84,15 @@ def or_from_config(config: ConfigType, config_validation: bool=True):
def numeric_state(hass: HomeAssistant, entity, below=None, above=None, def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None): value_template=None, variables=None):
"""Test a numeric state condition.""" """Test a numeric state condition."""
return run_callback_threadsafe(
hass.loop, async_numeric_state, hass, entity, below, above,
value_template, variables,
).result()
def async_numeric_state(hass: HomeAssistant, entity, below=None, above=None,
value_template=None, variables=None):
"""Test a numeric state condition."""
if isinstance(entity, str): if isinstance(entity, str):
entity = hass.states.get(entity) entity = hass.states.get(entity)
@@ -96,7 +105,7 @@ def numeric_state(hass: HomeAssistant, entity, below=None, above=None,
variables = dict(variables or {}) variables = dict(variables or {})
variables['state'] = entity variables['state'] = entity
try: try:
value = value_template.render(variables) value = value_template.async_render(variables)
except TemplateError as ex: except TemplateError as ex:
_LOGGER.error("Template error: %s", ex) _LOGGER.error("Template error: %s", ex)
return False return False
@@ -290,7 +299,10 @@ def time_from_config(config, config_validation=True):
def zone(hass, zone_ent, entity): def zone(hass, zone_ent, entity):
"""Test if zone-condition matches.""" """Test if zone-condition matches.
Can be run async.
"""
if isinstance(zone_ent, str): if isinstance(zone_ent, str):
zone_ent = hass.states.get(zone_ent) zone_ent = hass.states.get(zone_ent)

View File

@@ -1,4 +1,5 @@
"""An abstract class for entities.""" """An abstract class for entities."""
import asyncio
import logging import logging
from typing import Any, Optional, List, Dict from typing import Any, Optional, List, Dict
@@ -11,6 +12,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import NoEntitySpecifiedError from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.util import ensure_unique_string, slugify from homeassistant.util import ensure_unique_string, slugify
from homeassistant.util.async import run_coroutine_threadsafe
# Entity attributes that we will overwrite # Entity attributes that we will overwrite
_OVERWRITE = {} # type: Dict[str, Any] _OVERWRITE = {} # type: Dict[str, Any]
@@ -143,6 +145,23 @@ class Entity(object):
If force_refresh == True will update entity before setting state. If force_refresh == True will update entity before setting state.
""" """
# We're already in a thread, do the force refresh here.
if force_refresh and not hasattr(self, 'async_update'):
self.update()
force_refresh = False
run_coroutine_threadsafe(
self.async_update_ha_state(force_refresh), self.hass.loop
).result()
@asyncio.coroutine
def async_update_ha_state(self, force_refresh=False):
"""Update Home Assistant with current state of entity.
If force_refresh == True will update entity before setting state.
This method must be run in the event loop.
"""
if self.hass is None: if self.hass is None:
raise RuntimeError("Attribute hass is None for {}".format(self)) raise RuntimeError("Attribute hass is None for {}".format(self))
@@ -151,7 +170,13 @@ class Entity(object):
"No entity id specified for entity {}".format(self.name)) "No entity id specified for entity {}".format(self.name))
if force_refresh: if force_refresh:
self.update() if hasattr(self, 'async_update'):
# pylint: disable=no-member
self.async_update()
else:
# PS: Run this in our own thread pool once we have
# future support?
yield from self.hass.loop.run_in_executor(None, self.update)
state = STATE_UNKNOWN if self.state is None else str(self.state) state = STATE_UNKNOWN if self.state is None else str(self.state)
attr = self.state_attributes or {} attr = self.state_attributes or {}
@@ -192,7 +217,7 @@ class Entity(object):
# Could not convert state to float # Could not convert state to float
pass pass
return self.hass.states.set( self.hass.states.async_set(
self.entity_id, state, attr, self.force_update) self.entity_id, state, attr, self.force_update)
def remove(self) -> None: def remove(self) -> None:

View File

@@ -18,6 +18,28 @@ def track_state_change(hass, entity_ids, action, from_state=None,
Returns a function that can be called to remove the listener. Returns a function that can be called to remove the listener.
""" """
async_unsub = run_callback_threadsafe(
hass.loop, async_track_state_change, hass, entity_ids, action,
from_state, to_state).result()
def remove():
"""Remove listener."""
run_callback_threadsafe(hass.loop, async_unsub).result()
return remove
def async_track_state_change(hass, entity_ids, action, from_state=None,
to_state=None):
"""Track specific state changes.
entity_ids, from_state and to_state can be string or list.
Use list to match multiple.
Returns a function that can be called to remove the listener.
Must be run within the event loop.
"""
from_state = _process_state_match(from_state) from_state = _process_state_match(from_state)
to_state = _process_state_match(to_state) to_state = _process_state_match(to_state)
@@ -52,7 +74,7 @@ def track_state_change(hass, entity_ids, action, from_state=None,
event.data.get('old_state'), event.data.get('old_state'),
event.data.get('new_state')) event.data.get('new_state'))
return hass.bus.listen(EVENT_STATE_CHANGED, state_change_listener) return hass.bus.async_listen(EVENT_STATE_CHANGED, state_change_listener)
def track_point_in_time(hass, action, point_in_time): def track_point_in_time(hass, action, point_in_time):
@@ -69,6 +91,19 @@ def track_point_in_time(hass, action, point_in_time):
def track_point_in_utc_time(hass, action, point_in_time): def track_point_in_utc_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in UTC time."""
async_unsub = run_callback_threadsafe(
hass.loop, async_track_point_in_utc_time, hass, action, point_in_time
).result()
def remove():
"""Remove listener."""
run_callback_threadsafe(hass.loop, async_unsub).result()
return remove
def async_track_point_in_utc_time(hass, action, point_in_time):
"""Add a listener that fires once after a specific point in UTC time.""" """Add a listener that fires once after a specific point in UTC time."""
# Ensure point_in_time is UTC # Ensure point_in_time is UTC
point_in_time = dt_util.as_utc(point_in_time) point_in_time = dt_util.as_utc(point_in_time)
@@ -88,20 +123,14 @@ def track_point_in_utc_time(hass, action, point_in_time):
# listener gets lined up twice to be executed. This will make # listener gets lined up twice to be executed. This will make
# sure the second time it does nothing. # sure the second time it does nothing.
point_in_time_listener.run = True point_in_time_listener.run = True
async_remove() async_unsub()
hass.async_add_job(action, now) hass.async_add_job(action, now)
future = run_callback_threadsafe( async_unsub = hass.bus.async_listen(EVENT_TIME_CHANGED,
hass.loop, hass.bus.async_listen, EVENT_TIME_CHANGED, point_in_time_listener)
point_in_time_listener)
async_remove = future.result()
def remove(): return async_unsub
"""Remove listener."""
run_callback_threadsafe(hass.loop, async_remove).result()
return remove
def track_sunrise(hass, action, offset=None): def track_sunrise(hass, action, offset=None):
@@ -118,19 +147,21 @@ def track_sunrise(hass, action, offset=None):
return next_time return next_time
@asyncio.coroutine
def sunrise_automation_listener(now): def sunrise_automation_listener(now):
"""Called when it's time for action.""" """Called when it's time for action."""
nonlocal remove nonlocal remove
remove = track_point_in_utc_time(hass, sunrise_automation_listener, remove = async_track_point_in_utc_time(
next_rise()) hass, sunrise_automation_listener, next_rise())
action() hass.async_add_job(action)
remove = track_point_in_utc_time(hass, sunrise_automation_listener, remove = run_callback_threadsafe(
next_rise()) hass.loop, async_track_point_in_utc_time, hass,
sunrise_automation_listener, next_rise()).result()
def remove_listener(): def remove_listener():
"""Remove sunrise listener.""" """Remove sunset listener."""
remove() run_callback_threadsafe(hass.loop, remove).result()
return remove_listener return remove_listener
@@ -149,19 +180,21 @@ def track_sunset(hass, action, offset=None):
return next_time return next_time
@asyncio.coroutine
def sunset_automation_listener(now): def sunset_automation_listener(now):
"""Called when it's time for action.""" """Called when it's time for action."""
nonlocal remove nonlocal remove
remove = track_point_in_utc_time(hass, sunset_automation_listener, remove = async_track_point_in_utc_time(
next_set()) hass, sunset_automation_listener, next_set())
action() hass.async_add_job(action)
remove = track_point_in_utc_time(hass, sunset_automation_listener, remove = run_callback_threadsafe(
next_set()) hass.loop, async_track_point_in_utc_time, hass,
sunset_automation_listener, next_set()).result()
def remove_listener(): def remove_listener():
"""Remove sunset listener.""" """Remove sunset listener."""
remove() run_callback_threadsafe(hass.loop, remove).result()
return remove_listener return remove_listener

View File

@@ -149,8 +149,8 @@ class Template(object):
global_vars = ENV.make_globals({ global_vars = ENV.make_globals({
'closest': location_methods.closest, 'closest': location_methods.closest,
'distance': location_methods.distance, 'distance': location_methods.distance,
'is_state': self.hass.states.async_is_state, 'is_state': self.hass.states.is_state,
'is_state_attr': self.hass.states.async_is_state_attr, 'is_state_attr': self.hass.states.is_state_attr,
'states': AllStates(self.hass), 'states': AllStates(self.hass),
}) })

View File

@@ -77,7 +77,8 @@ class TestComponentsCore(unittest.TestCase):
service_call = ha.ServiceCall('homeassistant', 'turn_on', { service_call = ha.ServiceCall('homeassistant', 'turn_on', {
'entity_id': ['light.test', 'sensor.bla', 'light.bla'] 'entity_id': ['light.test', 'sensor.bla', 'light.bla']
}) })
self.hass.services._services['homeassistant']['turn_on'](service_call) service = self.hass.services._services['homeassistant']['turn_on']
service.func(service_call)
self.assertEqual(2, mock_call.call_count) self.assertEqual(2, mock_call.call_count)
self.assertEqual( self.assertEqual(

View File

@@ -1,7 +1,8 @@
"""The tests for the logbook component.""" """The tests for the logbook component."""
# pylint: disable=protected-access,too-many-public-methods # pylint: disable=protected-access,too-many-public-methods
import unittest
from datetime import timedelta from datetime import timedelta
import unittest
from unittest.mock import patch
from homeassistant.components import sun from homeassistant.components import sun
import homeassistant.core as ha import homeassistant.core as ha
@@ -18,13 +19,17 @@ from tests.common import mock_http_component, get_test_home_assistant
class TestComponentLogbook(unittest.TestCase): class TestComponentLogbook(unittest.TestCase):
"""Test the History component.""" """Test the History component."""
EMPTY_CONFIG = logbook.CONFIG_SCHEMA({ha.DOMAIN: {}, logbook.DOMAIN: {}}) EMPTY_CONFIG = logbook.CONFIG_SCHEMA({logbook.DOMAIN: {}})
def setUp(self): def setUp(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
mock_http_component(self.hass) mock_http_component(self.hass)
assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG) self.hass.config.components += ['frontend', 'recorder', 'api']
with patch('homeassistant.components.logbook.'
'register_built_in_panel'):
assert setup_component(self.hass, logbook.DOMAIN,
self.EMPTY_CONFIG)
def tearDown(self): def tearDown(self):
"""Stop everything that was started.""" """Stop everything that was started."""
@@ -44,7 +49,6 @@ class TestComponentLogbook(unittest.TestCase):
logbook.ATTR_DOMAIN: 'switch', logbook.ATTR_DOMAIN: 'switch',
logbook.ATTR_ENTITY_ID: 'switch.test_switch' logbook.ATTR_ENTITY_ID: 'switch.test_switch'
}, True) }, True)
self.hass.block_till_done()
self.assertEqual(1, len(calls)) self.assertEqual(1, len(calls))
last_call = calls[-1] last_call = calls[-1]
@@ -65,7 +69,6 @@ class TestComponentLogbook(unittest.TestCase):
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener) self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
self.hass.services.call(logbook.DOMAIN, 'log', {}, True) self.hass.services.call(logbook.DOMAIN, 'log', {}, True)
self.hass.block_till_done()
self.assertEqual(0, len(calls)) self.assertEqual(0, len(calls))

View File

@@ -1,6 +1,9 @@
"""Test the entity helper.""" """Test the entity helper."""
# pylint: disable=protected-access,too-many-public-methods # pylint: disable=protected-access,too-many-public-methods
import unittest import asyncio
from unittest.mock import MagicMock
import pytest
import homeassistant.helpers.entity as entity import homeassistant.helpers.entity as entity
from homeassistant.const import ATTR_HIDDEN from homeassistant.const import ATTR_HIDDEN
@@ -8,26 +11,75 @@ from homeassistant.const import ATTR_HIDDEN
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant
class TestHelpersEntity(unittest.TestCase): def test_generate_entity_id_requires_hass_or_ids():
"""Ensure we require at least hass or current ids."""
fmt = 'test.{}'
with pytest.raises(ValueError):
entity.generate_entity_id(fmt, 'hello world')
def test_generate_entity_id_given_keys():
"""Test generating an entity id given current ids."""
fmt = 'test.{}'
assert entity.generate_entity_id(
fmt, 'overwrite hidden true', current_ids=[
'test.overwrite_hidden_true']) == 'test.overwrite_hidden_true_2'
assert entity.generate_entity_id(
fmt, 'overwrite hidden true', current_ids=[
'test.another_entity']) == 'test.overwrite_hidden_true'
def test_async_update_support(event_loop):
"""Test async update getting called."""
sync_update = []
async_update = []
class AsyncEntity(entity.Entity):
hass = MagicMock()
entity_id = 'sensor.test'
def update(self):
sync_update.append([1])
ent = AsyncEntity()
ent.hass.loop = event_loop
@asyncio.coroutine
def test():
yield from ent.async_update_ha_state(True)
event_loop.run_until_complete(test())
assert len(sync_update) == 1
assert len(async_update) == 0
ent.async_update = lambda: async_update.append(1)
event_loop.run_until_complete(test())
assert len(sync_update) == 1
assert len(async_update) == 1
class TestHelpersEntity(object):
"""Test homeassistant.helpers.entity module.""" """Test homeassistant.helpers.entity module."""
def setUp(self): # pylint: disable=invalid-name def setup_method(self, method):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.entity = entity.Entity() self.entity = entity.Entity()
self.entity.entity_id = 'test.overwrite_hidden_true' self.entity.entity_id = 'test.overwrite_hidden_true'
self.hass = self.entity.hass = get_test_home_assistant() self.hass = self.entity.hass = get_test_home_assistant()
self.entity.update_ha_state() self.entity.update_ha_state()
def tearDown(self): # pylint: disable=invalid-name def teardown_method(self, method):
"""Stop everything that was started.""" """Stop everything that was started."""
self.hass.stop()
entity.set_customize({}) entity.set_customize({})
self.hass.stop()
def test_default_hidden_not_in_attributes(self): def test_default_hidden_not_in_attributes(self):
"""Test that the default hidden property is set to False.""" """Test that the default hidden property is set to False."""
self.assertNotIn( assert ATTR_HIDDEN not in self.hass.states.get(
ATTR_HIDDEN, self.entity.entity_id).attributes
self.hass.states.get(self.entity.entity_id).attributes)
def test_overwriting_hidden_property_to_true(self): def test_overwriting_hidden_property_to_true(self):
"""Test we can overwrite hidden property to True.""" """Test we can overwrite hidden property to True."""
@@ -35,31 +87,11 @@ class TestHelpersEntity(unittest.TestCase):
self.entity.update_ha_state() self.entity.update_ha_state()
state = self.hass.states.get(self.entity.entity_id) state = self.hass.states.get(self.entity.entity_id)
self.assertTrue(state.attributes.get(ATTR_HIDDEN)) assert state.attributes.get(ATTR_HIDDEN)
def test_generate_entity_id_requires_hass_or_ids(self):
"""Ensure we require at least hass or current ids."""
fmt = 'test.{}'
with self.assertRaises(ValueError):
entity.generate_entity_id(fmt, 'hello world')
def test_generate_entity_id_given_hass(self): def test_generate_entity_id_given_hass(self):
"""Test generating an entity id given hass object.""" """Test generating an entity id given hass object."""
fmt = 'test.{}' fmt = 'test.{}'
self.assertEqual( assert entity.generate_entity_id(
'test.overwrite_hidden_true_2', fmt, 'overwrite hidden true',
entity.generate_entity_id(fmt, 'overwrite hidden true', hass=self.hass) == 'test.overwrite_hidden_true_2'
hass=self.hass))
def test_generate_entity_id_given_keys(self):
"""Test generating an entity id given current ids."""
fmt = 'test.{}'
self.assertEqual(
'test.overwrite_hidden_true_2',
entity.generate_entity_id(
fmt, 'overwrite hidden true',
current_ids=['test.overwrite_hidden_true']))
self.assertEqual(
'test.overwrite_hidden_true',
entity.generate_entity_id(fmt, 'overwrite hidden true',
current_ids=['test.another_entity']))

View File

@@ -1,6 +1,7 @@
"""Test to verify that Home Assistant core works.""" """Test to verify that Home Assistant core works."""
# pylint: disable=protected-access,too-many-public-methods # pylint: disable=protected-access,too-many-public-methods
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
import asyncio
import os import os
import signal import signal
import unittest import unittest
@@ -362,7 +363,6 @@ class TestServiceRegistry(unittest.TestCase):
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
self.services = self.hass.services self.services = self.hass.services
self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None) self.services.register("Test_Domain", "TEST_SERVICE", lambda x: None)
self.hass.block_till_done()
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Stop down stuff we started.""" """Stop down stuff we started."""
@@ -387,8 +387,13 @@ class TestServiceRegistry(unittest.TestCase):
def test_call_with_blocking_done_in_time(self): def test_call_with_blocking_done_in_time(self):
"""Test call with blocking.""" """Test call with blocking."""
calls = [] calls = []
def service_handler(call):
"""Service handler."""
calls.append(call)
self.services.register("test_domain", "register_calls", self.services.register("test_domain", "register_calls",
lambda x: calls.append(1)) service_handler)
self.assertTrue( self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True)) self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
@@ -404,6 +409,22 @@ class TestServiceRegistry(unittest.TestCase):
finally: finally:
ha.SERVICE_CALL_LIMIT = prior ha.SERVICE_CALL_LIMIT = prior
def test_async_service(self):
"""Test registering and calling an async service."""
calls = []
@asyncio.coroutine
def service_handler(call):
"""Service handler coroutine."""
calls.append(call)
self.services.register('test_domain', 'register_calls',
service_handler)
self.assertTrue(
self.services.call('test_domain', 'REGISTER_CALLS', blocking=True))
self.hass.block_till_done()
self.assertEqual(1, len(calls))
class TestConfig(unittest.TestCase): class TestConfig(unittest.TestCase):
"""Test configuration methods.""" """Test configuration methods."""