Files
homeassistant-core/tests/common.py
T

1209 lines
37 KiB
Python
Raw Normal View History

2016-03-09 10:25:50 +01:00
"""Test the helper method for writing tests."""
2021-03-18 15:13:22 +01:00
from __future__ import annotations
import asyncio
import collections
2019-01-29 00:52:42 +01:00
from collections import OrderedDict
from contextlib import contextmanager
from datetime import datetime, timedelta
import functools as ft
2019-01-29 00:52:42 +01:00
from io import StringIO
import json
import logging
import os
2020-11-25 15:10:04 +01:00
import pathlib
import threading
2020-06-29 11:39:24 -05:00
import time
from time import monotonic
import types
2021-03-18 15:13:22 +01:00
from typing import Any, Awaitable, Collection
2021-01-01 22:31:56 +01:00
from unittest.mock import AsyncMock, Mock, patch
2019-01-29 00:52:42 +01:00
2021-03-02 10:02:04 +02:00
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
2020-04-02 09:55:34 -07:00
from homeassistant import auth, config_entries, core as ha, loader
2018-07-13 15:31:20 +02:00
from homeassistant.auth import (
2019-07-31 12:25:30 -07:00
auth_store,
models as auth_models,
2019-07-31 12:25:30 -07:00
permissions as auth_permissions,
providers as auth_providers,
2019-07-31 12:25:30 -07:00
)
2018-11-08 12:57:00 +01:00
from homeassistant.auth.permissions import system_policies
from homeassistant.components import device_automation, recorder
from homeassistant.components.device_automation import ( # noqa: F401
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
)
2021-07-06 14:38:48 +02:00
from homeassistant.components.mqtt.models import ReceiveMessage
2019-01-29 00:52:42 +01:00
from homeassistant.config import async_process_component_config
from homeassistant.const import (
2019-07-31 12:25:30 -07:00
DEVICE_DEFAULT_NAME,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED,
STATE_OFF,
STATE_ON,
2019-07-31 12:25:30 -07:00
)
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, State
2019-01-29 00:52:42 +01:00
from homeassistant.helpers import (
2019-07-31 12:25:30 -07:00
area_registry,
device_registry,
entity,
entity_platform,
entity_registry,
intent,
restore_state,
storage,
)
from homeassistant.helpers.json import JSONEncoder
from homeassistant.setup import async_setup_component, setup_component
2019-10-01 16:59:06 +02:00
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as date_util
from homeassistant.util.unit_system import METRIC_SYSTEM
2021-03-21 18:44:29 -10:00
import homeassistant.util.uuid as uuid_util
import homeassistant.util.yaml.loader as yaml_loader
2016-08-23 06:42:05 +02:00
_LOGGER = logging.getLogger(__name__)
2017-05-17 15:19:40 -07:00
INSTANCES = []
2019-07-31 12:25:30 -07:00
CLIENT_ID = "https://example.com/app"
CLIENT_REDIRECT_URI = "https://example.com/app/callback"
2016-02-14 12:54:16 -08:00
2014-11-25 00:20:36 -08:00
async def async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_id: str
) -> Any:
"""Get a device automation for a single device id."""
automations = await device_automation.async_get_device_automations(
hass, automation_type, [device_id]
)
return automations.get(device_id)
2017-05-01 23:29:01 -07:00
def threadsafe_callback_factory(func):
"""Create threadsafe functions out of callbacks.
Callback needs to have `hass` as first argument.
"""
2019-07-31 12:25:30 -07:00
2017-05-01 23:29:01 -07:00
@ft.wraps(func)
def threadsafe(*args, **kwargs):
"""Call func threadsafe."""
hass = args[0]
2017-06-03 18:51:29 -07:00
return run_callback_threadsafe(
2019-07-31 12:25:30 -07:00
hass.loop, ft.partial(func, *args, **kwargs)
).result()
2017-05-01 23:29:01 -07:00
return threadsafe
2017-06-03 18:51:29 -07:00
def threadsafe_coroutine_factory(func):
"""Create threadsafe functions out of coroutine.
Callback needs to have `hass` as first argument.
"""
2019-07-31 12:25:30 -07:00
2017-06-03 18:51:29 -07:00
@ft.wraps(func)
def threadsafe(*args, **kwargs):
"""Call func threadsafe."""
hass = args[0]
2019-10-01 16:59:06 +02:00
return asyncio.run_coroutine_threadsafe(
func(*args, **kwargs), hass.loop
).result()
2017-06-03 18:51:29 -07:00
return threadsafe
2016-08-23 06:42:05 +02:00
def get_test_config_dir(*add_path):
2016-03-09 10:25:50 +01:00
"""Return a path to a test config dir."""
2019-07-31 12:25:30 -07:00
return os.path.join(os.path.dirname(__file__), "testing_config", *add_path)
2016-10-31 08:47:29 -07:00
def get_test_home_assistant():
2016-11-18 23:05:03 +01:00
"""Return a Home Assistant object pointing at test config directory."""
loop = asyncio.new_event_loop()
2019-05-22 21:09:59 -07:00
asyncio.set_event_loop(loop)
2016-10-23 23:48:01 -07:00
hass = loop.run_until_complete(async_test_home_assistant(loop))
2015-04-30 21:03:01 -07:00
loop_stop_event = threading.Event()
def run_loop():
"""Run event loop."""
# pylint: disable=protected-access
loop._thread_ident = threading.get_ident()
loop.run_forever()
loop_stop_event.set()
orig_stop = hass.stop
hass._stopped = Mock(set=loop.stop)
def start_hass(*mocks):
2017-05-02 18:18:47 +02:00
"""Start hass."""
2019-10-01 16:59:06 +02:00
asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result()
def stop_hass():
"""Stop hass."""
orig_stop()
loop_stop_event.wait()
2017-03-07 01:11:41 -08:00
loop.close()
hass.start = start_hass
hass.stop = stop_hass
threading.Thread(name="LoopThread", target=run_loop, daemon=False).start()
return hass
2016-11-18 23:05:03 +01:00
# pylint: disable=protected-access
async def async_test_home_assistant(loop, load_registries=True):
2016-10-23 23:48:01 -07:00
"""Return a Home Assistant object pointing at test config dir."""
2020-07-06 15:58:53 -07:00
hass = ha.HomeAssistant()
2018-07-13 11:43:08 +02:00
store = auth_store.AuthStore(hass)
hass.auth = auth.AuthManager(hass, store, {}, {})
2018-05-01 12:20:41 -04:00
ensure_auth_manager_loaded(hass.auth)
2017-05-17 15:19:40 -07:00
INSTANCES.append(hass)
2017-02-18 23:17:18 +01:00
2017-03-01 05:33:19 +01:00
orig_async_add_job = hass.async_add_job
2018-08-20 16:34:18 +02:00
orig_async_add_executor_job = hass.async_add_executor_job
orig_async_create_task = hass.async_create_task
2017-03-01 05:33:19 +01:00
2017-02-18 23:17:18 +01:00
def async_add_job(target, *args):
2018-08-20 16:34:18 +02:00
"""Add job."""
2020-04-30 13:29:50 -07:00
check_target = target
while isinstance(check_target, ft.partial):
check_target = check_target.func
if isinstance(check_target, Mock) and not isinstance(target, AsyncMock):
fut = asyncio.Future()
fut.set_result(target(*args))
return fut
2017-03-01 05:33:19 +01:00
return orig_async_add_job(target, *args)
2017-02-18 23:17:18 +01:00
2018-08-20 16:34:18 +02:00
def async_add_executor_job(target, *args):
"""Add executor job."""
2020-04-30 13:29:50 -07:00
check_target = target
while isinstance(check_target, ft.partial):
check_target = check_target.func
if isinstance(check_target, Mock):
fut = asyncio.Future()
fut.set_result(target(*args))
return fut
2018-08-20 16:34:18 +02:00
return orig_async_add_executor_job(target, *args)
def async_create_task(coroutine):
"""Create task."""
2020-04-30 13:29:50 -07:00
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
fut = asyncio.Future()
fut.set_result(None)
return fut
2018-08-20 16:34:18 +02:00
return orig_async_create_task(coroutine)
async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None:
"""Block until at most max_remaining_tasks remain.
Based on HomeAssistant.async_block_till_done
"""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
2021-03-18 15:13:22 +01:00
start_time: float | None = None
while len(self._pending_tasks) > max_remaining_tasks:
2021-03-18 22:58:19 +01:00
pending: Collection[Awaitable[Any]] = [
task for task in self._pending_tasks if not task.done()
2021-03-18 22:58:19 +01:00
]
self._pending_tasks.clear()
if len(pending) > max_remaining_tasks:
remaining_pending = await self._await_count_and_log_pending(
pending, max_remaining_tasks=max_remaining_tasks
)
self._pending_tasks.extend(remaining_pending)
if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
self._pending_tasks.extend(pending)
await asyncio.sleep(0)
async def _await_count_and_log_pending(
self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0
) -> Collection[Awaitable[Any]]:
"""Block at most max_remaining_tasks remain and log tasks that take a long time.
Based on HomeAssistant._await_and_log_pending
"""
wait_time = 0
return_when = asyncio.ALL_COMPLETED
if max_remaining_tasks:
return_when = asyncio.FIRST_COMPLETED
while len(pending) > max_remaining_tasks:
_, pending = await asyncio.wait(
pending, timeout=BLOCK_LOG_TIMEOUT, return_when=return_when
)
if not pending or max_remaining_tasks:
return pending
wait_time += BLOCK_LOG_TIMEOUT
for task in pending:
_LOGGER.debug("Waited %s seconds for task: %s", wait_time, task)
return []
2017-02-18 23:17:18 +01:00
hass.async_add_job = async_add_job
2018-08-20 16:34:18 +02:00
hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task
hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass)
hass._await_count_and_log_pending = types.MethodType(
_await_count_and_log_pending, hass
)
2016-10-23 23:48:01 -07:00
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}
2019-07-31 12:25:30 -07:00
hass.config.location_name = "test home"
2016-10-29 08:57:59 -07:00
hass.config.config_dir = get_test_config_dir()
hass.config.latitude = 32.87336
hass.config.longitude = -117.22743
hass.config.elevation = 0
hass.config.time_zone = "US/Pacific"
2016-10-29 08:57:59 -07:00
hass.config.units = METRIC_SYSTEM
hass.config.media_dirs = {"local": get_test_config_dir("media")}
2016-10-29 08:57:59 -07:00
hass.config.skip_pip = True
2016-10-23 23:48:01 -07:00
hass.config_entries = config_entries.ConfigEntries(hass, {})
2021-03-21 18:44:29 -10:00
hass.config_entries._entries = {}
hass.config_entries._store._async_ensure_stop_listener = lambda: None
# Load the registries
if load_registries:
await asyncio.gather(
device_registry.async_load(hass),
entity_registry.async_load(hass),
area_registry.async_load(hass),
)
await hass.async_block_till_done()
2016-10-29 08:57:59 -07:00
hass.state = ha.CoreState.running
2016-10-23 23:48:01 -07:00
2016-11-02 19:16:59 -07:00
# Mock async_start
orig_start = hass.async_start
async def mock_async_start():
2016-11-18 23:05:03 +01:00
"""Start the mocking."""
2017-05-17 15:19:40 -07:00
# We only mock time during tests and we want to track tasks
2019-07-31 12:25:30 -07:00
with patch("homeassistant.core._async_create_timer"), patch.object(
hass, "async_stop_track_tasks"
):
await orig_start()
2016-11-02 19:16:59 -07:00
hass.async_start = mock_async_start
@ha.callback
def clear_instance(event):
2017-03-05 01:19:01 +02:00
"""Clear global instance."""
2017-05-17 15:19:40 -07:00
INSTANCES.remove(hass)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, clear_instance)
2016-10-23 23:48:01 -07:00
return hass
def async_mock_service(hass, domain, service, schema=None):
2017-05-02 18:18:47 +02:00
"""Set up a fake service & return a calls log list to this service."""
calls = []
2018-07-29 01:53:37 +01:00
@ha.callback
2017-03-05 01:19:01 +02:00
def mock_service_log(call): # pylint: disable=unnecessary-lambda
2017-05-02 22:47:20 +02:00
"""Mock service call."""
2016-11-05 16:36:20 -07:00
calls.append(call)
2019-07-31 12:25:30 -07:00
hass.services.async_register(domain, service, mock_service_log, schema=schema)
return calls
2017-06-25 10:53:15 -07:00
mock_service = threadsafe_callback_factory(async_mock_service)
2017-07-21 21:38:53 -07:00
@ha.callback
def async_mock_intent(hass, intent_typ):
"""Set up a fake intent handler."""
intents = []
class MockIntentHandler(intent.IntentHandler):
intent_type = intent_typ
async def async_handle(self, intent):
2017-07-21 21:38:53 -07:00
"""Handle the intent."""
intents.append(intent)
return intent.create_response()
intent.async_register(hass, MockIntentHandler())
return intents
2017-02-07 18:13:24 +01:00
@ha.callback
def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
2016-03-09 10:25:50 +01:00
"""Fire the MQTT message."""
2017-04-13 16:38:09 +02:00
if isinstance(payload, str):
2019-07-31 12:25:30 -07:00
payload = payload.encode("utf-8")
2021-07-06 14:38:48 +02:00
msg = ReceiveMessage(topic, payload, qos, retain)
2019-07-31 12:25:30 -07:00
hass.data["mqtt"]._mqtt_handle_message(msg)
2015-08-10 23:11:46 -07:00
2017-05-01 23:29:01 -07:00
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
2017-02-07 18:13:24 +01:00
2017-05-01 23:29:01 -07:00
@ha.callback
def async_fire_time_changed(
hass: HomeAssistant, datetime_: datetime, fire_all: bool = False
) -> None:
2016-03-09 10:25:50 +01:00
"""Fire a time changes event."""
2020-06-29 11:39:24 -05:00
hass.bus.async_fire(EVENT_TIME_CHANGED, {"now": date_util.as_utc(datetime_)})
for task in list(hass.loop._scheduled):
if not isinstance(task, asyncio.TimerHandle):
continue
if task.cancelled():
continue
mock_seconds_into_future = datetime_.timestamp() - time.time()
future_seconds = task.when() - hass.loop.time()
2020-06-29 11:39:24 -05:00
if fire_all or mock_seconds_into_future >= future_seconds:
with patch(
"homeassistant.helpers.event.time_tracker_utcnow",
return_value=date_util.as_utc(datetime_),
):
task._run()
task.cancel()
2017-05-01 23:29:01 -07:00
fire_time_changed = threadsafe_callback_factory(async_fire_time_changed)
2015-08-03 17:57:12 +02:00
def load_fixture(filename):
2017-05-02 18:18:47 +02:00
"""Load a fixture."""
2019-07-31 12:25:30 -07:00
path = os.path.join(os.path.dirname(__file__), "fixtures", filename)
with open(path, encoding="utf-8") as fptr:
2016-08-23 06:42:05 +02:00
return fptr.read()
2015-04-30 21:03:01 -07:00
def mock_state_change_event(hass, new_state, old_state=None):
2016-03-09 11:15:04 +01:00
"""Mock state change envent."""
2019-07-31 12:25:30 -07:00
event_data = {"entity_id": new_state.entity_id, "new_state": new_state}
2015-04-30 21:03:01 -07:00
if old_state:
2019-07-31 12:25:30 -07:00
event_data["old_state"] = old_state
2015-04-30 21:03:01 -07:00
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
2015-04-30 21:03:01 -07:00
2017-05-01 23:29:01 -07:00
@ha.callback
2017-03-01 05:33:19 +01:00
def mock_component(hass, component):
"""Mock a component is setup."""
2017-04-04 09:29:49 -07:00
if component in hass.config.components:
2020-01-03 15:47:06 +02:00
AssertionError(f"Integration {component} is already setup")
2017-03-01 05:33:19 +01:00
hass.config.components.add(component)
def mock_registry(hass, mock_entries=None):
2018-01-30 01:39:39 -08:00
"""Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass)
2018-09-21 14:47:52 +02:00
registry.entities = mock_entries or OrderedDict()
2020-07-19 19:52:41 -10:00
registry._rebuild_index()
2018-08-18 13:34:33 +02:00
hass.data[entity_registry.DATA_REGISTRY] = registry
2018-01-30 01:39:39 -08:00
return registry
2019-01-29 00:52:42 +01:00
def mock_area_registry(hass, mock_entries=None):
"""Mock the Area Registry."""
registry = area_registry.AreaRegistry(hass)
registry.areas = mock_entries or OrderedDict()
hass.data[area_registry.DATA_REGISTRY] = registry
2019-01-29 00:52:42 +01:00
return registry
def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None):
"""Mock the Device Registry."""
registry = device_registry.DeviceRegistry(hass)
registry.devices = mock_entries or OrderedDict()
registry.deleted_devices = mock_deleted_entries or OrderedDict()
2020-07-19 20:32:05 -10:00
registry._rebuild_index()
hass.data[device_registry.DATA_REGISTRY] = registry
return registry
2018-10-08 16:35:38 +02:00
class MockGroup(auth_models.Group):
"""Mock a group in Home Assistant."""
2019-07-31 12:25:30 -07:00
def __init__(self, id=None, name="Mock Group", policy=system_policies.ADMIN_POLICY):
2018-10-08 16:35:38 +02:00
"""Mock a group."""
2019-07-31 12:25:30 -07:00
kwargs = {"name": name, "policy": policy}
2018-10-08 16:35:38 +02:00
if id is not None:
2019-07-31 12:25:30 -07:00
kwargs["id"] = id
2018-10-08 16:35:38 +02:00
super().__init__(**kwargs)
def add_to_hass(self, hass):
"""Test helper to add entry to hass."""
return self.add_to_auth_manager(hass.auth)
def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass."""
ensure_auth_manager_loaded(auth_mgr)
auth_mgr._store._groups[self.id] = self
return self
2018-07-13 11:43:08 +02:00
class MockUser(auth_models.User):
2018-05-01 12:20:41 -04:00
"""Mock a user in Home Assistant."""
2019-07-31 12:25:30 -07:00
def __init__(
self,
id=None,
is_owner=False,
is_active=True,
name="Mock User",
system_generated=False,
groups=None,
):
2018-05-01 12:20:41 -04:00
"""Initialize mock user."""
2018-08-14 21:14:12 +02:00
kwargs = {
2019-07-31 12:25:30 -07:00
"is_owner": is_owner,
"is_active": is_active,
"name": name,
"system_generated": system_generated,
"groups": groups or [],
"perm_lookup": None,
2018-08-14 21:14:12 +02:00
}
if id is not None:
2019-07-31 12:25:30 -07:00
kwargs["id"] = id
2018-08-14 21:14:12 +02:00
super().__init__(**kwargs)
2018-05-01 12:20:41 -04:00
def add_to_hass(self, hass):
"""Test helper to add entry to hass."""
return self.add_to_auth_manager(hass.auth)
def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass."""
ensure_auth_manager_loaded(auth_mgr)
auth_mgr._store._users[self.id] = self
2018-05-01 12:20:41 -04:00
return self
2018-11-25 18:04:48 +01:00
def mock_policy(self, policy):
"""Mock a policy for a user."""
2019-07-31 12:25:30 -07:00
self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
2018-11-25 18:04:48 +01:00
2018-05-01 12:20:41 -04:00
2018-07-13 15:31:20 +02:00
async def register_auth_provider(hass, config):
"""Register an auth provider."""
2018-07-13 15:31:20 +02:00
provider = await auth_providers.auth_provider_from_config(
2019-07-31 12:25:30 -07:00
hass, hass.auth._store, config
)
assert provider is not None, "Invalid config specified"
2018-07-13 15:31:20 +02:00
key = (provider.type, provider.id)
providers = hass.auth._providers
if key in providers:
2019-07-31 12:25:30 -07:00
raise ValueError("Provider already registered")
2018-07-13 15:31:20 +02:00
providers[key] = provider
return provider
2018-05-01 12:20:41 -04:00
@ha.callback
def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded."""
store = auth_mgr._store
if store._users is None:
store._set_defaults()
2018-05-01 12:20:41 -04:00
2018-07-20 11:45:20 +03:00
class MockModule:
2016-03-09 11:15:04 +01:00
"""Representation of a fake module."""
2016-10-30 22:18:53 +01:00
# pylint: disable=invalid-name
2019-07-31 12:25:30 -07:00
def __init__(
self,
domain=None,
dependencies=None,
setup=None,
requirements=None,
config_schema=None,
platform_schema=None,
platform_schema_base=None,
async_setup=None,
async_setup_entry=None,
async_unload_entry=None,
async_migrate_entry=None,
async_remove_entry=None,
partial_manifest=None,
):
2016-03-09 11:15:04 +01:00
"""Initialize the mock module."""
2020-01-03 15:47:06 +02:00
self.__name__ = f"homeassistant.components.{domain}"
self.__file__ = f"homeassistant/components/{domain}"
self.DOMAIN = domain
2016-04-02 20:10:57 -07:00
self.DEPENDENCIES = dependencies or []
self.REQUIREMENTS = requirements or []
2019-04-16 13:40:21 -07:00
# Overlay to be used when generating manifest from this module
self._partial_manifest = partial_manifest
2016-03-29 00:17:53 -07:00
if config_schema is not None:
self.CONFIG_SCHEMA = config_schema
if platform_schema is not None:
self.PLATFORM_SCHEMA = platform_schema
if platform_schema_base is not None:
self.PLATFORM_SCHEMA_BASE = platform_schema_base
2021-03-29 22:53:47 +02:00
if setup:
# We run this in executor, wrap it in function
self.setup = lambda *args: setup(*args)
2016-10-31 08:47:29 -07:00
if async_setup is not None:
self.async_setup = async_setup
if setup is None and async_setup is None:
2020-04-30 13:29:50 -07:00
self.async_setup = AsyncMock(return_value=True)
2016-01-30 18:55:52 -08:00
2018-02-16 14:07:38 -08:00
if async_setup_entry is not None:
self.async_setup_entry = async_setup_entry
if async_unload_entry is not None:
self.async_unload_entry = async_unload_entry
2019-02-15 11:30:47 -06:00
if async_migrate_entry is not None:
self.async_migrate_entry = async_migrate_entry
2019-03-01 23:13:55 -06:00
if async_remove_entry is not None:
self.async_remove_entry = async_remove_entry
2019-04-16 13:40:21 -07:00
def mock_manifest(self):
"""Generate a mock manifest to represent this module."""
return {
2019-04-17 19:17:13 -07:00
**loader.manifest_from_legacy_module(self.DOMAIN, self),
2019-07-31 12:25:30 -07:00
**(self._partial_manifest or {}),
2019-04-16 13:40:21 -07:00
}
2016-01-30 18:55:52 -08:00
2018-07-20 11:45:20 +03:00
class MockPlatform:
2016-03-09 10:25:50 +01:00
"""Provide a fake platform."""
2016-01-30 18:55:52 -08:00
2019-07-31 12:25:30 -07:00
__name__ = "homeassistant.components.light.bla"
__file__ = "homeassistant/components/blah/light"
2019-02-07 13:33:12 -08:00
2016-10-30 22:18:53 +01:00
# pylint: disable=invalid-name
2019-07-31 12:25:30 -07:00
def __init__(
self,
setup_platform=None,
dependencies=None,
platform_schema=None,
async_setup_platform=None,
async_setup_entry=None,
scan_interval=None,
):
2016-03-09 10:25:50 +01:00
"""Initialize the platform."""
2016-04-02 20:10:57 -07:00
self.DEPENDENCIES = dependencies or []
2016-01-30 18:55:52 -08:00
2016-04-02 20:10:57 -07:00
if platform_schema is not None:
self.PLATFORM_SCHEMA = platform_schema
if scan_interval is not None:
self.SCAN_INTERVAL = scan_interval
if setup_platform is not None:
# We run this in executor, wrap it in function
self.setup_platform = lambda *args: setup_platform(*args)
if async_setup_platform is not None:
self.async_setup_platform = async_setup_platform
if async_setup_entry is not None:
self.async_setup_entry = async_setup_entry
if setup_platform is None and async_setup_platform is None:
2020-04-30 13:29:50 -07:00
self.async_setup_platform = AsyncMock(return_value=None)
class MockEntityPlatform(entity_platform.EntityPlatform):
"""Mock class with some mock defaults."""
def __init__(
2019-07-31 12:25:30 -07:00
self,
hass,
logger=None,
2019-07-31 12:25:30 -07:00
domain="test_domain",
platform_name="test_platform",
platform=None,
scan_interval=timedelta(seconds=15),
entity_namespace=None,
):
"""Initialize a mock entity platform."""
if logger is None:
2019-07-31 12:25:30 -07:00
logger = logging.getLogger("homeassistant.helpers.entity_platform")
# Otherwise the constructor will blow up.
2019-07-31 12:25:30 -07:00
if isinstance(platform, Mock) and isinstance(platform.PARALLEL_UPDATES, Mock):
platform.PARALLEL_UPDATES = 0
super().__init__(
hass=hass,
logger=logger,
domain=domain,
platform_name=platform_name,
platform=platform,
scan_interval=scan_interval,
entity_namespace=entity_namespace,
)
class MockToggleEntity(entity.ToggleEntity):
2016-03-09 10:25:50 +01:00
"""Provide a mock toggle device."""
2016-03-09 11:15:04 +01:00
def __init__(self, name, state, unique_id=None):
"""Initialize the mock entity."""
self._name = name or DEVICE_DEFAULT_NAME
self._state = state
2014-11-25 21:28:43 -08:00
self.calls = []
2014-11-25 00:20:36 -08:00
@property
def name(self):
"""Return the name of the entity if any."""
2019-07-31 12:25:30 -07:00
self.calls.append(("name", {}))
return self._name
@property
def state(self):
"""Return the state of the entity if any."""
2019-07-31 12:25:30 -07:00
self.calls.append(("state", {}))
return self._state
@property
def is_on(self):
"""Return true if entity is on."""
2019-07-31 12:25:30 -07:00
self.calls.append(("is_on", {}))
return self._state == STATE_ON
2014-11-25 00:20:36 -08:00
def turn_on(self, **kwargs):
"""Turn the entity on."""
2019-07-31 12:25:30 -07:00
self.calls.append(("turn_on", kwargs))
self._state = STATE_ON
2014-11-25 00:20:36 -08:00
def turn_off(self, **kwargs):
"""Turn the entity off."""
2019-07-31 12:25:30 -07:00
self.calls.append(("turn_off", kwargs))
self._state = STATE_OFF
2014-11-25 00:20:36 -08:00
2014-11-25 21:28:43 -08:00
def last_call(self, method=None):
2016-03-09 10:25:50 +01:00
"""Return the last call."""
if not self.calls:
return None
2018-07-23 11:16:05 +03:00
if method is None:
2014-11-25 21:28:43 -08:00
return self.calls[-1]
2018-07-23 11:16:05 +03:00
try:
2019-07-31 12:25:30 -07:00
return next(call for call in reversed(self.calls) if call[0] == method)
2018-07-23 11:16:05 +03:00
except StopIteration:
return None
2016-08-23 06:42:05 +02:00
2018-02-16 14:07:38 -08:00
class MockConfigEntry(config_entries.ConfigEntry):
"""Helper for creating config entries that adds some defaults."""
2019-07-31 12:25:30 -07:00
def __init__(
self,
*,
domain="test",
data=None,
version=1,
entry_id=None,
source=config_entries.SOURCE_USER,
title="Mock Title",
state=None,
options={},
pref_disable_new_entities=None,
pref_disable_polling=None,
2019-12-16 12:27:43 +01:00
unique_id=None,
disabled_by=None,
reason=None,
2019-07-31 12:25:30 -07:00
):
2018-02-16 14:07:38 -08:00
"""Initialize a mock config entry."""
kwargs = {
2021-03-21 18:44:29 -10:00
"entry_id": entry_id or uuid_util.random_uuid_hex(),
2019-07-31 12:25:30 -07:00
"domain": domain,
"data": data or {},
"pref_disable_new_entities": pref_disable_new_entities,
"pref_disable_polling": pref_disable_polling,
2019-07-31 12:25:30 -07:00
"options": options,
"version": version,
"title": title,
2019-12-16 12:27:43 +01:00
"unique_id": unique_id,
"disabled_by": disabled_by,
2018-02-16 14:07:38 -08:00
}
if source is not None:
2019-07-31 12:25:30 -07:00
kwargs["source"] = source
2018-02-16 14:07:38 -08:00
if state is not None:
2019-07-31 12:25:30 -07:00
kwargs["state"] = state
2018-02-16 14:07:38 -08:00
super().__init__(**kwargs)
if reason is not None:
self.reason = reason
2018-02-16 14:07:38 -08:00
def add_to_hass(self, hass):
"""Test helper to add entry to hass."""
2021-03-21 18:44:29 -10:00
hass.config_entries._entries[self.entry_id] = self
2018-02-16 14:07:38 -08:00
def add_to_manager(self, manager):
"""Test helper to add entry to entry manager."""
2021-03-21 18:44:29 -10:00
manager._entries[self.entry_id] = self
2018-02-16 14:07:38 -08:00
2016-08-23 06:42:05 +02:00
def patch_yaml_files(files_dict, endswith=True):
"""Patch load_yaml with a dictionary of yaml files."""
# match using endswith, start search with longest string
matchlist = sorted(files_dict.keys(), key=len) if endswith else []
2016-08-23 06:42:05 +02:00
def mock_open_f(fname, **_):
"""Mock open() in the yaml module, used by load_yaml."""
# Return the mocked file on full match
2020-11-25 15:10:04 +01:00
if isinstance(fname, pathlib.Path):
fname = str(fname)
2016-08-23 06:42:05 +02:00
if fname in files_dict:
2017-05-02 18:18:47 +02:00
_LOGGER.debug("patch_yaml_files match %s", fname)
res = StringIO(files_dict[fname])
2019-07-31 12:25:30 -07:00
setattr(res, "name", fname)
return res
2016-08-23 06:42:05 +02:00
# Match using endswith
for ends in matchlist:
if fname.endswith(ends):
2017-05-02 18:18:47 +02:00
_LOGGER.debug("patch_yaml_files end match %s: %s", ends, fname)
res = StringIO(files_dict[ends])
2019-07-31 12:25:30 -07:00
setattr(res, "name", fname)
return res
2016-08-23 06:42:05 +02:00
# Fallback for hass.components (i.e. services.yaml)
2019-07-31 12:25:30 -07:00
if "homeassistant/components" in fname:
2017-05-02 18:18:47 +02:00
_LOGGER.debug("patch_yaml_files using real file: %s", fname)
2019-07-31 12:25:30 -07:00
return open(fname, encoding="utf-8")
2016-08-23 06:42:05 +02:00
# Not found
2020-01-03 15:47:06 +02:00
raise FileNotFoundError(f"File not found: {fname}")
2016-08-23 06:42:05 +02:00
2019-07-31 12:25:30 -07:00
return patch.object(yaml_loader, "open", mock_open_f, create=True)
2018-09-23 14:35:07 -07:00
def mock_coro(return_value=None, exception=None):
"""Return a coro that returns a value or raise an exception."""
2020-04-30 16:31:00 -07:00
fut = asyncio.Future()
if exception is not None:
fut.set_exception(exception)
else:
fut.set_result(return_value)
return fut
2017-02-13 21:34:36 -08:00
@contextmanager
def assert_setup_component(count, domain=None):
"""Collect valid configuration from setup_component.
- count: The amount of valid platforms that should be setup
- domain: The domain to count is optional. It can be automatically
determined most of the time
2017-09-23 17:15:46 +02:00
Use as a context manager around setup.setup_component
with assert_setup_component(0) as result_config:
setup_component(hass, domain, start_config)
# using result_config is optional
"""
config = {}
2019-04-14 07:23:01 -07:00
async def mock_psc(hass, config_input, integration):
"""Mock the prepare_setup_component to capture config."""
2019-04-14 07:23:01 -07:00
domain_input = integration.domain
2019-07-31 12:25:30 -07:00
res = await async_process_component_config(hass, config_input, integration)
2019-04-09 13:59:15 -07:00
config[domain_input] = None if res is None else res.get(domain_input)
2019-07-31 12:25:30 -07:00
_LOGGER.debug(
"Configuration for %s, Validated: %s, Original %s",
domain_input,
config[domain_input],
config_input.get(domain_input),
)
return res
assert isinstance(config, dict)
2019-07-31 12:25:30 -07:00
with patch("homeassistant.config.async_process_component_config", mock_psc):
yield config
if domain is None:
2019-07-31 12:25:30 -07:00
assert len(config) == 1, "assert_setup_component requires DOMAIN: {}".format(
list(config.keys())
)
domain = list(config.keys())[0]
res = config.get(domain)
res_len = 0 if res is None else len(res)
2020-01-03 15:47:06 +02:00
assert (
res_len == count
), f"setup_component failed, expected {count} got {res_len}: {res}"
2017-02-26 14:38:06 -08:00
def init_recorder_component(hass, add_config=None):
"""Initialize the recorder."""
config = dict(add_config) if add_config else {}
2019-07-31 12:25:30 -07:00
config[recorder.CONF_DB_URL] = "sqlite://" # In memory DB
2019-07-31 12:25:30 -07:00
with patch("homeassistant.components.recorder.migration.migrate_schema"):
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config})
2017-02-26 14:38:06 -08:00
assert recorder.DOMAIN in hass.config.components
_LOGGER.info("In-memory recorder successfully started")
async def async_init_recorder_component(hass, add_config=None):
"""Initialize the recorder asynchronously."""
config = dict(add_config) if add_config else {}
config[recorder.CONF_DB_URL] = "sqlite://"
with patch("homeassistant.components.recorder.migration.migrate_schema"):
assert await async_setup_component(
hass, recorder.DOMAIN, {recorder.DOMAIN: config}
)
assert recorder.DOMAIN in hass.config.components
_LOGGER.info("In-memory recorder successfully started")
def mock_restore_cache(hass, states):
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
data = restore_state.RestoreStateData(hass)
now = date_util.utcnow()
last_states = {}
for state in states:
restored_state = state.as_dict()
2019-07-31 12:25:30 -07:00
restored_state["attributes"] = json.loads(
json.dumps(restored_state["attributes"], cls=JSONEncoder)
)
last_states[state.entity_id] = restore_state.StoredState(
2019-07-31 12:25:30 -07:00
State.from_dict(restored_state), now
)
data.last_states = last_states
2019-07-31 12:25:30 -07:00
_LOGGER.debug("Restore cache: %s", data.last_states)
2020-01-03 15:47:06 +02:00
assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
hass.data[key] = data
2017-05-13 21:25:54 -07:00
2018-02-08 03:16:51 -08:00
class MockEntity(entity.Entity):
"""Mock Entity class."""
def __init__(self, **values):
"""Initialize an entity."""
self._values = values
2019-07-31 12:25:30 -07:00
if "entity_id" in values:
self.entity_id = values["entity_id"]
2018-02-08 03:16:51 -08:00
@property
def name(self):
"""Return the name of the entity."""
2019-07-31 12:25:30 -07:00
return self._handle("name")
2018-02-08 03:16:51 -08:00
@property
def should_poll(self):
"""Return the ste of the polling."""
2019-07-31 12:25:30 -07:00
return self._handle("should_poll")
2018-02-08 03:16:51 -08:00
@property
def unique_id(self):
"""Return the unique ID of the entity."""
2019-07-31 12:25:30 -07:00
return self._handle("unique_id")
2018-02-08 03:16:51 -08:00
@property
def state(self):
"""Return the state of the entity."""
return self._handle("state")
2018-02-08 03:16:51 -08:00
@property
def available(self):
"""Return True if entity is available."""
2019-07-31 12:25:30 -07:00
return self._handle("available")
2018-02-08 03:16:51 -08:00
@property
def device_info(self):
"""Info how it links to a device."""
2019-07-31 12:25:30 -07:00
return self._handle("device_info")
@property
def device_class(self):
"""Info how device should be classified."""
return self._handle("device_class")
@property
def unit_of_measurement(self):
"""Info on the units the entity state is in."""
return self._handle("unit_of_measurement")
@property
def capability_attributes(self):
"""Info about capabilities."""
return self._handle("capability_attributes")
@property
def supported_features(self):
"""Info about supported features."""
return self._handle("supported_features")
@property
def entity_registry_enabled_default(self):
"""Return if the entity should be enabled when first added to the entity registry."""
return self._handle("entity_registry_enabled_default")
2018-02-08 03:16:51 -08:00
def _handle(self, attr):
"""Return attribute value."""
2018-02-08 03:16:51 -08:00
if attr in self._values:
return self._values[attr]
return getattr(super(), attr)
2018-06-28 22:14:26 -04:00
@contextmanager
def mock_storage(data=None):
"""Mock storage.
Data is a dict {'key': {'version': version, 'data': data}}
Written data will be converted to JSON to ensure JSON parsing works.
"""
if data is None:
data = {}
orig_load = storage.Store._async_load
async def mock_async_load(store):
"""Mock version of load."""
if store._data is None:
# No data to load
if store.key not in data:
return None
2018-07-13 15:31:20 +02:00
mock_data = data.get(store.key)
2019-07-31 12:25:30 -07:00
if "data" not in mock_data or "version" not in mock_data:
2018-07-13 15:31:20 +02:00
_LOGGER.error('Mock data needs "version" and "data"')
raise ValueError('Mock data needs "version" and "data"')
store._data = mock_data
2018-06-28 22:14:26 -04:00
# Route through original load so that we trigger migration
loaded = await orig_load(store)
2019-07-31 12:25:30 -07:00
_LOGGER.info("Loading data for %s: %s", store.key, loaded)
2018-06-28 22:14:26 -04:00
return loaded
def mock_write_data(store, path, data_to_write):
"""Mock version of write data."""
2019-07-31 12:25:30 -07:00
_LOGGER.info("Writing data to %s: %s", store.key, data_to_write)
# To ensure that the data can be serialized
2019-07-31 12:25:30 -07:00
data[store.key] = json.loads(json.dumps(data_to_write, cls=store._encoder))
2018-06-28 22:14:26 -04:00
async def mock_remove(store):
"""Remove data."""
data.pop(store.key, None)
2019-07-31 12:25:30 -07:00
with patch(
"homeassistant.helpers.storage.Store._async_load",
side_effect=mock_async_load,
autospec=True,
), patch(
"homeassistant.helpers.storage.Store._write_data",
side_effect=mock_write_data,
autospec=True,
), patch(
"homeassistant.helpers.storage.Store.async_remove",
side_effect=mock_remove,
autospec=True,
2019-07-31 12:25:30 -07:00
):
2018-06-28 22:14:26 -04:00
yield data
async def flush_store(store):
"""Make sure all delayed writes of a store are written."""
if store._data is None:
return
store._async_cleanup_final_write_listener()
store._async_cleanup_delay_listener()
2018-06-28 22:14:26 -04:00
await store._async_handle_write_data()
2019-01-30 12:57:56 -08:00
async def get_system_health_info(hass, domain):
"""Get system health info."""
return await hass.data["system_health"][domain].info_callback(hass)
def mock_integration(hass, module, built_in=True):
"""Mock an integration."""
integration = loader.Integration(
hass,
f"{loader.PACKAGE_BUILTIN}.{module.DOMAIN}"
if built_in
else f"{loader.PACKAGE_CUSTOM_COMPONENTS}.{module.DOMAIN}",
None,
module.mock_manifest(),
2019-07-31 12:25:30 -07:00
)
def mock_import_platform(platform_name):
raise ImportError(
f"Mocked unable to import platform '{platform_name}'",
name=f"{integration.pkg_path}.{platform_name}",
)
integration._import_platform = mock_import_platform
_LOGGER.info("Adding mock integration: %s", module.DOMAIN)
2019-07-31 12:25:30 -07:00
hass.data.setdefault(loader.DATA_INTEGRATIONS, {})[module.DOMAIN] = integration
2019-04-14 19:07:05 -07:00
hass.data.setdefault(loader.DATA_COMPONENTS, {})[module.DOMAIN] = module
return integration
def mock_entity_platform(hass, platform_path, module):
"""Mock a entity platform.
platform_path is in form light.hue. Will create platform
hue.light.
"""
2019-07-31 12:25:30 -07:00
domain, platform_name = platform_path.split(".")
mock_platform(hass, f"{platform_name}.{domain}", module)
def mock_platform(hass, platform_path, module=None):
"""Mock a platform.
platform_path is in form hue.config_flow.
"""
domain, platform_name = platform_path.split(".")
integration_cache = hass.data.setdefault(loader.DATA_INTEGRATIONS, {})
2019-04-14 19:07:05 -07:00
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
if domain not in integration_cache:
mock_integration(hass, MockModule(domain))
_LOGGER.info("Adding mock integration platform: %s", platform_path)
module_cache[platform_path] = module or Mock()
def async_capture_events(hass, event_name):
"""Create a helper that captures events."""
events = []
@ha.callback
def capture_events(event):
events.append(event)
hass.bus.async_listen(event_name, capture_events)
return events
2019-09-11 12:34:10 -06:00
@ha.callback
def async_mock_signal(hass, signal):
"""Catch all dispatches to a signal."""
calls = []
@ha.callback
def mock_signal_handler(*args):
"""Mock service call."""
calls.append(args)
hass.helpers.dispatcher.async_dispatcher_connect(signal, mock_signal_handler)
return calls
class hashdict(dict):
"""
hashable dict implementation, suitable for use as a key into other dicts.
>>> h1 = hashdict({"apples": 1, "bananas":2})
>>> h2 = hashdict({"bananas": 3, "mangoes": 5})
>>> h1+h2
hashdict(apples=1, bananas=3, mangoes=5)
>>> d1 = {}
>>> d1[h1] = "salad"
>>> d1[h1]
'salad'
>>> d1[h2]
Traceback (most recent call last):
...
KeyError: hashdict(bananas=3, mangoes=5)
based on answers from
http://stackoverflow.com/questions/1151658/python-hashable-dicts
"""
2019-11-16 11:22:07 +02:00
def __key(self):
return tuple(sorted(self.items()))
def __repr__(self): # noqa: D105 no docstring
2020-01-03 15:47:06 +02:00
return ", ".join(f"{i[0]!s}={i[1]!r}" for i in self.__key())
def __hash__(self): # noqa: D105 no docstring
return hash(self.__key())
def __setitem__(self, key, value): # noqa: D105 no docstring
2020-01-03 15:47:06 +02:00
raise TypeError(f"{self.__class__.__name__} does not support item assignment")
def __delitem__(self, key): # noqa: D105 no docstring
2020-01-03 15:47:06 +02:00
raise TypeError(f"{self.__class__.__name__} does not support item assignment")
def clear(self): # noqa: D102 no docstring
2020-01-03 15:47:06 +02:00
raise TypeError(f"{self.__class__.__name__} does not support item assignment")
def pop(self, *args, **kwargs): # noqa: D102 no docstring
2020-01-03 15:47:06 +02:00
raise TypeError(f"{self.__class__.__name__} does not support item assignment")
def popitem(self, *args, **kwargs): # noqa: D102 no docstring
2020-01-03 15:47:06 +02:00
raise TypeError(f"{self.__class__.__name__} does not support item assignment")
def setdefault(self, *args, **kwargs): # noqa: D102 no docstring
2020-01-03 15:47:06 +02:00
raise TypeError(f"{self.__class__.__name__} does not support item assignment")
def update(self, *args, **kwargs): # noqa: D102 no docstring
2020-01-03 15:47:06 +02:00
raise TypeError(f"{self.__class__.__name__} does not support item assignment")
# update is not ok because it mutates the object
# __add__ is ok because it creates a new object
# while the new object is under construction, it's ok to mutate it
def __add__(self, right): # noqa: D105 no docstring
result = hashdict(self)
dict.update(result, right)
return result
def assert_lists_same(a, b):
"""Compare two lists, ignoring order."""
assert collections.Counter([hashdict(i) for i in a]) == collections.Counter(
[hashdict(i) for i in b]
)