mirror of
https://github.com/home-assistant/core.git
synced 2025-08-07 06:35:10 +02:00
Merge branch 'dev' into jbouwh-mqtt-device-discovery
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -20,6 +21,11 @@ from .models import (
|
|||||||
ConversationInput,
|
ConversationInput,
|
||||||
ConversationResult,
|
ConversationResult,
|
||||||
)
|
)
|
||||||
|
from .trace import (
|
||||||
|
ConversationTraceEvent,
|
||||||
|
ConversationTraceEventType,
|
||||||
|
async_conversation_trace,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -84,15 +90,23 @@ async def async_converse(
|
|||||||
language = hass.config.language
|
language = hass.config.language
|
||||||
|
|
||||||
_LOGGER.debug("Processing in %s: %s", language, text)
|
_LOGGER.debug("Processing in %s: %s", language, text)
|
||||||
return await method(
|
conversation_input = ConversationInput(
|
||||||
ConversationInput(
|
|
||||||
text=text,
|
text=text,
|
||||||
context=context,
|
context=context,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
language=language,
|
language=language,
|
||||||
)
|
)
|
||||||
|
with async_conversation_trace() as trace:
|
||||||
|
trace.add_event(
|
||||||
|
ConversationTraceEvent(
|
||||||
|
ConversationTraceEventType.ASYNC_PROCESS,
|
||||||
|
dataclasses.asdict(conversation_input),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
result = await method(conversation_input)
|
||||||
|
trace.set_result(**result.as_dict())
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class AgentManager:
|
class AgentManager:
|
||||||
|
118
homeassistant/components/conversation/trace.py
Normal file
118
homeassistant/components/conversation/trace.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""Debug traces for conversation."""
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
import enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||||
|
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||||
|
|
||||||
|
STORED_TRACES = 3
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationTraceEventType(enum.StrEnum):
|
||||||
|
"""Type of an event emitted during a conversation."""
|
||||||
|
|
||||||
|
ASYNC_PROCESS = "async_process"
|
||||||
|
"""The conversation is started from user input."""
|
||||||
|
|
||||||
|
AGENT_DETAIL = "agent_detail"
|
||||||
|
"""Event detail added by a conversation agent."""
|
||||||
|
|
||||||
|
LLM_TOOL_CALL = "llm_tool_call"
|
||||||
|
"""An LLM Tool call"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ConversationTraceEvent:
|
||||||
|
"""Event emitted during a conversation."""
|
||||||
|
|
||||||
|
event_type: ConversationTraceEventType
|
||||||
|
data: dict[str, Any] | None = None
|
||||||
|
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationTrace:
|
||||||
|
"""Stores debug data related to a conversation."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize ConversationTrace."""
|
||||||
|
self._trace_id = ulid_util.ulid_now()
|
||||||
|
self._events: list[ConversationTraceEvent] = []
|
||||||
|
self._error: Exception | None = None
|
||||||
|
self._result: dict[str, Any] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trace_id(self) -> str:
|
||||||
|
"""Identifier for this trace."""
|
||||||
|
return self._trace_id
|
||||||
|
|
||||||
|
def add_event(self, event: ConversationTraceEvent) -> None:
|
||||||
|
"""Add an event to the trace."""
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
def set_error(self, ex: Exception) -> None:
|
||||||
|
"""Set error."""
|
||||||
|
self._error = ex
|
||||||
|
|
||||||
|
def set_result(self, **kwargs: Any) -> None:
|
||||||
|
"""Set result."""
|
||||||
|
self._result = {**kwargs}
|
||||||
|
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
"""Return dictionary version of this ConversationTrace."""
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"id": self._trace_id,
|
||||||
|
"events": [asdict(event) for event in self._events],
|
||||||
|
}
|
||||||
|
if self._error is not None:
|
||||||
|
result["error"] = str(self._error) or self._error.__class__.__name__
|
||||||
|
if self._result is not None:
|
||||||
|
result["result"] = self._result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
_current_trace: ContextVar[ConversationTrace | None] = ContextVar(
|
||||||
|
"current_trace", default=None
|
||||||
|
)
|
||||||
|
_recent_traces: LimitedSizeDict[str, ConversationTrace] = LimitedSizeDict(
|
||||||
|
size_limit=STORED_TRACES
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def async_conversation_trace_append(
|
||||||
|
event_type: ConversationTraceEventType, event_data: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Append a ConversationTraceEvent to the current active trace."""
|
||||||
|
trace = _current_trace.get()
|
||||||
|
if not trace:
|
||||||
|
return
|
||||||
|
trace.add_event(ConversationTraceEvent(event_type, event_data))
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def async_conversation_trace() -> Generator[ConversationTrace, None]:
|
||||||
|
"""Create a new active ConversationTrace."""
|
||||||
|
trace = ConversationTrace()
|
||||||
|
token = _current_trace.set(trace)
|
||||||
|
_recent_traces[trace.trace_id] = trace
|
||||||
|
try:
|
||||||
|
yield trace
|
||||||
|
except Exception as ex:
|
||||||
|
trace.set_error(ex)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
_current_trace.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def async_get_traces() -> list[ConversationTrace]:
|
||||||
|
"""Get the most recent traces."""
|
||||||
|
return list(_recent_traces.values())
|
||||||
|
|
||||||
|
|
||||||
|
def async_clear_traces() -> None:
|
||||||
|
"""Clear all traces."""
|
||||||
|
_recent_traces.clear()
|
@@ -5,5 +5,5 @@
|
|||||||
"documentation": "https://www.home-assistant.io/integrations/envisalink",
|
"documentation": "https://www.home-assistant.io/integrations/envisalink",
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"loggers": ["pyenvisalink"],
|
"loggers": ["pyenvisalink"],
|
||||||
"requirements": ["pyenvisalink==4.6"]
|
"requirements": ["pyenvisalink==4.7"]
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
"""Support for Arduino-compatible Microcontrollers through Firmata."""
|
"""Support for Arduino-compatible Microcontrollers through Firmata."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -212,16 +211,15 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
|
|||||||
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
||||||
"""Shutdown and close a Firmata board for a config entry."""
|
"""Shutdown and close a Firmata board for a config entry."""
|
||||||
_LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME])
|
_LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME])
|
||||||
|
results: list[bool] = []
|
||||||
unload_entries = []
|
if platforms := [
|
||||||
for conf, platform in CONF_PLATFORM_MAP.items():
|
platform
|
||||||
if conf in config_entry.data:
|
for conf, platform in CONF_PLATFORM_MAP.items()
|
||||||
unload_entries.append(
|
if conf in config_entry.data
|
||||||
hass.config_entries.async_forward_entry_unload(config_entry, platform)
|
]:
|
||||||
|
results.append(
|
||||||
|
await hass.config_entries.async_unload_platforms(config_entry, platforms)
|
||||||
)
|
)
|
||||||
results = []
|
|
||||||
if unload_entries:
|
|
||||||
results = await asyncio.gather(*unload_entries)
|
|
||||||
results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset())
|
results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset())
|
||||||
|
|
||||||
return False not in results
|
return False not in results
|
||||||
|
@@ -11,12 +11,13 @@ from .const import (
|
|||||||
CONF_DAMPING_EVENING,
|
CONF_DAMPING_EVENING,
|
||||||
CONF_DAMPING_MORNING,
|
CONF_DAMPING_MORNING,
|
||||||
CONF_MODULES_POWER,
|
CONF_MODULES_POWER,
|
||||||
DOMAIN,
|
|
||||||
)
|
)
|
||||||
from .coordinator import ForecastSolarDataUpdateCoordinator
|
from .coordinator import ForecastSolarDataUpdateCoordinator
|
||||||
|
|
||||||
PLATFORMS = [Platform.SENSOR]
|
PLATFORMS = [Platform.SENSOR]
|
||||||
|
|
||||||
|
type ForecastSolarConfigEntry = ConfigEntry[ForecastSolarDataUpdateCoordinator]
|
||||||
|
|
||||||
|
|
||||||
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Migrate old config entry."""
|
"""Migrate old config entry."""
|
||||||
@@ -36,12 +37,14 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant, entry: ForecastSolarConfigEntry
|
||||||
|
) -> bool:
|
||||||
"""Set up Forecast.Solar from a config entry."""
|
"""Set up Forecast.Solar from a config entry."""
|
||||||
coordinator = ForecastSolarDataUpdateCoordinator(hass, entry)
|
coordinator = ForecastSolarDataUpdateCoordinator(hass, entry)
|
||||||
await coordinator.async_config_entry_first_refresh()
|
await coordinator.async_config_entry_first_refresh()
|
||||||
|
|
||||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
|
entry.runtime_data = coordinator
|
||||||
|
|
||||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||||
|
|
||||||
@@ -52,11 +55,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload a config entry."""
|
"""Unload a config entry."""
|
||||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||||
if unload_ok:
|
|
||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
|
||||||
|
|
||||||
return unload_ok
|
|
||||||
|
|
||||||
|
|
||||||
async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||||
|
@@ -4,15 +4,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from forecast_solar import Estimate
|
|
||||||
|
|
||||||
from homeassistant.components.diagnostics import async_redact_data
|
from homeassistant.components.diagnostics import async_redact_data
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE
|
from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
|
|
||||||
|
|
||||||
from .const import DOMAIN
|
from . import ForecastSolarConfigEntry
|
||||||
|
|
||||||
TO_REDACT = {
|
TO_REDACT = {
|
||||||
CONF_API_KEY,
|
CONF_API_KEY,
|
||||||
@@ -22,10 +18,10 @@ TO_REDACT = {
|
|||||||
|
|
||||||
|
|
||||||
async def async_get_config_entry_diagnostics(
|
async def async_get_config_entry_diagnostics(
|
||||||
hass: HomeAssistant, entry: ConfigEntry
|
hass: HomeAssistant, entry: ForecastSolarConfigEntry
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return diagnostics for a config entry."""
|
"""Return diagnostics for a config entry."""
|
||||||
coordinator: DataUpdateCoordinator[Estimate] = hass.data[DOMAIN][entry.entry_id]
|
coordinator = entry.runtime_data
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"entry": {
|
"entry": {
|
||||||
|
@@ -4,19 +4,21 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .coordinator import ForecastSolarDataUpdateCoordinator
|
||||||
|
|
||||||
|
|
||||||
async def async_get_solar_forecast(
|
async def async_get_solar_forecast(
|
||||||
hass: HomeAssistant, config_entry_id: str
|
hass: HomeAssistant, config_entry_id: str
|
||||||
) -> dict[str, dict[str, float | int]] | None:
|
) -> dict[str, dict[str, float | int]] | None:
|
||||||
"""Get solar forecast for a config entry ID."""
|
"""Get solar forecast for a config entry ID."""
|
||||||
if (coordinator := hass.data[DOMAIN].get(config_entry_id)) is None:
|
if (
|
||||||
|
entry := hass.config_entries.async_get_entry(config_entry_id)
|
||||||
|
) is None or not isinstance(entry.runtime_data, ForecastSolarDataUpdateCoordinator):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"wh_hours": {
|
"wh_hours": {
|
||||||
timestamp.isoformat(): val
|
timestamp.isoformat(): val
|
||||||
for timestamp, val in coordinator.data.wh_period.items()
|
for timestamp, val in entry.runtime_data.data.wh_period.items()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -16,7 +16,6 @@ from homeassistant.components.sensor import (
|
|||||||
SensorEntityDescription,
|
SensorEntityDescription,
|
||||||
SensorStateClass,
|
SensorStateClass,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.const import UnitOfEnergy, UnitOfPower
|
from homeassistant.const import UnitOfEnergy, UnitOfPower
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||||
@@ -24,6 +23,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|||||||
from homeassistant.helpers.typing import StateType
|
from homeassistant.helpers.typing import StateType
|
||||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||||
|
|
||||||
|
from . import ForecastSolarConfigEntry
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .coordinator import ForecastSolarDataUpdateCoordinator
|
from .coordinator import ForecastSolarDataUpdateCoordinator
|
||||||
|
|
||||||
@@ -133,10 +133,12 @@ SENSORS: tuple[ForecastSolarSensorEntityDescription, ...] = (
|
|||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
|
hass: HomeAssistant,
|
||||||
|
entry: ForecastSolarConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Defer sensor setup to the shared sensor module."""
|
"""Defer sensor setup to the shared sensor module."""
|
||||||
coordinator: ForecastSolarDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id]
|
coordinator = entry.runtime_data
|
||||||
|
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
ForecastSolarSensorEntity(
|
ForecastSolarSensorEntity(
|
||||||
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
from typing import Final
|
from typing import Any, Final
|
||||||
|
|
||||||
from homeassistant.components.button import (
|
from homeassistant.components.button import (
|
||||||
ButtonDeviceClass,
|
ButtonDeviceClass,
|
||||||
@@ -30,7 +30,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
class FritzButtonDescription(ButtonEntityDescription):
|
class FritzButtonDescription(ButtonEntityDescription):
|
||||||
"""Class to describe a Button entity."""
|
"""Class to describe a Button entity."""
|
||||||
|
|
||||||
press_action: Callable
|
press_action: Callable[[AvmWrapper], Any]
|
||||||
|
|
||||||
|
|
||||||
BUTTONS: Final = [
|
BUTTONS: Final = [
|
||||||
|
@@ -57,9 +57,6 @@ ERROR_UPNP_NOT_CONFIGURED = "upnp_not_configured"
|
|||||||
ERROR_UNKNOWN = "unknown_error"
|
ERROR_UNKNOWN = "unknown_error"
|
||||||
|
|
||||||
FRITZ_SERVICES = "fritz_services"
|
FRITZ_SERVICES = "fritz_services"
|
||||||
SERVICE_REBOOT = "reboot"
|
|
||||||
SERVICE_RECONNECT = "reconnect"
|
|
||||||
SERVICE_CLEANUP = "cleanup"
|
|
||||||
SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password"
|
SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password"
|
||||||
|
|
||||||
SWITCH_TYPE_DEFLECTION = "CallDeflection"
|
SWITCH_TYPE_DEFLECTION = "CallDeflection"
|
||||||
|
@@ -46,9 +46,6 @@ from .const import (
|
|||||||
DEFAULT_USERNAME,
|
DEFAULT_USERNAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
FRITZ_EXCEPTIONS,
|
FRITZ_EXCEPTIONS,
|
||||||
SERVICE_CLEANUP,
|
|
||||||
SERVICE_REBOOT,
|
|
||||||
SERVICE_RECONNECT,
|
|
||||||
SERVICE_SET_GUEST_WIFI_PW,
|
SERVICE_SET_GUEST_WIFI_PW,
|
||||||
MeshRoles,
|
MeshRoles,
|
||||||
)
|
)
|
||||||
@@ -730,30 +727,6 @@ class FritzBoxTools(DataUpdateCoordinator[UpdateCoordinatorDataType]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if service_call.service == SERVICE_REBOOT:
|
|
||||||
_LOGGER.warning(
|
|
||||||
'Service "fritz.reboot" is deprecated, please use the corresponding'
|
|
||||||
" button entity instead"
|
|
||||||
)
|
|
||||||
await self.async_trigger_reboot()
|
|
||||||
return
|
|
||||||
|
|
||||||
if service_call.service == SERVICE_RECONNECT:
|
|
||||||
_LOGGER.warning(
|
|
||||||
'Service "fritz.reconnect" is deprecated, please use the'
|
|
||||||
" corresponding button entity instead"
|
|
||||||
)
|
|
||||||
await self.async_trigger_reconnect()
|
|
||||||
return
|
|
||||||
|
|
||||||
if service_call.service == SERVICE_CLEANUP:
|
|
||||||
_LOGGER.warning(
|
|
||||||
'Service "fritz.cleanup" is deprecated, please use the'
|
|
||||||
" corresponding button entity instead"
|
|
||||||
)
|
|
||||||
await self.async_trigger_cleanup(config_entry)
|
|
||||||
return
|
|
||||||
|
|
||||||
if service_call.service == SERVICE_SET_GUEST_WIFI_PW:
|
if service_call.service == SERVICE_SET_GUEST_WIFI_PW:
|
||||||
await self.async_trigger_set_guest_password(
|
await self.async_trigger_set_guest_password(
|
||||||
service_call.data.get("password"),
|
service_call.data.get("password"),
|
||||||
|
@@ -11,14 +11,7 @@ from homeassistant.core import HomeAssistant, ServiceCall
|
|||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers.service import async_extract_config_entry_ids
|
from homeassistant.helpers.service import async_extract_config_entry_ids
|
||||||
|
|
||||||
from .const import (
|
from .const import DOMAIN, FRITZ_SERVICES, SERVICE_SET_GUEST_WIFI_PW
|
||||||
DOMAIN,
|
|
||||||
FRITZ_SERVICES,
|
|
||||||
SERVICE_CLEANUP,
|
|
||||||
SERVICE_REBOOT,
|
|
||||||
SERVICE_RECONNECT,
|
|
||||||
SERVICE_SET_GUEST_WIFI_PW,
|
|
||||||
)
|
|
||||||
from .coordinator import AvmWrapper
|
from .coordinator import AvmWrapper
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@@ -32,9 +25,6 @@ SERVICE_SCHEMA_SET_GUEST_WIFI_PW = vol.Schema(
|
|||||||
)
|
)
|
||||||
|
|
||||||
SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [
|
SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [
|
||||||
(SERVICE_CLEANUP, None),
|
|
||||||
(SERVICE_REBOOT, None),
|
|
||||||
(SERVICE_RECONNECT, None),
|
|
||||||
(SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW),
|
(SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@@ -1,31 +1,3 @@
|
|||||||
reconnect:
|
|
||||||
fields:
|
|
||||||
device_id:
|
|
||||||
required: true
|
|
||||||
selector:
|
|
||||||
device:
|
|
||||||
integration: fritz
|
|
||||||
entity:
|
|
||||||
device_class: connectivity
|
|
||||||
reboot:
|
|
||||||
fields:
|
|
||||||
device_id:
|
|
||||||
required: true
|
|
||||||
selector:
|
|
||||||
device:
|
|
||||||
integration: fritz
|
|
||||||
entity:
|
|
||||||
device_class: connectivity
|
|
||||||
|
|
||||||
cleanup:
|
|
||||||
fields:
|
|
||||||
device_id:
|
|
||||||
required: true
|
|
||||||
selector:
|
|
||||||
device:
|
|
||||||
integration: fritz
|
|
||||||
entity:
|
|
||||||
device_class: connectivity
|
|
||||||
set_guest_wifi_password:
|
set_guest_wifi_password:
|
||||||
fields:
|
fields:
|
||||||
device_id:
|
device_id:
|
||||||
|
@@ -144,42 +144,12 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"services": {
|
"services": {
|
||||||
"reconnect": {
|
|
||||||
"name": "[%key:component::fritz::entity::button::reconnect::name%]",
|
|
||||||
"description": "Reconnects your FRITZ!Box internet connection.",
|
|
||||||
"fields": {
|
|
||||||
"device_id": {
|
|
||||||
"name": "Fritz!Box Device",
|
|
||||||
"description": "Select the Fritz!Box to reconnect."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"reboot": {
|
|
||||||
"name": "Reboot",
|
|
||||||
"description": "Reboots your FRITZ!Box.",
|
|
||||||
"fields": {
|
|
||||||
"device_id": {
|
|
||||||
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
|
|
||||||
"description": "Select the Fritz!Box to reboot."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"cleanup": {
|
|
||||||
"name": "Remove stale device tracker entities",
|
|
||||||
"description": "Remove FRITZ!Box stale device_tracker entities.",
|
|
||||||
"fields": {
|
|
||||||
"device_id": {
|
|
||||||
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
|
|
||||||
"description": "Select the Fritz!Box to check."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"set_guest_wifi_password": {
|
"set_guest_wifi_password": {
|
||||||
"name": "Set guest Wi-Fi password",
|
"name": "Set guest Wi-Fi password",
|
||||||
"description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.",
|
"description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.",
|
||||||
"fields": {
|
"fields": {
|
||||||
"device_id": {
|
"device_id": {
|
||||||
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
|
"name": "Fritz!Box Device",
|
||||||
"description": "Select the Fritz!Box to configure."
|
"description": "Select the Fritz!Box to configure."
|
||||||
},
|
},
|
||||||
"password": {
|
"password": {
|
||||||
|
46
homeassistant/components/fronius/diagnostics.py
Normal file
46
homeassistant/components/fronius/diagnostics.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Diagnostics support for Fronius."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.components.diagnostics import async_redact_data
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from . import FroniusConfigEntry
|
||||||
|
|
||||||
|
TO_REDACT = {"unique_id", "unique_identifier", "serial"}
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_config_entry_diagnostics(
|
||||||
|
hass: HomeAssistant, config_entry: FroniusConfigEntry
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return diagnostics for a config entry."""
|
||||||
|
diag: dict[str, Any] = {}
|
||||||
|
solar_net = config_entry.runtime_data
|
||||||
|
fronius = solar_net.fronius
|
||||||
|
|
||||||
|
diag["config_entry"] = config_entry.as_dict()
|
||||||
|
diag["inverter_info"] = await fronius.inverter_info()
|
||||||
|
|
||||||
|
diag["coordinators"] = {"inverters": {}}
|
||||||
|
for inv in solar_net.inverter_coordinators:
|
||||||
|
diag["coordinators"]["inverters"] |= inv.data
|
||||||
|
|
||||||
|
diag["coordinators"]["logger"] = (
|
||||||
|
solar_net.logger_coordinator.data if solar_net.logger_coordinator else None
|
||||||
|
)
|
||||||
|
diag["coordinators"]["meter"] = (
|
||||||
|
solar_net.meter_coordinator.data if solar_net.meter_coordinator else None
|
||||||
|
)
|
||||||
|
diag["coordinators"]["ohmpilot"] = (
|
||||||
|
solar_net.ohmpilot_coordinator.data if solar_net.ohmpilot_coordinator else None
|
||||||
|
)
|
||||||
|
diag["coordinators"]["power_flow"] = (
|
||||||
|
solar_net.power_flow_coordinator.data
|
||||||
|
if solar_net.power_flow_coordinator
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
diag["coordinators"]["storage"] = (
|
||||||
|
solar_net.storage_coordinator.data if solar_net.storage_coordinator else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return async_redact_data(diag, TO_REDACT)
|
@@ -181,8 +181,7 @@ async def google_generative_ai_config_option_schema(
|
|||||||
schema = {
|
schema = {
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
|
||||||
default=DEFAULT_PROMPT,
|
|
||||||
): TemplateSelector(),
|
): TemplateSelector(),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_LLM_HASS_API,
|
CONF_LLM_HASS_API,
|
||||||
|
@@ -22,4 +22,4 @@ CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
|
|||||||
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
||||||
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
||||||
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
|
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
|
||||||
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_LOW_AND_ABOVE"
|
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
|
||||||
|
@@ -5,13 +5,14 @@ from __future__ import annotations
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import google.ai.generativelanguage as glm
|
import google.ai.generativelanguage as glm
|
||||||
from google.api_core.exceptions import ClientError
|
from google.api_core.exceptions import GoogleAPICallError
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import google.generativeai.types as genai_types
|
import google.generativeai.types as genai_types
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@@ -205,15 +206,6 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
messages = [{}, {}]
|
messages = [{}, {}]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt = template.Template(
|
|
||||||
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
|
||||||
).async_render(
|
|
||||||
{
|
|
||||||
"ha_name": self.hass.config.location_name,
|
|
||||||
},
|
|
||||||
parse_result=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if llm_api:
|
if llm_api:
|
||||||
empty_tool_input = llm.ToolInput(
|
empty_tool_input = llm.ToolInput(
|
||||||
tool_name="",
|
tool_name="",
|
||||||
@@ -226,8 +218,23 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
device_id=user_input.device_id,
|
device_id=user_input.device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = (
|
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
|
||||||
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
|
|
||||||
|
else:
|
||||||
|
api_prompt = llm.PROMPT_NO_API_CONFIGURED
|
||||||
|
|
||||||
|
prompt = "\n".join(
|
||||||
|
(
|
||||||
|
template.Template(
|
||||||
|
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||||
|
).async_render(
|
||||||
|
{
|
||||||
|
"ha_name": self.hass.config.location_name,
|
||||||
|
},
|
||||||
|
parse_result=False,
|
||||||
|
),
|
||||||
|
api_prompt,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
except TemplateError as err:
|
except TemplateError as err:
|
||||||
@@ -244,6 +251,9 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
messages[1] = {"role": "model", "parts": "Ok"}
|
messages[1] = {"role": "model", "parts": "Ok"}
|
||||||
|
|
||||||
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
|
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
|
||||||
|
trace.async_conversation_trace_append(
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
||||||
|
)
|
||||||
|
|
||||||
chat = model.start_chat(history=messages)
|
chat = model.start_chat(history=messages)
|
||||||
chat_request = user_input.text
|
chat_request = user_input.text
|
||||||
@@ -252,15 +262,25 @@ class GoogleGenerativeAIConversationEntity(
|
|||||||
try:
|
try:
|
||||||
chat_response = await chat.send_message_async(chat_request)
|
chat_response = await chat.send_message_async(chat_request)
|
||||||
except (
|
except (
|
||||||
ClientError,
|
GoogleAPICallError,
|
||||||
ValueError,
|
ValueError,
|
||||||
genai_types.BlockedPromptException,
|
genai_types.BlockedPromptException,
|
||||||
genai_types.StopCandidateException,
|
genai_types.StopCandidateException,
|
||||||
) as err:
|
) as err:
|
||||||
LOGGER.error("Error sending message: %s", err)
|
LOGGER.error("Error sending message: %s %s", type(err), err)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
err, genai_types.StopCandidateException
|
||||||
|
) and "finish_reason: SAFETY\n" in str(err):
|
||||||
|
error = "The message got blocked by your safety settings"
|
||||||
|
else:
|
||||||
|
error = (
|
||||||
|
f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||||
|
)
|
||||||
|
|
||||||
intent_response.async_set_error(
|
intent_response.async_set_error(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
intent.IntentResponseErrorCode.UNKNOWN,
|
||||||
f"Sorry, I had a problem talking to Google Generative AI: {err}",
|
error,
|
||||||
)
|
)
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=conversation_id
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"domain": "integration",
|
"domain": "integration",
|
||||||
"name": "Integration - Riemann sum integral",
|
"name": "Integral",
|
||||||
"after_dependencies": ["counter"],
|
"after_dependencies": ["counter"],
|
||||||
"codeowners": ["@dgomes"],
|
"codeowners": ["@dgomes"],
|
||||||
"config_flow": true,
|
"config_flow": true,
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"title": "Integration - Riemann sum integral sensor",
|
"title": "Integral sensor",
|
||||||
"config": {
|
"config": {
|
||||||
"step": {
|
"step": {
|
||||||
"user": {
|
"user": {
|
||||||
|
@@ -21,7 +21,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|||||||
CONF_VALIDATOR = "validator"
|
CONF_VALIDATOR = "validator"
|
||||||
CONF_SECRET = "secret"
|
CONF_SECRET = "secret"
|
||||||
URL = "/api/meraki"
|
URL = "/api/meraki"
|
||||||
VERSION = "2.0"
|
ACCEPTED_VERSIONS = ["2.0", "2.1"]
|
||||||
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@@ -74,7 +74,7 @@ class MerakiView(HomeAssistantView):
|
|||||||
if data["secret"] != self.secret:
|
if data["secret"] != self.secret:
|
||||||
_LOGGER.error("Invalid Secret received from Meraki")
|
_LOGGER.error("Invalid Secret received from Meraki")
|
||||||
return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY)
|
return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY)
|
||||||
if data["version"] != VERSION:
|
if data["version"] not in ACCEPTED_VERSIONS:
|
||||||
_LOGGER.error("Invalid API version: %s", data["version"])
|
_LOGGER.error("Invalid API version: %s", data["version"])
|
||||||
return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY)
|
return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY)
|
||||||
_LOGGER.debug("Valid Secret")
|
_LOGGER.debug("Valid Secret")
|
||||||
|
@@ -6,6 +6,6 @@
|
|||||||
"documentation": "https://www.home-assistant.io/integrations/minecraft_server",
|
"documentation": "https://www.home-assistant.io/integrations/minecraft_server",
|
||||||
"iot_class": "local_polling",
|
"iot_class": "local_polling",
|
||||||
"loggers": ["dnspython", "mcstatus"],
|
"loggers": ["dnspython", "mcstatus"],
|
||||||
"quality_scale": "gold",
|
"quality_scale": "platinum",
|
||||||
"requirements": ["mcstatus==11.1.1"]
|
"requirements": ["mcstatus==11.1.1"]
|
||||||
}
|
}
|
||||||
|
@@ -39,6 +39,7 @@ from .client import ( # noqa: F401
|
|||||||
MQTT,
|
MQTT,
|
||||||
async_publish,
|
async_publish,
|
||||||
async_subscribe,
|
async_subscribe,
|
||||||
|
async_subscribe_internal,
|
||||||
publish,
|
publish,
|
||||||
subscribe,
|
subscribe,
|
||||||
)
|
)
|
||||||
@@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
def collect_msg(msg: ReceiveMessage) -> None:
|
def collect_msg(msg: ReceiveMessage) -> None:
|
||||||
messages.append((msg.topic, str(msg.payload).replace("\n", "")))
|
messages.append((msg.topic, str(msg.payload).replace("\n", "")))
|
||||||
|
|
||||||
unsub = await async_subscribe(hass, call.data["topic"], collect_msg)
|
unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)
|
||||||
|
|
||||||
def write_dump() -> None:
|
def write_dump() -> None:
|
||||||
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
|
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
|
||||||
@@ -459,7 +460,7 @@ async def websocket_subscribe(
|
|||||||
|
|
||||||
# Perform UTF-8 decoding directly in callback routine
|
# Perform UTF-8 decoding directly in callback routine
|
||||||
qos: int = msg.get("qos", DEFAULT_QOS)
|
qos: int = msg.get("qos", DEFAULT_QOS)
|
||||||
connection.subscriptions[msg["id"]] = await async_subscribe(
|
connection.subscriptions[msg["id"]] = async_subscribe_internal(
|
||||||
hass, msg["topic"], forward_messages, encoding=None, qos=qos
|
hass, msg["topic"], forward_messages, encoding=None, qos=qos
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,24 +523,13 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
mqtt_client = mqtt_data.client
|
mqtt_client = mqtt_data.client
|
||||||
|
|
||||||
# Unload publish and dump services.
|
# Unload publish and dump services.
|
||||||
hass.services.async_remove(
|
hass.services.async_remove(DOMAIN, SERVICE_PUBLISH)
|
||||||
DOMAIN,
|
hass.services.async_remove(DOMAIN, SERVICE_DUMP)
|
||||||
SERVICE_PUBLISH,
|
|
||||||
)
|
|
||||||
hass.services.async_remove(
|
|
||||||
DOMAIN,
|
|
||||||
SERVICE_DUMP,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stop the discovery
|
# Stop the discovery
|
||||||
await discovery.async_stop(hass)
|
await discovery.async_stop(hass)
|
||||||
# Unload the platforms
|
# Unload the platforms
|
||||||
await asyncio.gather(
|
await hass.config_entries.async_unload_platforms(entry, mqtt_data.platforms_loaded)
|
||||||
*(
|
|
||||||
hass.config_entries.async_forward_entry_unload(entry, component)
|
|
||||||
for component in mqtt_data.platforms_loaded
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mqtt_data.platforms_loaded = set()
|
mqtt_data.platforms_loaded = set()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
# Unsubscribe reload dispatchers
|
# Unsubscribe reload dispatchers
|
||||||
|
@@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_alarm_disarm(self, code: str | None = None) -> None:
|
async def async_alarm_disarm(self, code: str | None = None) -> None:
|
||||||
"""Send disarm command.
|
"""Send disarm command.
|
||||||
|
@@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _value_is_expired(self, *_: Any) -> None:
|
def _value_is_expired(self, *_: Any) -> None:
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -20,7 +21,6 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|||||||
from . import subscription
|
from . import subscription
|
||||||
from .config import MQTT_BASE_SCHEMA
|
from .config import MQTT_BASE_SCHEMA
|
||||||
from .const import CONF_QOS, CONF_TOPIC
|
from .const import CONF_QOS, CONF_TOPIC
|
||||||
from .debug_info import log_messages
|
|
||||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .models import ReceiveMessage
|
from .models import ReceiveMessage
|
||||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||||
@@ -97,12 +97,8 @@ class MqttCamera(MqttEntity, Camera):
|
|||||||
"""Return the config schema."""
|
"""Return the config schema."""
|
||||||
return DISCOVERY_SCHEMA
|
return DISCOVERY_SCHEMA
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _image_received(self, msg: ReceiveMessage) -> None:
|
||||||
def message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
if CONF_IMAGE_ENCODING in self._config:
|
if CONF_IMAGE_ENCODING in self._config:
|
||||||
self._last_image = b64decode(msg.payload)
|
self._last_image = b64decode(msg.payload)
|
||||||
@@ -111,13 +107,21 @@ class MqttCamera(MqttEntity, Camera):
|
|||||||
assert isinstance(msg.payload, bytes)
|
assert isinstance(msg.payload, bytes)
|
||||||
self._last_image = msg.payload
|
self._last_image = msg.payload
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
{
|
{
|
||||||
"state_topic": {
|
"state_topic": {
|
||||||
"topic": self._config[CONF_TOPIC],
|
"topic": self._config[CONF_TOPIC],
|
||||||
"msg_callback": message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._image_received,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": None,
|
"encoding": None,
|
||||||
}
|
}
|
||||||
@@ -126,7 +130,7 @@ class MqttCamera(MqttEntity, Camera):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_camera_image(
|
async def async_camera_image(
|
||||||
self, width: int | None = None, height: int | None = None
|
self, width: int | None = None, height: int | None = None
|
||||||
|
@@ -77,7 +77,6 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .models import (
|
from .models import (
|
||||||
DATA_MQTT,
|
DATA_MQTT,
|
||||||
AsyncMessageCallbackType,
|
|
||||||
MessageCallbackType,
|
MessageCallbackType,
|
||||||
MqttData,
|
MqttData,
|
||||||
PublishMessage,
|
PublishMessage,
|
||||||
@@ -184,7 +183,7 @@ async def async_publish(
|
|||||||
async def async_subscribe(
|
async def async_subscribe(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
topic: str,
|
topic: str,
|
||||||
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||||
qos: int = DEFAULT_QOS,
|
qos: int = DEFAULT_QOS,
|
||||||
encoding: str | None = DEFAULT_ENCODING,
|
encoding: str | None = DEFAULT_ENCODING,
|
||||||
) -> CALLBACK_TYPE:
|
) -> CALLBACK_TYPE:
|
||||||
@@ -192,13 +191,25 @@ async def async_subscribe(
|
|||||||
|
|
||||||
Call the return value to unsubscribe.
|
Call the return value to unsubscribe.
|
||||||
"""
|
"""
|
||||||
if not mqtt_config_entry_enabled(hass):
|
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
|
||||||
raise HomeAssistantError(
|
|
||||||
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
|
|
||||||
translation_key="mqtt_not_setup_cannot_subscribe",
|
@callback
|
||||||
translation_domain=DOMAIN,
|
def async_subscribe_internal(
|
||||||
translation_placeholders={"topic": topic},
|
hass: HomeAssistant,
|
||||||
)
|
topic: str,
|
||||||
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||||
|
qos: int = DEFAULT_QOS,
|
||||||
|
encoding: str | None = DEFAULT_ENCODING,
|
||||||
|
) -> CALLBACK_TYPE:
|
||||||
|
"""Subscribe to an MQTT topic.
|
||||||
|
|
||||||
|
This function is internal to the MQTT integration
|
||||||
|
and may change at any time. It should not be considered
|
||||||
|
a stable API.
|
||||||
|
|
||||||
|
Call the return value to unsubscribe.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
mqtt_data = hass.data[DATA_MQTT]
|
mqtt_data = hass.data[DATA_MQTT]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
@@ -209,12 +220,15 @@ async def async_subscribe(
|
|||||||
translation_domain=DOMAIN,
|
translation_domain=DOMAIN,
|
||||||
translation_placeholders={"topic": topic},
|
translation_placeholders={"topic": topic},
|
||||||
) from exc
|
) from exc
|
||||||
return await mqtt_data.client.async_subscribe(
|
client = mqtt_data.client
|
||||||
topic,
|
if not client.connected and not mqtt_config_entry_enabled(hass):
|
||||||
msg_callback,
|
raise HomeAssistantError(
|
||||||
qos,
|
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
|
||||||
encoding,
|
translation_key="mqtt_not_setup_cannot_subscribe",
|
||||||
|
translation_domain=DOMAIN,
|
||||||
|
translation_placeholders={"topic": topic},
|
||||||
)
|
)
|
||||||
|
return client.async_subscribe(topic, msg_callback, qos, encoding)
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
@@ -429,10 +443,10 @@ class MQTT:
|
|||||||
self.config_entry = config_entry
|
self.config_entry = config_entry
|
||||||
self.conf = conf
|
self.conf = conf
|
||||||
|
|
||||||
self._simple_subscriptions: defaultdict[str, list[Subscription]] = defaultdict(
|
self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
|
||||||
list
|
set
|
||||||
)
|
)
|
||||||
self._wildcard_subscriptions: list[Subscription] = []
|
self._wildcard_subscriptions: set[Subscription] = set()
|
||||||
# _retained_topics prevents a Subscription from receiving a
|
# _retained_topics prevents a Subscription from receiving a
|
||||||
# retained message more than once per topic. This prevents flooding
|
# retained message more than once per topic. This prevents flooding
|
||||||
# already active subscribers when new subscribers subscribe to a topic
|
# already active subscribers when new subscribers subscribe to a topic
|
||||||
@@ -452,7 +466,7 @@ class MQTT:
|
|||||||
self._should_reconnect: bool = True
|
self._should_reconnect: bool = True
|
||||||
self._available_future: asyncio.Future[bool] | None = None
|
self._available_future: asyncio.Future[bool] | None = None
|
||||||
|
|
||||||
self._max_qos: dict[str, int] = {} # topic, max qos
|
self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos
|
||||||
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
||||||
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
|
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
|
||||||
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
|
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
|
||||||
@@ -789,9 +803,9 @@ class MQTT:
|
|||||||
The caller is responsible clearing the cache of _matching_subscriptions.
|
The caller is responsible clearing the cache of _matching_subscriptions.
|
||||||
"""
|
"""
|
||||||
if subscription.is_simple_match:
|
if subscription.is_simple_match:
|
||||||
self._simple_subscriptions[subscription.topic].append(subscription)
|
self._simple_subscriptions[subscription.topic].add(subscription)
|
||||||
else:
|
else:
|
||||||
self._wildcard_subscriptions.append(subscription)
|
self._wildcard_subscriptions.add(subscription)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_untrack_subscription(self, subscription: Subscription) -> None:
|
def _async_untrack_subscription(self, subscription: Subscription) -> None:
|
||||||
@@ -820,8 +834,8 @@ class MQTT:
|
|||||||
"""Queue requested subscriptions."""
|
"""Queue requested subscriptions."""
|
||||||
for subscription in subscriptions:
|
for subscription in subscriptions:
|
||||||
topic, qos = subscription
|
topic, qos = subscription
|
||||||
max_qos = max(qos, self._max_qos.setdefault(topic, qos))
|
if (max_qos := self._max_qos[topic]) < qos:
|
||||||
self._max_qos[topic] = max_qos
|
self._max_qos[topic] = (max_qos := qos)
|
||||||
self._pending_subscriptions[topic] = max_qos
|
self._pending_subscriptions[topic] = max_qos
|
||||||
# Cancel any pending unsubscribe since we are subscribing now
|
# Cancel any pending unsubscribe since we are subscribing now
|
||||||
if topic in self._pending_unsubscribes:
|
if topic in self._pending_unsubscribes:
|
||||||
@@ -832,26 +846,29 @@ class MQTT:
|
|||||||
|
|
||||||
def _exception_message(
|
def _exception_message(
|
||||||
self,
|
self,
|
||||||
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||||
msg: ReceiveMessage,
|
msg: ReceiveMessage,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return a string with the exception message."""
|
"""Return a string with the exception message."""
|
||||||
|
# if msg_callback is a partial we return the name of the first argument
|
||||||
|
if isinstance(msg_callback, partial):
|
||||||
|
call_back_name = getattr(msg_callback.args[0], "__name__") # type: ignore[unreachable]
|
||||||
|
else:
|
||||||
|
call_back_name = getattr(msg_callback, "__name__")
|
||||||
return (
|
return (
|
||||||
f"Exception in {msg_callback.__name__} when handling msg on "
|
f"Exception in {call_back_name} when handling msg on "
|
||||||
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
|
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_subscribe(
|
@callback
|
||||||
|
def async_subscribe(
|
||||||
self,
|
self,
|
||||||
topic: str,
|
topic: str,
|
||||||
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
|
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||||
qos: int,
|
qos: int,
|
||||||
encoding: str | None = None,
|
encoding: str | None = None,
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Set up a subscription to a topic with the provided qos.
|
"""Set up a subscription to a topic with the provided qos."""
|
||||||
|
|
||||||
This method is a coroutine.
|
|
||||||
"""
|
|
||||||
if not isinstance(topic, str):
|
if not isinstance(topic, str):
|
||||||
raise HomeAssistantError("Topic needs to be a string!")
|
raise HomeAssistantError("Topic needs to be a string!")
|
||||||
|
|
||||||
@@ -877,8 +894,10 @@ class MQTT:
|
|||||||
if self.connected:
|
if self.connected:
|
||||||
self._async_queue_subscriptions(((topic, qos),))
|
self._async_queue_subscriptions(((topic, qos),))
|
||||||
|
|
||||||
|
return partial(self._async_remove, subscription)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_remove() -> None:
|
def _async_remove(self, subscription: Subscription) -> None:
|
||||||
"""Remove subscription."""
|
"""Remove subscription."""
|
||||||
self._async_untrack_subscription(subscription)
|
self._async_untrack_subscription(subscription)
|
||||||
self._matching_subscriptions.cache_clear()
|
self._matching_subscriptions.cache_clear()
|
||||||
@@ -886,9 +905,7 @@ class MQTT:
|
|||||||
del self._retained_topics[subscription]
|
del self._retained_topics[subscription]
|
||||||
# Only unsubscribe if currently connected
|
# Only unsubscribe if currently connected
|
||||||
if self.connected:
|
if self.connected:
|
||||||
self._async_unsubscribe(topic)
|
self._async_unsubscribe(subscription.topic)
|
||||||
|
|
||||||
return async_remove
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_unsubscribe(self, topic: str) -> None:
|
def _async_unsubscribe(self, topic: str) -> None:
|
||||||
@@ -1257,9 +1274,7 @@ class MQTT:
|
|||||||
|
|
||||||
last_discovery = self._mqtt_data.last_discovery
|
last_discovery = self._mqtt_data.last_discovery
|
||||||
last_subscribe = now if self._pending_subscriptions else self._last_subscribe
|
last_subscribe = now if self._pending_subscriptions else self._last_subscribe
|
||||||
wait_until = max(
|
wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
|
||||||
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
|
|
||||||
)
|
|
||||||
while now < wait_until:
|
while now < wait_until:
|
||||||
await asyncio.sleep(wait_until - now)
|
await asyncio.sleep(wait_until - now)
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
@@ -1267,9 +1282,7 @@ class MQTT:
|
|||||||
last_subscribe = (
|
last_subscribe = (
|
||||||
now if self._pending_subscriptions else self._last_subscribe
|
now if self._pending_subscriptions else self._last_subscribe
|
||||||
)
|
)
|
||||||
wait_until = max(
|
wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
|
||||||
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:
|
def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:
|
||||||
|
@@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
|
async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
|
||||||
if self._topic[topic] is not None:
|
if self._topic[topic] is not None:
|
||||||
|
@@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_open_cover(self, **kwargs: Any) -> None:
|
async def async_open_cover(self, **kwargs: Any) -> None:
|
||||||
"""Move the cover up.
|
"""Move the cover up.
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -32,13 +33,7 @@ from homeassistant.helpers.typing import ConfigType
|
|||||||
from . import subscription
|
from . import subscription
|
||||||
from .config import MQTT_BASE_SCHEMA
|
from .config import MQTT_BASE_SCHEMA
|
||||||
from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC
|
from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC
|
||||||
from .debug_info import log_messages
|
from .mixins import CONF_JSON_ATTRS_TOPIC, MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
CONF_JSON_ATTRS_TOPIC,
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType
|
from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType
|
||||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||||
from .util import valid_subscribe_topic
|
from .util import valid_subscribe_topic
|
||||||
@@ -119,13 +114,8 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
|
|||||||
config.get(CONF_VALUE_TEMPLATE), entity=self
|
config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _tracker_message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_location_name"})
|
|
||||||
def message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
payload = self._value_template(msg.payload)
|
payload = self._value_template(msg.payload)
|
||||||
if not payload.strip(): # No output from template, ignore
|
if not payload.strip(): # No output from template, ignore
|
||||||
@@ -146,6 +136,9 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
|
|||||||
assert isinstance(msg.payload, str)
|
assert isinstance(msg.payload, str)
|
||||||
self._location_name = msg.payload
|
self._location_name = msg.payload
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
|
||||||
state_topic: str | None = self._config.get(CONF_STATE_TOPIC)
|
state_topic: str | None = self._config.get(CONF_STATE_TOPIC)
|
||||||
if state_topic is None:
|
if state_topic is None:
|
||||||
return
|
return
|
||||||
@@ -155,7 +148,12 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
|
|||||||
{
|
{
|
||||||
"state_topic": {
|
"state_topic": {
|
||||||
"topic": state_topic,
|
"topic": state_topic,
|
||||||
"msg_callback": message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._tracker_message_received,
|
||||||
|
{"_location_name"},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -168,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def latitude(self) -> float | None:
|
def latitude(self) -> float | None:
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -31,7 +32,6 @@ from .const import (
|
|||||||
PAYLOAD_EMPTY_JSON,
|
PAYLOAD_EMPTY_JSON,
|
||||||
PAYLOAD_NONE,
|
PAYLOAD_NONE,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
|
||||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .models import (
|
from .models import (
|
||||||
DATA_MQTT,
|
DATA_MQTT,
|
||||||
@@ -113,13 +113,8 @@ class MqttEvent(MqttEntity, EventEntity):
|
|||||||
self._config.get(CONF_VALUE_TEMPLATE), entity=self
|
self._config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
topics: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _event_received(self, msg: ReceiveMessage) -> None:
|
||||||
def message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
if msg.retain:
|
if msg.retain:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
@@ -161,10 +156,7 @@ class MqttEvent(MqttEntity, EventEntity):
|
|||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
(
|
("`event_type` missing in JSON event payload, " " '%s' on topic %s"),
|
||||||
"`event_type` missing in JSON event payload, "
|
|
||||||
" '%s' on topic %s"
|
|
||||||
),
|
|
||||||
payload,
|
payload,
|
||||||
msg.topic,
|
msg.topic,
|
||||||
)
|
)
|
||||||
@@ -194,9 +186,18 @@ class MqttEvent(MqttEntity, EventEntity):
|
|||||||
mqtt_data = self.hass.data[DATA_MQTT]
|
mqtt_data = self.hass.data[DATA_MQTT]
|
||||||
mqtt_data.state_write_requests.write_state_request(self)
|
mqtt_data.state_write_requests.write_state_request(self)
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
topics["state_topic"] = {
|
topics["state_topic"] = {
|
||||||
"topic": self._config[CONF_STATE_TOPIC],
|
"topic": self._config[CONF_STATE_TOPIC],
|
||||||
"msg_callback": message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._event_received,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -207,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -49,12 +50,7 @@ from .const import (
|
|||||||
CONF_STATE_VALUE_TEMPLATE,
|
CONF_STATE_VALUE_TEMPLATE,
|
||||||
PAYLOAD_NONE,
|
PAYLOAD_NONE,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MessageCallbackType,
|
MessageCallbackType,
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
@@ -338,25 +334,8 @@ class MqttFan(MqttEntity, FanEntity):
|
|||||||
for key, tpl in value_templates.items()
|
for key, tpl in value_templates.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
topics: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
|
|
||||||
"""Add a topic to subscribe to."""
|
|
||||||
if has_topic := self._topic[topic] is not None:
|
|
||||||
topics[topic] = {
|
|
||||||
"topic": self._topic[topic],
|
|
||||||
"msg_callback": msg_callback,
|
|
||||||
"qos": self._config[CONF_QOS],
|
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
|
||||||
}
|
|
||||||
return has_topic
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
|
||||||
def state_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message."""
|
"""Handle new received MQTT message."""
|
||||||
payload = self._value_templates[CONF_STATE](msg.payload)
|
payload = self._value_templates[CONF_STATE](msg.payload)
|
||||||
if not payload:
|
if not payload:
|
||||||
@@ -369,12 +348,8 @@ class MqttFan(MqttEntity, FanEntity):
|
|||||||
elif payload == PAYLOAD_NONE:
|
elif payload == PAYLOAD_NONE:
|
||||||
self._attr_is_on = None
|
self._attr_is_on = None
|
||||||
|
|
||||||
add_subscribe_topic(CONF_STATE_TOPIC, state_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _percentage_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_percentage"})
|
|
||||||
def percentage_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message for the percentage."""
|
"""Handle new received MQTT message for the percentage."""
|
||||||
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
|
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
|
||||||
msg.payload
|
msg.payload
|
||||||
@@ -413,12 +388,8 @@ class MqttFan(MqttEntity, FanEntity):
|
|||||||
return
|
return
|
||||||
self._attr_percentage = percentage
|
self._attr_percentage = percentage
|
||||||
|
|
||||||
add_subscribe_topic(CONF_PERCENTAGE_STATE_TOPIC, percentage_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _preset_mode_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_preset_mode"})
|
|
||||||
def preset_mode_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message for preset mode."""
|
"""Handle new received MQTT message for preset mode."""
|
||||||
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
|
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
|
||||||
if preset_mode == self._payload["PRESET_MODE_RESET"]:
|
if preset_mode == self._payload["PRESET_MODE_RESET"]:
|
||||||
@@ -438,12 +409,8 @@ class MqttFan(MqttEntity, FanEntity):
|
|||||||
|
|
||||||
self._attr_preset_mode = preset_mode
|
self._attr_preset_mode = preset_mode
|
||||||
|
|
||||||
add_subscribe_topic(CONF_PRESET_MODE_STATE_TOPIC, preset_mode_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _oscillation_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_oscillating"})
|
|
||||||
def oscillation_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message for the oscillation."""
|
"""Handle new received MQTT message for the oscillation."""
|
||||||
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
|
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
|
||||||
if not payload:
|
if not payload:
|
||||||
@@ -454,13 +421,8 @@ class MqttFan(MqttEntity, FanEntity):
|
|||||||
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
|
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
|
||||||
self._attr_oscillating = False
|
self._attr_oscillating = False
|
||||||
|
|
||||||
if add_subscribe_topic(CONF_OSCILLATION_STATE_TOPIC, oscillation_received):
|
|
||||||
self._attr_oscillating = False
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _direction_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_current_direction"})
|
|
||||||
def direction_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message for the direction."""
|
"""Handle new received MQTT message for the direction."""
|
||||||
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
|
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
|
||||||
if not direction:
|
if not direction:
|
||||||
@@ -468,7 +430,46 @@ class MqttFan(MqttEntity, FanEntity):
|
|||||||
return
|
return
|
||||||
self._attr_current_direction = str(direction)
|
self._attr_current_direction = str(direction)
|
||||||
|
|
||||||
add_subscribe_topic(CONF_DIRECTION_STATE_TOPIC, direction_received)
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def add_subscribe_topic(
|
||||||
|
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
|
||||||
|
) -> bool:
|
||||||
|
"""Add a topic to subscribe to."""
|
||||||
|
if has_topic := self._topic[topic] is not None:
|
||||||
|
topics[topic] = {
|
||||||
|
"topic": self._topic[topic],
|
||||||
|
"msg_callback": partial(
|
||||||
|
self._message_callback, msg_callback, tracked_attributes
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
|
"qos": self._config[CONF_QOS],
|
||||||
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
|
}
|
||||||
|
return has_topic
|
||||||
|
|
||||||
|
add_subscribe_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
|
||||||
|
add_subscribe_topic(
|
||||||
|
CONF_PERCENTAGE_STATE_TOPIC, self._percentage_received, {"_attr_percentage"}
|
||||||
|
)
|
||||||
|
add_subscribe_topic(
|
||||||
|
CONF_PRESET_MODE_STATE_TOPIC,
|
||||||
|
self._preset_mode_received,
|
||||||
|
{"_attr_preset_mode"},
|
||||||
|
)
|
||||||
|
if add_subscribe_topic(
|
||||||
|
CONF_OSCILLATION_STATE_TOPIC,
|
||||||
|
self._oscillation_received,
|
||||||
|
{"_attr_oscillating"},
|
||||||
|
):
|
||||||
|
self._attr_oscillating = False
|
||||||
|
add_subscribe_topic(
|
||||||
|
CONF_DIRECTION_STATE_TOPIC,
|
||||||
|
self._direction_received,
|
||||||
|
{"_attr_current_direction"},
|
||||||
|
)
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass, self._sub_state, topics
|
self.hass, self._sub_state, topics
|
||||||
@@ -476,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_on(self) -> bool | None:
|
def is_on(self) -> bool | None:
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -51,12 +52,7 @@ from .const import (
|
|||||||
CONF_STATE_VALUE_TEMPLATE,
|
CONF_STATE_VALUE_TEMPLATE,
|
||||||
PAYLOAD_NONE,
|
PAYLOAD_NONE,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
@@ -284,25 +280,23 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
|||||||
topics: dict[str, dict[str, Any]],
|
topics: dict[str, dict[str, Any]],
|
||||||
topic: str,
|
topic: str,
|
||||||
msg_callback: Callable[[ReceiveMessage], None],
|
msg_callback: Callable[[ReceiveMessage], None],
|
||||||
|
tracked_attributes: set[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a subscription."""
|
"""Add a subscription."""
|
||||||
qos: int = self._config[CONF_QOS]
|
qos: int = self._config[CONF_QOS]
|
||||||
if topic in self._topic and self._topic[topic] is not None:
|
if topic in self._topic and self._topic[topic] is not None:
|
||||||
topics[topic] = {
|
topics[topic] = {
|
||||||
"topic": self._topic[topic],
|
"topic": self._topic[topic],
|
||||||
"msg_callback": msg_callback,
|
"msg_callback": partial(
|
||||||
|
self._message_callback, msg_callback, tracked_attributes
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": qos,
|
"qos": qos,
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
topics: dict[str, Any] = {}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
|
||||||
def state_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message."""
|
"""Handle new received MQTT message."""
|
||||||
payload = self._value_templates[CONF_STATE](msg.payload)
|
payload = self._value_templates[CONF_STATE](msg.payload)
|
||||||
if not payload:
|
if not payload:
|
||||||
@@ -315,12 +309,8 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
|||||||
elif payload == PAYLOAD_NONE:
|
elif payload == PAYLOAD_NONE:
|
||||||
self._attr_is_on = None
|
self._attr_is_on = None
|
||||||
|
|
||||||
self.add_subscription(topics, CONF_STATE_TOPIC, state_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _action_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_action"})
|
|
||||||
def action_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message."""
|
"""Handle new received MQTT message."""
|
||||||
action_payload = self._value_templates[ATTR_ACTION](msg.payload)
|
action_payload = self._value_templates[ATTR_ACTION](msg.payload)
|
||||||
if not action_payload or action_payload == PAYLOAD_NONE:
|
if not action_payload or action_payload == PAYLOAD_NONE:
|
||||||
@@ -337,12 +327,8 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.add_subscription(topics, CONF_ACTION_TOPIC, action_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _current_humidity_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_current_humidity"})
|
|
||||||
def current_humidity_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message for the current humidity."""
|
"""Handle new received MQTT message for the current humidity."""
|
||||||
rendered_current_humidity_payload = self._value_templates[
|
rendered_current_humidity_payload = self._value_templates[
|
||||||
ATTR_CURRENT_HUMIDITY
|
ATTR_CURRENT_HUMIDITY
|
||||||
@@ -373,14 +359,8 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
|||||||
return
|
return
|
||||||
self._attr_current_humidity = current_humidity
|
self._attr_current_humidity = current_humidity
|
||||||
|
|
||||||
self.add_subscription(
|
|
||||||
topics, CONF_CURRENT_HUMIDITY_TOPIC, current_humidity_received
|
|
||||||
)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _target_humidity_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_target_humidity"})
|
|
||||||
def target_humidity_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message for the target humidity."""
|
"""Handle new received MQTT message for the target humidity."""
|
||||||
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
|
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
|
||||||
msg.payload
|
msg.payload
|
||||||
@@ -414,14 +394,8 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
|||||||
return
|
return
|
||||||
self._attr_target_humidity = target_humidity
|
self._attr_target_humidity = target_humidity
|
||||||
|
|
||||||
self.add_subscription(
|
|
||||||
topics, CONF_TARGET_HUMIDITY_STATE_TOPIC, target_humidity_received
|
|
||||||
)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _mode_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_mode"})
|
|
||||||
def mode_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new received MQTT message for mode."""
|
"""Handle new received MQTT message for mode."""
|
||||||
mode = str(self._value_templates[ATTR_MODE](msg.payload))
|
mode = str(self._value_templates[ATTR_MODE](msg.payload))
|
||||||
if mode == self._payload["MODE_RESET"]:
|
if mode == self._payload["MODE_RESET"]:
|
||||||
@@ -441,7 +415,31 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
|||||||
|
|
||||||
self._attr_mode = mode
|
self._attr_mode = mode
|
||||||
|
|
||||||
self.add_subscription(topics, CONF_MODE_STATE_TOPIC, mode_received)
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics: dict[str, Any] = {}
|
||||||
|
|
||||||
|
self.add_subscription(
|
||||||
|
topics, CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"}
|
||||||
|
)
|
||||||
|
self.add_subscription(
|
||||||
|
topics, CONF_ACTION_TOPIC, self._action_received, {"_attr_action"}
|
||||||
|
)
|
||||||
|
self.add_subscription(
|
||||||
|
topics,
|
||||||
|
CONF_CURRENT_HUMIDITY_TOPIC,
|
||||||
|
self._current_humidity_received,
|
||||||
|
{"_attr_current_humidity"},
|
||||||
|
)
|
||||||
|
self.add_subscription(
|
||||||
|
topics,
|
||||||
|
CONF_TARGET_HUMIDITY_STATE_TOPIC,
|
||||||
|
self._target_humidity_received,
|
||||||
|
{"_attr_target_humidity"},
|
||||||
|
)
|
||||||
|
self.add_subscription(
|
||||||
|
topics, CONF_MODE_STATE_TOPIC, self._mode_received, {"_attr_mode"}
|
||||||
|
)
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass, self._sub_state, topics
|
self.hass, self._sub_state, topics
|
||||||
@@ -449,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||||
"""Turn on the entity.
|
"""Turn on the entity.
|
||||||
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
import binascii
|
import binascii
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@@ -26,7 +27,6 @@ from homeassistant.util import dt as dt_util
|
|||||||
from . import subscription
|
from . import subscription
|
||||||
from .config import MQTT_BASE_SCHEMA
|
from .config import MQTT_BASE_SCHEMA
|
||||||
from .const import CONF_ENCODING, CONF_QOS
|
from .const import CONF_ENCODING, CONF_QOS
|
||||||
from .debug_info import log_messages
|
|
||||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .models import (
|
from .models import (
|
||||||
DATA_MQTT,
|
DATA_MQTT,
|
||||||
@@ -143,31 +143,8 @@ class MqttImage(MqttEntity, ImageEntity):
|
|||||||
config.get(CONF_URL_TEMPLATE), entity=self
|
config.get(CONF_URL_TEMPLATE), entity=self
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
topics: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
|
|
||||||
"""Add a topic to subscribe to."""
|
|
||||||
encoding: str | None
|
|
||||||
encoding = (
|
|
||||||
None
|
|
||||||
if CONF_IMAGE_TOPIC in self._config
|
|
||||||
else self._config[CONF_ENCODING] or None
|
|
||||||
)
|
|
||||||
if has_topic := self._topic[topic] is not None:
|
|
||||||
topics[topic] = {
|
|
||||||
"topic": self._topic[topic],
|
|
||||||
"msg_callback": msg_callback,
|
|
||||||
"qos": self._config[CONF_QOS],
|
|
||||||
"encoding": encoding,
|
|
||||||
}
|
|
||||||
return has_topic
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _image_data_received(self, msg: ReceiveMessage) -> None:
|
||||||
def image_data_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
try:
|
try:
|
||||||
if CONF_IMAGE_ENCODING in self._config:
|
if CONF_IMAGE_ENCODING in self._config:
|
||||||
@@ -186,11 +163,8 @@ class MqttImage(MqttEntity, ImageEntity):
|
|||||||
self._attr_image_last_updated = dt_util.utcnow()
|
self._attr_image_last_updated = dt_util.utcnow()
|
||||||
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
||||||
|
|
||||||
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _image_from_url_request_received(self, msg: ReceiveMessage) -> None:
|
||||||
def image_from_url_request_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
try:
|
try:
|
||||||
url = cv.url(self._url_template(msg.payload))
|
url = cv.url(self._url_template(msg.payload))
|
||||||
@@ -208,7 +182,31 @@ class MqttImage(MqttEntity, ImageEntity):
|
|||||||
self._cached_image = None
|
self._cached_image = None
|
||||||
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
||||||
|
|
||||||
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
|
||||||
|
topics: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
|
||||||
|
"""Add a topic to subscribe to."""
|
||||||
|
encoding: str | None
|
||||||
|
encoding = (
|
||||||
|
None
|
||||||
|
if CONF_IMAGE_TOPIC in self._config
|
||||||
|
else self._config[CONF_ENCODING] or None
|
||||||
|
)
|
||||||
|
if has_topic := self._topic[topic] is not None:
|
||||||
|
topics[topic] = {
|
||||||
|
"topic": self._topic[topic],
|
||||||
|
"msg_callback": partial(self._message_callback, msg_callback, None),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
|
"qos": self._config[CONF_QOS],
|
||||||
|
"encoding": encoding,
|
||||||
|
}
|
||||||
|
return has_topic
|
||||||
|
|
||||||
|
add_subscribe_topic(CONF_IMAGE_TOPIC, self._image_data_received)
|
||||||
|
add_subscribe_topic(CONF_URL_TOPIC, self._image_from_url_request_received)
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass, self._sub_state, topics
|
self.hass, self._sub_state, topics
|
||||||
@@ -216,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_image(self) -> bytes | None:
|
async def async_image(self) -> bytes | None:
|
||||||
"""Return bytes of image."""
|
"""Return bytes of image."""
|
||||||
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@@ -31,12 +32,7 @@ from .const import (
|
|||||||
DEFAULT_OPTIMISTIC,
|
DEFAULT_OPTIMISTIC,
|
||||||
DEFAULT_RETAIN,
|
DEFAULT_RETAIN,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
@@ -150,13 +146,8 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
|
|||||||
config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self
|
config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self
|
||||||
).async_render
|
).async_render
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_activity"})
|
|
||||||
def message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
payload = str(self._value_template(msg.payload))
|
payload = str(self._value_template(msg.payload))
|
||||||
if not payload:
|
if not payload:
|
||||||
@@ -181,17 +172,24 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None:
|
if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None:
|
||||||
# Force into optimistic mode.
|
# Force into optimistic mode.
|
||||||
self._attr_assumed_state = True
|
self._attr_assumed_state = True
|
||||||
else:
|
return
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
{
|
{
|
||||||
CONF_ACTIVITY_STATE_TOPIC: {
|
CONF_ACTIVITY_STATE_TOPIC: {
|
||||||
"topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC),
|
"topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC),
|
||||||
"msg_callback": message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._message_received,
|
||||||
|
{"_attr_activity"},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -200,7 +198,7 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
if self._attr_assumed_state and (
|
if self._attr_assumed_state and (
|
||||||
last_state := await self.async_get_last_state()
|
last_state := await self.async_get_last_state()
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
@@ -53,8 +54,7 @@ from ..const import (
|
|||||||
CONF_STATE_VALUE_TEMPLATE,
|
CONF_STATE_VALUE_TEMPLATE,
|
||||||
PAYLOAD_NONE,
|
PAYLOAD_NONE,
|
||||||
)
|
)
|
||||||
from ..debug_info import log_messages
|
from ..mixins import MqttEntity
|
||||||
from ..mixins import MqttEntity, write_state_on_attr_change
|
|
||||||
from ..models import (
|
from ..models import (
|
||||||
MessageCallbackType,
|
MessageCallbackType,
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
@@ -378,24 +378,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
attr: bool = getattr(self, f"_optimistic_{attribute}")
|
attr: bool = getattr(self, f"_optimistic_{attribute}")
|
||||||
return attr
|
return attr
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None: # noqa: C901
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
topics: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
def add_topic(topic: str, msg_callback: MessageCallbackType) -> None:
|
|
||||||
"""Add a topic."""
|
|
||||||
if self._topic[topic] is not None:
|
|
||||||
topics[topic] = {
|
|
||||||
"topic": self._topic[topic],
|
|
||||||
"msg_callback": msg_callback,
|
|
||||||
"qos": self._config[CONF_QOS],
|
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
|
||||||
}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
|
||||||
def state_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
|
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
|
||||||
msg.payload, PayloadSentinel.NONE
|
msg.payload, PayloadSentinel.NONE
|
||||||
@@ -411,18 +395,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
elif payload == PAYLOAD_NONE:
|
elif payload == PAYLOAD_NONE:
|
||||||
self._attr_is_on = None
|
self._attr_is_on = None
|
||||||
|
|
||||||
if self._topic[CONF_STATE_TOPIC] is not None:
|
|
||||||
topics[CONF_STATE_TOPIC] = {
|
|
||||||
"topic": self._topic[CONF_STATE_TOPIC],
|
|
||||||
"msg_callback": state_received,
|
|
||||||
"qos": self._config[CONF_QOS],
|
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
|
||||||
}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _brightness_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_brightness"})
|
|
||||||
def brightness_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for the brightness."""
|
"""Handle new MQTT messages for the brightness."""
|
||||||
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
|
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
|
||||||
msg.payload, PayloadSentinel.DEFAULT
|
msg.payload, PayloadSentinel.DEFAULT
|
||||||
@@ -439,23 +413,18 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
|
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
|
||||||
self._attr_brightness = min(round(percent_bright * 255), 255)
|
self._attr_brightness = min(round(percent_bright * 255), 255)
|
||||||
|
|
||||||
add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _rgbx_received(
|
def _rgbx_received(
|
||||||
|
self,
|
||||||
msg: ReceiveMessage,
|
msg: ReceiveMessage,
|
||||||
template: str,
|
template: str,
|
||||||
color_mode: ColorMode,
|
color_mode: ColorMode,
|
||||||
convert_color: Callable[..., tuple[int, ...]],
|
convert_color: Callable[..., tuple[int, ...]],
|
||||||
) -> tuple[int, ...] | None:
|
) -> tuple[int, ...] | None:
|
||||||
"""Handle new MQTT messages for RGBW and RGBWW."""
|
"""Process MQTT messages for RGBW and RGBWW."""
|
||||||
payload = self._value_templates[template](
|
payload = self._value_templates[template](msg.payload, PayloadSentinel.DEFAULT)
|
||||||
msg.payload, PayloadSentinel.DEFAULT
|
|
||||||
)
|
|
||||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug("Ignoring empty %s message from '%s'", color_mode, msg.topic)
|
||||||
"Ignoring empty %s message from '%s'", color_mode, msg.topic
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
color = tuple(int(val) for val in str(payload).split(","))
|
color = tuple(int(val) for val in str(payload).split(","))
|
||||||
if self._optimistic_color_mode:
|
if self._optimistic_color_mode:
|
||||||
@@ -478,29 +447,19 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
return color
|
return color
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _rgb_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"}
|
|
||||||
)
|
|
||||||
def rgb_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for RGB."""
|
"""Handle new MQTT messages for RGB."""
|
||||||
rgb = _rgbx_received(
|
rgb = self._rgbx_received(
|
||||||
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x
|
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x
|
||||||
)
|
)
|
||||||
if rgb is None:
|
if rgb is None:
|
||||||
return
|
return
|
||||||
self._attr_rgb_color = cast(tuple[int, int, int], rgb)
|
self._attr_rgb_color = cast(tuple[int, int, int], rgb)
|
||||||
|
|
||||||
add_topic(CONF_RGB_STATE_TOPIC, rgb_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _rgbw_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"}
|
|
||||||
)
|
|
||||||
def rgbw_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for RGBW."""
|
"""Handle new MQTT messages for RGBW."""
|
||||||
rgbw = _rgbx_received(
|
rgbw = self._rgbx_received(
|
||||||
msg,
|
msg,
|
||||||
CONF_RGBW_VALUE_TEMPLATE,
|
CONF_RGBW_VALUE_TEMPLATE,
|
||||||
ColorMode.RGBW,
|
ColorMode.RGBW,
|
||||||
@@ -510,31 +469,21 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
return
|
return
|
||||||
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
|
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
|
||||||
|
|
||||||
add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _rgbww_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"}
|
|
||||||
)
|
|
||||||
def rgbww_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for RGBWW."""
|
"""Handle new MQTT messages for RGBWW."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _converter(
|
def _converter(
|
||||||
r: int, g: int, b: int, cw: int, ww: int
|
r: int, g: int, b: int, cw: int, ww: int
|
||||||
) -> tuple[int, int, int]:
|
) -> tuple[int, int, int]:
|
||||||
min_kelvin = color_util.color_temperature_mired_to_kelvin(
|
min_kelvin = color_util.color_temperature_mired_to_kelvin(self.max_mireds)
|
||||||
self.max_mireds
|
max_kelvin = color_util.color_temperature_mired_to_kelvin(self.min_mireds)
|
||||||
)
|
|
||||||
max_kelvin = color_util.color_temperature_mired_to_kelvin(
|
|
||||||
self.min_mireds
|
|
||||||
)
|
|
||||||
return color_util.color_rgbww_to_rgb(
|
return color_util.color_rgbww_to_rgb(
|
||||||
r, g, b, cw, ww, min_kelvin, max_kelvin
|
r, g, b, cw, ww, min_kelvin, max_kelvin
|
||||||
)
|
)
|
||||||
|
|
||||||
rgbww = _rgbx_received(
|
rgbww = self._rgbx_received(
|
||||||
msg,
|
msg,
|
||||||
CONF_RGBWW_VALUE_TEMPLATE,
|
CONF_RGBWW_VALUE_TEMPLATE,
|
||||||
ColorMode.RGBWW,
|
ColorMode.RGBWW,
|
||||||
@@ -544,12 +493,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
return
|
return
|
||||||
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
|
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
|
||||||
|
|
||||||
add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _color_mode_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_color_mode"})
|
|
||||||
def color_mode_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for color mode."""
|
"""Handle new MQTT messages for color mode."""
|
||||||
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
|
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
|
||||||
msg.payload, PayloadSentinel.DEFAULT
|
msg.payload, PayloadSentinel.DEFAULT
|
||||||
@@ -560,12 +505,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
|
|
||||||
self._attr_color_mode = ColorMode(str(payload))
|
self._attr_color_mode = ColorMode(str(payload))
|
||||||
|
|
||||||
add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _color_temp_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_color_temp"})
|
|
||||||
def color_temp_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for color temperature."""
|
"""Handle new MQTT messages for color temperature."""
|
||||||
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
|
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
|
||||||
msg.payload, PayloadSentinel.DEFAULT
|
msg.payload, PayloadSentinel.DEFAULT
|
||||||
@@ -578,12 +519,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
self._attr_color_mode = ColorMode.COLOR_TEMP
|
self._attr_color_mode = ColorMode.COLOR_TEMP
|
||||||
self._attr_color_temp = int(payload)
|
self._attr_color_temp = int(payload)
|
||||||
|
|
||||||
add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _effect_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_effect"})
|
|
||||||
def effect_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for effect."""
|
"""Handle new MQTT messages for effect."""
|
||||||
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
|
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
|
||||||
msg.payload, PayloadSentinel.DEFAULT
|
msg.payload, PayloadSentinel.DEFAULT
|
||||||
@@ -594,12 +531,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
|
|
||||||
self._attr_effect = str(payload)
|
self._attr_effect = str(payload)
|
||||||
|
|
||||||
add_topic(CONF_EFFECT_STATE_TOPIC, effect_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _hs_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_hs_color"})
|
|
||||||
def hs_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for hs color."""
|
"""Handle new MQTT messages for hs color."""
|
||||||
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
|
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
|
||||||
msg.payload, PayloadSentinel.DEFAULT
|
msg.payload, PayloadSentinel.DEFAULT
|
||||||
@@ -615,12 +548,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
_LOGGER.warning("Failed to parse hs state update: '%s'", payload)
|
_LOGGER.warning("Failed to parse hs state update: '%s'", payload)
|
||||||
|
|
||||||
add_topic(CONF_HS_STATE_TOPIC, hs_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _xy_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_xy_color"})
|
|
||||||
def xy_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages for xy color."""
|
"""Handle new MQTT messages for xy color."""
|
||||||
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
|
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
|
||||||
msg.payload, PayloadSentinel.DEFAULT
|
msg.payload, PayloadSentinel.DEFAULT
|
||||||
@@ -634,7 +563,63 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
self._attr_color_mode = ColorMode.XY
|
self._attr_color_mode = ColorMode.XY
|
||||||
self._attr_xy_color = cast(tuple[float, float], xy_color)
|
self._attr_xy_color = cast(tuple[float, float], xy_color)
|
||||||
|
|
||||||
add_topic(CONF_XY_STATE_TOPIC, xy_received)
|
def _prepare_subscribe_topics(self) -> None: # noqa: C901
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def add_topic(
|
||||||
|
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
|
||||||
|
) -> None:
|
||||||
|
"""Add a topic."""
|
||||||
|
if self._topic[topic] is not None:
|
||||||
|
topics[topic] = {
|
||||||
|
"topic": self._topic[topic],
|
||||||
|
"msg_callback": partial(
|
||||||
|
self._message_callback, msg_callback, tracked_attributes
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
|
"qos": self._config[CONF_QOS],
|
||||||
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
|
}
|
||||||
|
|
||||||
|
add_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
|
||||||
|
add_topic(
|
||||||
|
CONF_BRIGHTNESS_STATE_TOPIC, self._brightness_received, {"_attr_brightness"}
|
||||||
|
)
|
||||||
|
add_topic(
|
||||||
|
CONF_RGB_STATE_TOPIC,
|
||||||
|
self._rgb_received,
|
||||||
|
{"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"},
|
||||||
|
)
|
||||||
|
add_topic(
|
||||||
|
CONF_RGBW_STATE_TOPIC,
|
||||||
|
self._rgbw_received,
|
||||||
|
{"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"},
|
||||||
|
)
|
||||||
|
add_topic(
|
||||||
|
CONF_RGBWW_STATE_TOPIC,
|
||||||
|
self._rgbww_received,
|
||||||
|
{"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"},
|
||||||
|
)
|
||||||
|
add_topic(
|
||||||
|
CONF_COLOR_MODE_STATE_TOPIC, self._color_mode_received, {"_attr_color_mode"}
|
||||||
|
)
|
||||||
|
add_topic(
|
||||||
|
CONF_COLOR_TEMP_STATE_TOPIC,
|
||||||
|
self._color_temp_received,
|
||||||
|
{"_attr_color_mode", "_attr_color_temp"},
|
||||||
|
)
|
||||||
|
add_topic(CONF_EFFECT_STATE_TOPIC, self._effect_received, {"_attr_effect"})
|
||||||
|
add_topic(
|
||||||
|
CONF_HS_STATE_TOPIC,
|
||||||
|
self._hs_received,
|
||||||
|
{"_attr_color_mode", "_attr_hs_color"},
|
||||||
|
)
|
||||||
|
add_topic(
|
||||||
|
CONF_XY_STATE_TOPIC,
|
||||||
|
self._xy_received,
|
||||||
|
{"_attr_color_mode", "_attr_xy_color"},
|
||||||
|
)
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass, self._sub_state, topics
|
self.hass, self._sub_state, topics
|
||||||
@@ -642,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
last_state = await self.async_get_last_state()
|
last_state = await self.async_get_last_state()
|
||||||
|
|
||||||
def restore_state(
|
def restore_state(
|
||||||
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
@@ -66,8 +67,7 @@ from ..const import (
|
|||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
DOMAIN as MQTT_DOMAIN,
|
DOMAIN as MQTT_DOMAIN,
|
||||||
)
|
)
|
||||||
from ..debug_info import log_messages
|
from ..mixins import MqttEntity
|
||||||
from ..mixins import MqttEntity, write_state_on_attr_change
|
|
||||||
from ..models import ReceiveMessage
|
from ..models import ReceiveMessage
|
||||||
from ..schemas import MQTT_ENTITY_COMMON_SCHEMA
|
from ..schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||||
from ..util import valid_subscribe_topic
|
from ..util import valid_subscribe_topic
|
||||||
@@ -414,27 +414,8 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
self.entity_id,
|
self.entity_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self,
|
|
||||||
{
|
|
||||||
"_attr_brightness",
|
|
||||||
"_attr_color_temp",
|
|
||||||
"_attr_effect",
|
|
||||||
"_attr_hs_color",
|
|
||||||
"_attr_is_on",
|
|
||||||
"_attr_rgb_color",
|
|
||||||
"_attr_rgbw_color",
|
|
||||||
"_attr_rgbww_color",
|
|
||||||
"_attr_xy_color",
|
|
||||||
"color_mode",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
def state_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
values = json_loads_object(msg.payload)
|
values = json_loads_object(msg.payload)
|
||||||
|
|
||||||
@@ -509,14 +490,36 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
with suppress(KeyError):
|
with suppress(KeyError):
|
||||||
self._attr_effect = cast(str, values["effect"])
|
self._attr_effect = cast(str, values["effect"])
|
||||||
|
|
||||||
if self._topic[CONF_STATE_TOPIC] is not None:
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
|
||||||
|
#
|
||||||
|
if self._topic[CONF_STATE_TOPIC] is None:
|
||||||
|
return
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
{
|
{
|
||||||
"state_topic": {
|
CONF_STATE_TOPIC: {
|
||||||
"topic": self._topic[CONF_STATE_TOPIC],
|
"topic": self._topic[CONF_STATE_TOPIC],
|
||||||
"msg_callback": state_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._state_received,
|
||||||
|
{
|
||||||
|
"_attr_brightness",
|
||||||
|
"_attr_color_temp",
|
||||||
|
"_attr_effect",
|
||||||
|
"_attr_hs_color",
|
||||||
|
"_attr_is_on",
|
||||||
|
"_attr_rgb_color",
|
||||||
|
"_attr_rgbw_color",
|
||||||
|
"_attr_rgbww_color",
|
||||||
|
"_attr_xy_color",
|
||||||
|
"color_mode",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -525,7 +528,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
last_state = await self.async_get_last_state()
|
last_state = await self.async_get_last_state()
|
||||||
if self._optimistic and last_state:
|
if self._optimistic and last_state:
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -44,8 +45,7 @@ from ..const import (
|
|||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
PAYLOAD_NONE,
|
PAYLOAD_NONE,
|
||||||
)
|
)
|
||||||
from ..debug_info import log_messages
|
from ..mixins import MqttEntity
|
||||||
from ..mixins import MqttEntity, write_state_on_attr_change
|
|
||||||
from ..models import (
|
from ..models import (
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
@@ -188,23 +188,8 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
# Support for ct + hs, prioritize hs
|
# Support for ct + hs, prioritize hs
|
||||||
self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP
|
self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self,
|
|
||||||
{
|
|
||||||
"_attr_brightness",
|
|
||||||
"_attr_color_mode",
|
|
||||||
"_attr_color_temp",
|
|
||||||
"_attr_effect",
|
|
||||||
"_attr_hs_color",
|
|
||||||
"_attr_is_on",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
def state_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
|
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
|
||||||
if state == STATE_ON:
|
if state == STATE_ON:
|
||||||
@@ -229,9 +214,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning("Invalid brightness value received from %s", msg.topic)
|
||||||
"Invalid brightness value received from %s", msg.topic
|
|
||||||
)
|
|
||||||
|
|
||||||
if CONF_COLOR_TEMP_TEMPLATE in self._config:
|
if CONF_COLOR_TEMP_TEMPLATE in self._config:
|
||||||
try:
|
try:
|
||||||
@@ -272,14 +255,31 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
else:
|
else:
|
||||||
_LOGGER.warning("Unsupported effect value received")
|
_LOGGER.warning("Unsupported effect value received")
|
||||||
|
|
||||||
if self._topics[CONF_STATE_TOPIC] is not None:
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
|
||||||
|
if self._topics[CONF_STATE_TOPIC] is None:
|
||||||
|
return
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
{
|
{
|
||||||
"state_topic": {
|
"state_topic": {
|
||||||
"topic": self._topics[CONF_STATE_TOPIC],
|
"topic": self._topics[CONF_STATE_TOPIC],
|
||||||
"msg_callback": state_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._state_received,
|
||||||
|
{
|
||||||
|
"_attr_brightness",
|
||||||
|
"_attr_color_mode",
|
||||||
|
"_attr_color_temp",
|
||||||
|
"_attr_effect",
|
||||||
|
"_attr_hs_color",
|
||||||
|
"_attr_is_on",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -288,7 +288,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
last_state = await self.async_get_last_state()
|
last_state = await self.async_get_last_state()
|
||||||
if self._optimistic and last_state:
|
if self._optimistic and last_state:
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -36,12 +37,7 @@ from .const import (
|
|||||||
CONF_STATE_OPENING,
|
CONF_STATE_OPENING,
|
||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
@@ -186,27 +182,8 @@ class MqttLock(MqttEntity, LockEntity):
|
|||||||
|
|
||||||
self._valid_states = [config[state] for state in STATE_CONFIG_KEYS]
|
self._valid_states = [config[state] for state in STATE_CONFIG_KEYS]
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
topics: dict[str, dict[str, Any]] = {}
|
|
||||||
qos: int = self._config[CONF_QOS]
|
|
||||||
encoding: str | None = self._config[CONF_ENCODING] or None
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self,
|
|
||||||
{
|
|
||||||
"_attr_is_jammed",
|
|
||||||
"_attr_is_locked",
|
|
||||||
"_attr_is_locking",
|
|
||||||
"_attr_is_open",
|
|
||||||
"_attr_is_opening",
|
|
||||||
"_attr_is_unlocking",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
def message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new lock state messages."""
|
"""Handle new lock state messages."""
|
||||||
payload = self._value_template(msg.payload)
|
payload = self._value_template(msg.payload)
|
||||||
if not payload.strip(): # No output from template, ignore
|
if not payload.strip(): # No output from template, ignore
|
||||||
@@ -227,16 +204,36 @@ class MqttLock(MqttEntity, LockEntity):
|
|||||||
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
|
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
|
||||||
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
|
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics: dict[str, dict[str, Any]]
|
||||||
|
qos: int = self._config[CONF_QOS]
|
||||||
|
encoding: str | None = self._config[CONF_ENCODING] or None
|
||||||
|
|
||||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||||
# Force into optimistic mode.
|
# Force into optimistic mode.
|
||||||
self._optimistic = True
|
self._optimistic = True
|
||||||
else:
|
return
|
||||||
topics[CONF_STATE_TOPIC] = {
|
topics = {
|
||||||
|
CONF_STATE_TOPIC: {
|
||||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||||
"msg_callback": message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._message_received,
|
||||||
|
{
|
||||||
|
"_attr_is_jammed",
|
||||||
|
"_attr_is_locked",
|
||||||
|
"_attr_is_locking",
|
||||||
|
"_attr_is_open",
|
||||||
|
"_attr_is_opening",
|
||||||
|
"_attr_is_unlocking",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
CONF_QOS: qos,
|
CONF_QOS: qos,
|
||||||
CONF_ENCODING: encoding,
|
CONF_ENCODING: encoding,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
@@ -246,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_lock(self, **kwargs: Any) -> None:
|
async def async_lock(self, **kwargs: Any) -> None:
|
||||||
"""Lock the device.
|
"""Lock the device.
|
||||||
|
@@ -114,7 +114,7 @@ from .models import (
|
|||||||
from .subscription import (
|
from .subscription import (
|
||||||
EntitySubscription,
|
EntitySubscription,
|
||||||
async_prepare_subscribe_topics,
|
async_prepare_subscribe_topics,
|
||||||
async_subscribe_topics,
|
async_subscribe_topics_internal,
|
||||||
async_unsubscribe_topics,
|
async_unsubscribe_topics,
|
||||||
)
|
)
|
||||||
from .util import mqtt_config_entry_enabled
|
from .util import mqtt_config_entry_enabled
|
||||||
@@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity):
|
|||||||
"""Subscribe MQTT events."""
|
"""Subscribe MQTT events."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
self._attributes_prepare_subscribe_topics()
|
self._attributes_prepare_subscribe_topics()
|
||||||
await self._attributes_subscribe_topics()
|
self._attributes_subscribe_topics()
|
||||||
|
|
||||||
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
|
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||||
"""Handle updated discovery message."""
|
"""Handle updated discovery message."""
|
||||||
@@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity):
|
|||||||
|
|
||||||
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
|
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||||
"""Handle updated discovery message."""
|
"""Handle updated discovery message."""
|
||||||
await self._attributes_subscribe_topics()
|
self._attributes_subscribe_topics()
|
||||||
|
|
||||||
def _attributes_prepare_subscribe_topics(self) -> None:
|
def _attributes_prepare_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
@@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _attributes_subscribe_topics(self) -> None:
|
@callback
|
||||||
|
def _attributes_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await async_subscribe_topics(self.hass, self._attributes_sub_state)
|
async_subscribe_topics_internal(self.hass, self._attributes_sub_state)
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self) -> None:
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
"""Unsubscribe when removed."""
|
"""Unsubscribe when removed."""
|
||||||
@@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity):
|
|||||||
"""Subscribe MQTT events."""
|
"""Subscribe MQTT events."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
self._availability_prepare_subscribe_topics()
|
self._availability_prepare_subscribe_topics()
|
||||||
await self._availability_subscribe_topics()
|
self._availability_subscribe_topics()
|
||||||
self.async_on_remove(
|
self.async_on_remove(
|
||||||
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
|
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
|
||||||
)
|
)
|
||||||
@@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity):
|
|||||||
|
|
||||||
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
|
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||||
"""Handle updated discovery message."""
|
"""Handle updated discovery message."""
|
||||||
await self._availability_subscribe_topics()
|
self._availability_subscribe_topics()
|
||||||
|
|
||||||
def _availability_setup_from_config(self, config: ConfigType) -> None:
|
def _availability_setup_from_config(self, config: ConfigType) -> None:
|
||||||
"""(Re)Setup."""
|
"""(Re)Setup."""
|
||||||
@@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity):
|
|||||||
self._available[topic] = False
|
self._available[topic] = False
|
||||||
self._available_latest = False
|
self._available_latest = False
|
||||||
|
|
||||||
async def _availability_subscribe_topics(self) -> None:
|
@callback
|
||||||
|
def _availability_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await async_subscribe_topics(self.hass, self._availability_sub_state)
|
async_subscribe_topics_internal(self.hass, self._availability_sub_state)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_mqtt_connect(self) -> None:
|
def async_mqtt_connect(self) -> None:
|
||||||
@@ -1254,12 +1256,14 @@ class MqttEntity(
|
|||||||
def _message_callback(
|
def _message_callback(
|
||||||
self,
|
self,
|
||||||
msg_callback: MessageCallbackType,
|
msg_callback: MessageCallbackType,
|
||||||
attributes: set[str],
|
attributes: set[str] | None,
|
||||||
msg: ReceiveMessage,
|
msg: ReceiveMessage,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process the message callback."""
|
"""Process the message callback."""
|
||||||
|
if attributes is not None:
|
||||||
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
|
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
|
||||||
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes
|
(attribute, getattr(self, attribute, UNDEFINED))
|
||||||
|
for attribute in attributes
|
||||||
)
|
)
|
||||||
mqtt_data = self.hass.data[DATA_MQTT]
|
mqtt_data = self.hass.data[DATA_MQTT]
|
||||||
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
|
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
|
||||||
@@ -1274,7 +1278,7 @@ class MqttEntity(
|
|||||||
_LOGGER.warning(exc)
|
_LOGGER.warning(exc)
|
||||||
return
|
return
|
||||||
|
|
||||||
if self._attrs_have_changed(attrs_snapshot):
|
if attributes is not None and self._attrs_have_changed(attrs_snapshot):
|
||||||
mqtt_data.state_write_requests.write_state_request(self)
|
mqtt_data.state_write_requests.write_state_request(self)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import logging
|
import logging
|
||||||
@@ -70,7 +70,6 @@ class ReceiveMessage:
|
|||||||
timestamp: float
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
type AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]]
|
|
||||||
type MessageCallbackType = Callable[[ReceiveMessage], None]
|
type MessageCallbackType = Callable[[ReceiveMessage], None]
|
||||||
|
|
||||||
|
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@@ -41,12 +42,7 @@ from .const import (
|
|||||||
CONF_RETAIN,
|
CONF_RETAIN,
|
||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
@@ -165,13 +161,8 @@ class MqttNumber(MqttEntity, RestoreNumber):
|
|||||||
self._attr_native_step = config[CONF_STEP]
|
self._attr_native_step = config[CONF_STEP]
|
||||||
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
|
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_native_value"})
|
|
||||||
def message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
num_value: int | float | None
|
num_value: int | float | None
|
||||||
payload = str(self._value_template(msg.payload))
|
payload = str(self._value_template(msg.payload))
|
||||||
@@ -203,17 +194,24 @@ class MqttNumber(MqttEntity, RestoreNumber):
|
|||||||
|
|
||||||
self._attr_native_value = num_value
|
self._attr_native_value = num_value
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||||
# Force into optimistic mode.
|
# Force into optimistic mode.
|
||||||
self._attr_assumed_state = True
|
self._attr_assumed_state = True
|
||||||
else:
|
return
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
{
|
{
|
||||||
"state_topic": {
|
"state_topic": {
|
||||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||||
"msg_callback": message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._message_received,
|
||||||
|
{"_attr_native_value"},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -222,7 +220,7 @@ class MqttNumber(MqttEntity, RestoreNumber):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
if self._attr_assumed_state and (
|
if self._attr_assumed_state and (
|
||||||
last_number_data := await self.async_get_last_number_data()
|
last_number_data := await self.async_get_last_number_data()
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@@ -27,12 +28,7 @@ from .const import (
|
|||||||
CONF_RETAIN,
|
CONF_RETAIN,
|
||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
@@ -113,13 +109,8 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
|
|||||||
config.get(CONF_VALUE_TEMPLATE), entity=self
|
config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_current_option"})
|
|
||||||
def message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT messages."""
|
"""Handle new MQTT messages."""
|
||||||
payload = str(self._value_template(msg.payload))
|
payload = str(self._value_template(msg.payload))
|
||||||
if not payload.strip(): # No output from template, ignore
|
if not payload.strip(): # No output from template, ignore
|
||||||
@@ -143,17 +134,24 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
|
|||||||
return
|
return
|
||||||
self._attr_current_option = payload
|
self._attr_current_option = payload
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||||
# Force into optimistic mode.
|
# Force into optimistic mode.
|
||||||
self._attr_assumed_state = True
|
self._attr_assumed_state = True
|
||||||
else:
|
return
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
{
|
{
|
||||||
"state_topic": {
|
"state_topic": {
|
||||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||||
"msg_callback": message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._message_received,
|
||||||
|
{"_attr_current_option"},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -162,7 +160,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
if self._attr_assumed_state and (
|
if self._attr_assumed_state and (
|
||||||
last_state := await self.async_get_last_state()
|
last_state := await self.async_get_last_state()
|
||||||
|
@@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _value_is_expired(self, *_: datetime) -> None:
|
def _value_is_expired(self, *_: datetime) -> None:
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
@@ -48,12 +49,7 @@ from .const import (
|
|||||||
PAYLOAD_EMPTY_JSON,
|
PAYLOAD_EMPTY_JSON,
|
||||||
PAYLOAD_NONE,
|
PAYLOAD_NONE,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
MqttValueTemplate,
|
MqttValueTemplate,
|
||||||
@@ -205,13 +201,8 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
entity=self,
|
entity=self,
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_is_on", "_extra_attributes"})
|
|
||||||
def state_message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT state messages."""
|
"""Handle new MQTT state messages."""
|
||||||
payload = self._value_template(msg.payload)
|
payload = self._value_template(msg.payload)
|
||||||
if not payload or payload == PAYLOAD_EMPTY_JSON:
|
if not payload or payload == PAYLOAD_EMPTY_JSON:
|
||||||
@@ -271,17 +262,24 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
self._extra_attributes = dict(self._extra_attributes)
|
self._extra_attributes = dict(self._extra_attributes)
|
||||||
self._update(process_turn_on_params(self, params))
|
self._update(process_turn_on_params(self, params))
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||||
# Force into optimistic mode.
|
# Force into optimistic mode.
|
||||||
self._optimistic = True
|
self._optimistic = True
|
||||||
else:
|
return
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
{
|
{
|
||||||
CONF_STATE_TOPIC: {
|
CONF_STATE_TOPIC: {
|
||||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||||
"msg_callback": state_message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._state_message_received,
|
||||||
|
{"_attr_is_on", "_extra_attributes"},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -290,7 +288,7 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extra_state_attributes(self) -> dict[str, Any] | None:
|
def extra_state_attributes(self) -> dict[str, Any] | None:
|
||||||
|
@@ -2,14 +2,15 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable, Coroutine
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
from .. import mqtt
|
|
||||||
from . import debug_info
|
from . import debug_info
|
||||||
|
from .client import async_subscribe_internal
|
||||||
from .const import DEFAULT_QOS
|
from .const import DEFAULT_QOS
|
||||||
from .models import MessageCallbackType
|
from .models import MessageCallbackType
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ class EntitySubscription:
|
|||||||
hass: HomeAssistant
|
hass: HomeAssistant
|
||||||
topic: str | None
|
topic: str | None
|
||||||
message_callback: MessageCallbackType
|
message_callback: MessageCallbackType
|
||||||
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None
|
should_subscribe: bool | None
|
||||||
unsubscribe_callback: Callable[[], None] | None
|
unsubscribe_callback: Callable[[], None] | None
|
||||||
qos: int = 0
|
qos: int = 0
|
||||||
encoding: str = "utf-8"
|
encoding: str = "utf-8"
|
||||||
@@ -53,15 +54,16 @@ class EntitySubscription:
|
|||||||
self.hass, self.message_callback, self.topic, self.entity_id
|
self.hass, self.message_callback, self.topic, self.entity_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.subscribe_task = mqtt.async_subscribe(
|
self.should_subscribe = True
|
||||||
hass, self.topic, self.message_callback, self.qos, self.encoding
|
|
||||||
)
|
|
||||||
|
|
||||||
async def subscribe(self) -> None:
|
@callback
|
||||||
|
def subscribe(self) -> None:
|
||||||
"""Subscribe to a topic."""
|
"""Subscribe to a topic."""
|
||||||
if not self.subscribe_task:
|
if not self.should_subscribe or not self.topic:
|
||||||
return
|
return
|
||||||
self.unsubscribe_callback = await self.subscribe_task
|
self.unsubscribe_callback = async_subscribe_internal(
|
||||||
|
self.hass, self.topic, self.message_callback, self.qos, self.encoding
|
||||||
|
)
|
||||||
|
|
||||||
def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
|
def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
|
||||||
"""Check if we should re-subscribe to the topic using the old state."""
|
"""Check if we should re-subscribe to the topic using the old state."""
|
||||||
@@ -79,6 +81,7 @@ class EntitySubscription:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_prepare_subscribe_topics(
|
def async_prepare_subscribe_topics(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
new_state: dict[str, EntitySubscription] | None,
|
new_state: dict[str, EntitySubscription] | None,
|
||||||
@@ -107,7 +110,7 @@ def async_prepare_subscribe_topics(
|
|||||||
qos=value.get("qos", DEFAULT_QOS),
|
qos=value.get("qos", DEFAULT_QOS),
|
||||||
encoding=value.get("encoding", "utf-8"),
|
encoding=value.get("encoding", "utf-8"),
|
||||||
hass=hass,
|
hass=hass,
|
||||||
subscribe_task=None,
|
should_subscribe=None,
|
||||||
entity_id=value.get("entity_id", None),
|
entity_id=value.get("entity_id", None),
|
||||||
)
|
)
|
||||||
# Get the current subscription state
|
# Get the current subscription state
|
||||||
@@ -135,12 +138,29 @@ async def async_subscribe_topics(
|
|||||||
sub_state: dict[str, EntitySubscription],
|
sub_state: dict[str, EntitySubscription],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""(Re)Subscribe to a set of MQTT topics."""
|
"""(Re)Subscribe to a set of MQTT topics."""
|
||||||
for sub in sub_state.values():
|
async_subscribe_topics_internal(hass, sub_state)
|
||||||
await sub.subscribe()
|
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_subscribe_topics_internal(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
sub_state: dict[str, EntitySubscription],
|
||||||
|
) -> None:
|
||||||
|
"""(Re)Subscribe to a set of MQTT topics.
|
||||||
|
|
||||||
|
This function is internal to the MQTT integration and should not be called
|
||||||
|
from outside the integration.
|
||||||
|
"""
|
||||||
|
for sub in sub_state.values():
|
||||||
|
sub.subscribe()
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
def async_unsubscribe_topics(
|
def async_unsubscribe_topics(
|
||||||
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
|
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
|
||||||
) -> dict[str, EntitySubscription]:
|
) -> dict[str, EntitySubscription]:
|
||||||
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
||||||
return async_prepare_subscribe_topics(hass, sub_state, {})
|
|
||||||
|
|
||||||
|
async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={})
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@@ -36,12 +37,7 @@ from .const import (
|
|||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
PAYLOAD_NONE,
|
PAYLOAD_NONE,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import MqttValueTemplate, ReceiveMessage
|
from .models import MqttValueTemplate, ReceiveMessage
|
||||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||||
|
|
||||||
@@ -118,13 +114,8 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
|||||||
self._config.get(CONF_VALUE_TEMPLATE), entity=self
|
self._config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
|
||||||
def state_message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT state messages."""
|
"""Handle new MQTT state messages."""
|
||||||
payload = self._value_template(msg.payload)
|
payload = self._value_template(msg.payload)
|
||||||
if payload == self._state_on:
|
if payload == self._state_on:
|
||||||
@@ -134,17 +125,24 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
|||||||
elif payload == PAYLOAD_NONE:
|
elif payload == PAYLOAD_NONE:
|
||||||
self._attr_is_on = None
|
self._attr_is_on = None
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||||
# Force into optimistic mode.
|
# Force into optimistic mode.
|
||||||
self._optimistic = True
|
self._optimistic = True
|
||||||
else:
|
return
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._sub_state,
|
self._sub_state,
|
||||||
{
|
{
|
||||||
CONF_STATE_TOPIC: {
|
CONF_STATE_TOPIC: {
|
||||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||||
"msg_callback": state_message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._state_message_received,
|
||||||
|
{"_attr_is_on"},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -153,7 +151,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
if self._optimistic and (last_state := await self.async_get_last_state()):
|
if self._optimistic and (last_state := await self.async_get_last_state()):
|
||||||
self._attr_is_on = last_state.state == STATE_ON
|
self._attr_is_on = last_state.state == STATE_ON
|
||||||
|
@@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin):
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_tear_down(self) -> None:
|
async def async_tear_down(self) -> None:
|
||||||
"""Cleanup tag scanner."""
|
"""Cleanup tag scanner."""
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -34,12 +35,7 @@ from .const import (
|
|||||||
CONF_RETAIN,
|
CONF_RETAIN,
|
||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import (
|
from .models import (
|
||||||
MessageCallbackType,
|
MessageCallbackType,
|
||||||
MqttCommandTemplate,
|
MqttCommandTemplate,
|
||||||
@@ -160,32 +156,41 @@ class MqttTextEntity(MqttEntity, TextEntity):
|
|||||||
self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None
|
self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None
|
||||||
self._attr_assumed_state = bool(self._optimistic)
|
self._attr_assumed_state = bool(self._optimistic)
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
topics: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def add_subscription(
|
|
||||||
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType
|
|
||||||
) -> None:
|
|
||||||
if self._config.get(topic) is not None:
|
|
||||||
topics[topic] = {
|
|
||||||
"topic": self._config[topic],
|
|
||||||
"msg_callback": msg_callback,
|
|
||||||
"qos": self._config[CONF_QOS],
|
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
|
||||||
}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_native_value"})
|
|
||||||
def handle_state_message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle receiving state message via MQTT."""
|
"""Handle receiving state message via MQTT."""
|
||||||
payload = str(self._value_template(msg.payload))
|
payload = str(self._value_template(msg.payload))
|
||||||
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
|
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
|
||||||
return
|
return
|
||||||
self._attr_native_value = payload
|
self._attr_native_value = payload
|
||||||
|
|
||||||
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def add_subscription(
|
||||||
|
topics: dict[str, Any],
|
||||||
|
topic: str,
|
||||||
|
msg_callback: MessageCallbackType,
|
||||||
|
tracked_attributes: set[str],
|
||||||
|
) -> None:
|
||||||
|
if self._config.get(topic) is not None:
|
||||||
|
topics[topic] = {
|
||||||
|
"topic": self._config[topic],
|
||||||
|
"msg_callback": partial(
|
||||||
|
self._message_callback, msg_callback, tracked_attributes
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
|
"qos": self._config[CONF_QOS],
|
||||||
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
|
}
|
||||||
|
|
||||||
|
add_subscription(
|
||||||
|
topics,
|
||||||
|
CONF_STATE_TOPIC,
|
||||||
|
self._handle_state_message_received,
|
||||||
|
{"_attr_native_value"},
|
||||||
|
)
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
self.hass, self._sub_state, topics
|
self.hass, self._sub_state, topics
|
||||||
@@ -193,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_set_value(self, value: str) -> None:
|
async def async_set_value(self, value: str) -> None:
|
||||||
"""Change the text."""
|
"""Change the text."""
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TypedDict, cast
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
@@ -32,12 +33,7 @@ from .const import (
|
|||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
PAYLOAD_EMPTY_JSON,
|
PAYLOAD_EMPTY_JSON,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage
|
from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage
|
||||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||||
from .util import valid_publish_topic, valid_subscribe_topic
|
from .util import valid_publish_topic, valid_subscribe_topic
|
||||||
@@ -141,35 +137,8 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||||||
).async_render_with_possible_json_value,
|
).async_render_with_possible_json_value,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
topics: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def add_subscription(
|
|
||||||
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType
|
|
||||||
) -> None:
|
|
||||||
if self._config.get(topic) is not None:
|
|
||||||
topics[topic] = {
|
|
||||||
"topic": self._config[topic],
|
|
||||||
"msg_callback": msg_callback,
|
|
||||||
"qos": self._config[CONF_QOS],
|
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
|
||||||
}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self,
|
|
||||||
{
|
|
||||||
"_attr_installed_version",
|
|
||||||
"_attr_latest_version",
|
|
||||||
"_attr_title",
|
|
||||||
"_attr_release_summary",
|
|
||||||
"_attr_release_url",
|
|
||||||
"_entity_picture",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
def handle_state_message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle receiving state message via MQTT."""
|
"""Handle receiving state message via MQTT."""
|
||||||
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
|
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
|
||||||
|
|
||||||
@@ -233,20 +202,53 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||||||
if "entity_picture" in json_payload:
|
if "entity_picture" in json_payload:
|
||||||
self._entity_picture = json_payload["entity_picture"]
|
self._entity_picture = json_payload["entity_picture"]
|
||||||
|
|
||||||
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _handle_latest_version_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(self, {"_attr_latest_version"})
|
|
||||||
def handle_latest_version_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle receiving latest version via MQTT."""
|
"""Handle receiving latest version via MQTT."""
|
||||||
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
|
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
|
||||||
|
|
||||||
if isinstance(latest_version, str) and latest_version != "":
|
if isinstance(latest_version, str) and latest_version != "":
|
||||||
self._attr_latest_version = latest_version
|
self._attr_latest_version = latest_version
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def add_subscription(
|
||||||
|
topics: dict[str, Any],
|
||||||
|
topic: str,
|
||||||
|
msg_callback: MessageCallbackType,
|
||||||
|
tracked_attributes: set[str],
|
||||||
|
) -> None:
|
||||||
|
if self._config.get(topic) is not None:
|
||||||
|
topics[topic] = {
|
||||||
|
"topic": self._config[topic],
|
||||||
|
"msg_callback": partial(
|
||||||
|
self._message_callback, msg_callback, tracked_attributes
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
|
"qos": self._config[CONF_QOS],
|
||||||
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
|
}
|
||||||
|
|
||||||
add_subscription(
|
add_subscription(
|
||||||
topics, CONF_LATEST_VERSION_TOPIC, handle_latest_version_received
|
topics,
|
||||||
|
CONF_STATE_TOPIC,
|
||||||
|
self._handle_state_message_received,
|
||||||
|
{
|
||||||
|
"_attr_installed_version",
|
||||||
|
"_attr_latest_version",
|
||||||
|
"_attr_title",
|
||||||
|
"_attr_release_summary",
|
||||||
|
"_attr_release_url",
|
||||||
|
"_entity_picture",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
add_subscription(
|
||||||
|
topics,
|
||||||
|
CONF_LATEST_VERSION_TOPIC,
|
||||||
|
self._handle_latest_version_received,
|
||||||
|
{"_attr_latest_version"},
|
||||||
)
|
)
|
||||||
|
|
||||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||||
@@ -255,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_install(
|
async def async_install(
|
||||||
self, version: str | None, backup: bool, **kwargs: Any
|
self, version: str | None, backup: bool, **kwargs: Any
|
||||||
|
@@ -8,6 +8,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
@@ -49,12 +50,7 @@ from .const import (
|
|||||||
CONF_STATE_TOPIC,
|
CONF_STATE_TOPIC,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import ReceiveMessage
|
from .models import ReceiveMessage
|
||||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||||
from .util import valid_publish_topic
|
from .util import valid_publish_topic
|
||||||
@@ -322,16 +318,8 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
|
|||||||
self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0)
|
self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0)
|
||||||
self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0)))
|
self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0)))
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
topics: dict[str, Any] = {}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self, {"_attr_battery_level", "_attr_fan_speed", "_attr_state"}
|
|
||||||
)
|
|
||||||
def state_message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle state MQTT message."""
|
"""Handle state MQTT message."""
|
||||||
payload = json_loads_object(msg.payload)
|
payload = json_loads_object(msg.payload)
|
||||||
if STATE in payload and (
|
if STATE in payload and (
|
||||||
@@ -343,10 +331,19 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
|
|||||||
del payload[STATE]
|
del payload[STATE]
|
||||||
self._update_state_attributes(payload)
|
self._update_state_attributes(payload)
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics: dict[str, Any] = {}
|
||||||
|
|
||||||
if state_topic := self._config.get(CONF_STATE_TOPIC):
|
if state_topic := self._config.get(CONF_STATE_TOPIC):
|
||||||
topics["state_position_topic"] = {
|
topics["state_position_topic"] = {
|
||||||
"topic": state_topic,
|
"topic": state_topic,
|
||||||
"msg_callback": state_message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._state_message_received,
|
||||||
|
{"_attr_battery_level", "_attr_fan_speed", "_attr_state"},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -356,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
|
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
|
||||||
"""Publish a command."""
|
"""Publish a command."""
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -61,12 +62,7 @@ from .const import (
|
|||||||
DEFAULT_RETAIN,
|
DEFAULT_RETAIN,
|
||||||
PAYLOAD_NONE,
|
PAYLOAD_NONE,
|
||||||
)
|
)
|
||||||
from .debug_info import log_messages
|
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||||
from .mixins import (
|
|
||||||
MqttEntity,
|
|
||||||
async_setup_entity_entry_helper,
|
|
||||||
write_state_on_attr_change,
|
|
||||||
)
|
|
||||||
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
|
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
|
||||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||||
from .util import valid_publish_topic, valid_subscribe_topic
|
from .util import valid_publish_topic, valid_subscribe_topic
|
||||||
@@ -302,22 +298,8 @@ class MqttValve(MqttEntity, ValveEntity):
|
|||||||
return
|
return
|
||||||
self._update_state(state)
|
self._update_state(state)
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self) -> None:
|
|
||||||
"""(Re)Subscribe to topics."""
|
|
||||||
topics = {}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
def _state_message_received(self, msg: ReceiveMessage) -> None:
|
||||||
@write_state_on_attr_change(
|
|
||||||
self,
|
|
||||||
{
|
|
||||||
"_attr_current_valve_position",
|
|
||||||
"_attr_is_closed",
|
|
||||||
"_attr_is_closing",
|
|
||||||
"_attr_is_opening",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
def state_message_received(msg: ReceiveMessage) -> None:
|
|
||||||
"""Handle new MQTT state messages."""
|
"""Handle new MQTT state messages."""
|
||||||
payload = self._value_template(msg.payload)
|
payload = self._value_template(msg.payload)
|
||||||
payload_dict: Any = None
|
payload_dict: Any = None
|
||||||
@@ -351,16 +333,28 @@ class MqttValve(MqttEntity, ValveEntity):
|
|||||||
state_payload = payload_dict.get("state")
|
state_payload = payload_dict.get("state")
|
||||||
|
|
||||||
if self._config[CONF_REPORTS_POSITION]:
|
if self._config[CONF_REPORTS_POSITION]:
|
||||||
self._process_position_valve_update(
|
self._process_position_valve_update(msg, position_payload, state_payload)
|
||||||
msg, position_payload, state_payload
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self._process_binary_valve_update(msg, state_payload)
|
self._process_binary_valve_update(msg, state_payload)
|
||||||
|
|
||||||
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
|
"""(Re)Subscribe to topics."""
|
||||||
|
topics = {}
|
||||||
|
|
||||||
if self._config.get(CONF_STATE_TOPIC):
|
if self._config.get(CONF_STATE_TOPIC):
|
||||||
topics["state_topic"] = {
|
topics["state_topic"] = {
|
||||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||||
"msg_callback": state_message_received,
|
"msg_callback": partial(
|
||||||
|
self._message_callback,
|
||||||
|
self._state_message_received,
|
||||||
|
{
|
||||||
|
"_attr_current_valve_position",
|
||||||
|
"_attr_is_closed",
|
||||||
|
"_attr_is_closing",
|
||||||
|
"_attr_is_opening",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"entity_id": self.entity_id,
|
||||||
"qos": self._config[CONF_QOS],
|
"qos": self._config[CONF_QOS],
|
||||||
"encoding": self._config[CONF_ENCODING] or None,
|
"encoding": self._config[CONF_ENCODING] or None,
|
||||||
}
|
}
|
||||||
@@ -371,7 +365,7 @@ class MqttValve(MqttEntity, ValveEntity):
|
|||||||
|
|
||||||
async def _subscribe_topics(self) -> None:
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||||
|
|
||||||
async def async_open_valve(self) -> None:
|
async def async_open_valve(self) -> None:
|
||||||
"""Move the valve up.
|
"""Move the valve up.
|
||||||
|
@@ -9,6 +9,7 @@ from typing import Literal
|
|||||||
import ollama
|
import ollama
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import MATCH_ALL
|
from homeassistant.const import MATCH_ALL
|
||||||
@@ -138,6 +139,11 @@ class OllamaConversationEntity(
|
|||||||
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trace.async_conversation_trace_append(
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
|
{"messages": message_history.messages},
|
||||||
|
)
|
||||||
|
|
||||||
# Get response
|
# Get response
|
||||||
try:
|
try:
|
||||||
response = await client.chat(
|
response = await client.chat(
|
||||||
|
@@ -31,14 +31,15 @@ from .const import (
|
|||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
|
CONF_RECOMMENDED,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DEFAULT_CHAT_MODEL,
|
|
||||||
DEFAULT_MAX_TOKENS,
|
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
|
||||||
DEFAULT_TOP_P,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@@ -49,6 +50,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RECOMMENDED_OPTIONS = {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||||
|
CONF_PROMPT: DEFAULT_PROMPT,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||||
"""Validate the user input allows us to connect.
|
"""Validate the user input allows us to connect.
|
||||||
@@ -88,7 +95,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title="ChatGPT",
|
title="ChatGPT",
|
||||||
data=user_input,
|
data=user_input,
|
||||||
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
options=RECOMMENDED_OPTIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
@@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow):
|
|||||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||||
"""Initialize options flow."""
|
"""Initialize options flow."""
|
||||||
self.config_entry = config_entry
|
self.config_entry = config_entry
|
||||||
|
self.last_rendered_recommended = config_entry.options.get(
|
||||||
|
CONF_RECOMMENDED, False
|
||||||
|
)
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> ConfigFlowResult:
|
) -> ConfigFlowResult:
|
||||||
"""Manage the options."""
|
"""Manage the options."""
|
||||||
|
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
|
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||||
if user_input[CONF_LLM_HASS_API] == "none":
|
if user_input[CONF_LLM_HASS_API] == "none":
|
||||||
user_input.pop(CONF_LLM_HASS_API)
|
user_input.pop(CONF_LLM_HASS_API)
|
||||||
return self.async_create_entry(title="", data=user_input)
|
return self.async_create_entry(title="", data=user_input)
|
||||||
schema = openai_config_option_schema(self.hass, self.config_entry.options)
|
|
||||||
|
# Re-render the options again, now with the recommended options shown/hidden
|
||||||
|
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||||
|
|
||||||
|
options = {
|
||||||
|
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
||||||
|
CONF_PROMPT: user_input[CONF_PROMPT],
|
||||||
|
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = openai_config_option_schema(self.hass, options)
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="init",
|
step_id="init",
|
||||||
data_schema=vol.Schema(schema),
|
data_schema=vol.Schema(schema),
|
||||||
@@ -127,16 +150,16 @@ class OpenAIOptionsFlow(OptionsFlow):
|
|||||||
|
|
||||||
def openai_config_option_schema(
|
def openai_config_option_schema(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
options: MappingProxyType[str, Any],
|
options: dict[str, Any] | MappingProxyType[str, Any],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Return a schema for OpenAI completion options."""
|
"""Return a schema for OpenAI completion options."""
|
||||||
apis: list[SelectOptionDict] = [
|
hass_apis: list[SelectOptionDict] = [
|
||||||
SelectOptionDict(
|
SelectOptionDict(
|
||||||
label="No control",
|
label="No control",
|
||||||
value="none",
|
value="none",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
apis.extend(
|
hass_apis.extend(
|
||||||
SelectOptionDict(
|
SelectOptionDict(
|
||||||
label=api.name,
|
label=api.name,
|
||||||
value=api.id,
|
value=api.id,
|
||||||
@@ -144,38 +167,46 @@ def openai_config_option_schema(
|
|||||||
for api in llm.async_get_apis(hass)
|
for api in llm.async_get_apis(hass)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
schema = {
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
|
||||||
default=DEFAULT_PROMPT,
|
|
||||||
): TemplateSelector(),
|
): TemplateSelector(),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_LLM_HASS_API,
|
CONF_LLM_HASS_API,
|
||||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||||
default="none",
|
default="none",
|
||||||
): SelectSelector(SelectSelectorConfig(options=apis)),
|
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||||
|
vol.Required(
|
||||||
|
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||||
|
): bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.get(CONF_RECOMMENDED):
|
||||||
|
return schema
|
||||||
|
|
||||||
|
schema.update(
|
||||||
|
{
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
description={
|
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
||||||
# New key in HA 2023.4
|
default=RECOMMENDED_CHAT_MODEL,
|
||||||
"suggested_value": options.get(CONF_CHAT_MODEL)
|
|
||||||
},
|
|
||||||
default=DEFAULT_CHAT_MODEL,
|
|
||||||
): str,
|
): str,
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||||
default=DEFAULT_MAX_TOKENS,
|
default=RECOMMENDED_MAX_TOKENS,
|
||||||
): int,
|
): int,
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||||
default=DEFAULT_TOP_P,
|
default=RECOMMENDED_TOP_P,
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||||
vol.Optional(
|
vol.Optional(
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||||
default=DEFAULT_TEMPERATURE,
|
default=RECOMMENDED_TEMPERATURE,
|
||||||
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
return schema
|
||||||
|
@@ -4,13 +4,15 @@ import logging
|
|||||||
|
|
||||||
DOMAIN = "openai_conversation"
|
DOMAIN = "openai_conversation"
|
||||||
LOGGER = logging.getLogger(__package__)
|
LOGGER = logging.getLogger(__package__)
|
||||||
|
|
||||||
|
CONF_RECOMMENDED = "recommended"
|
||||||
CONF_PROMPT = "prompt"
|
CONF_PROMPT = "prompt"
|
||||||
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
|
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
|
||||||
CONF_CHAT_MODEL = "chat_model"
|
CONF_CHAT_MODEL = "chat_model"
|
||||||
DEFAULT_CHAT_MODEL = "gpt-4o"
|
RECOMMENDED_CHAT_MODEL = "gpt-4o"
|
||||||
CONF_MAX_TOKENS = "max_tokens"
|
CONF_MAX_TOKENS = "max_tokens"
|
||||||
DEFAULT_MAX_TOKENS = 150
|
RECOMMENDED_MAX_TOKENS = 150
|
||||||
CONF_TOP_P = "top_p"
|
CONF_TOP_P = "top_p"
|
||||||
DEFAULT_TOP_P = 1.0
|
RECOMMENDED_TOP_P = 1.0
|
||||||
CONF_TEMPERATURE = "temperature"
|
CONF_TEMPERATURE = "temperature"
|
||||||
DEFAULT_TEMPERATURE = 1.0
|
RECOMMENDED_TEMPERATURE = 1.0
|
||||||
|
@@ -8,6 +8,7 @@ import voluptuous as vol
|
|||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@@ -22,13 +23,13 @@ from .const import (
|
|||||||
CONF_PROMPT,
|
CONF_PROMPT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
DEFAULT_CHAT_MODEL,
|
|
||||||
DEFAULT_MAX_TOKENS,
|
|
||||||
DEFAULT_PROMPT,
|
DEFAULT_PROMPT,
|
||||||
DEFAULT_TEMPERATURE,
|
|
||||||
DEFAULT_TOP_P,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
# Max number of back and forth with the LLM to generate a response
|
||||||
@@ -97,15 +98,14 @@ class OpenAIConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
|
options = self.entry.options
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
llm_api: llm.API | None = None
|
llm_api: llm.API | None = None
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
if options.get(CONF_LLM_HASS_API):
|
||||||
try:
|
try:
|
||||||
llm_api = llm.async_get_api(
|
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
|
||||||
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
|
||||||
)
|
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
LOGGER.error("Error getting LLM API: %s", err)
|
LOGGER.error("Error getting LLM API: %s", err)
|
||||||
intent_response.async_set_error(
|
intent_response.async_set_error(
|
||||||
@@ -117,26 +117,12 @@ class OpenAIConversationEntity(
|
|||||||
)
|
)
|
||||||
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
||||||
|
|
||||||
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
|
||||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
|
||||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
|
||||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
|
||||||
|
|
||||||
if user_input.conversation_id in self.history:
|
if user_input.conversation_id in self.history:
|
||||||
conversation_id = user_input.conversation_id
|
conversation_id = user_input.conversation_id
|
||||||
messages = self.history[conversation_id]
|
messages = self.history[conversation_id]
|
||||||
else:
|
else:
|
||||||
conversation_id = ulid.ulid_now()
|
conversation_id = ulid.ulid_now()
|
||||||
try:
|
try:
|
||||||
prompt = template.Template(
|
|
||||||
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
|
||||||
).async_render(
|
|
||||||
{
|
|
||||||
"ha_name": self.hass.config.location_name,
|
|
||||||
},
|
|
||||||
parse_result=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if llm_api:
|
if llm_api:
|
||||||
empty_tool_input = llm.ToolInput(
|
empty_tool_input = llm.ToolInput(
|
||||||
tool_name="",
|
tool_name="",
|
||||||
@@ -149,10 +135,23 @@ class OpenAIConversationEntity(
|
|||||||
device_id=user_input.device_id,
|
device_id=user_input.device_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = (
|
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
|
||||||
await llm_api.async_get_api_prompt(empty_tool_input)
|
|
||||||
+ "\n"
|
else:
|
||||||
+ prompt
|
api_prompt = llm.PROMPT_NO_API_CONFIGURED
|
||||||
|
|
||||||
|
prompt = "\n".join(
|
||||||
|
(
|
||||||
|
template.Template(
|
||||||
|
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||||
|
).async_render(
|
||||||
|
{
|
||||||
|
"ha_name": self.hass.config.location_name,
|
||||||
|
},
|
||||||
|
parse_result=False,
|
||||||
|
),
|
||||||
|
api_prompt,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
except TemplateError as err:
|
except TemplateError as err:
|
||||||
@@ -170,7 +169,10 @@ class OpenAIConversationEntity(
|
|||||||
|
|
||||||
messages.append({"role": "user", "content": user_input.text})
|
messages.append({"role": "user", "content": user_input.text})
|
||||||
|
|
||||||
LOGGER.debug("Prompt for %s: %s", model, messages)
|
LOGGER.debug("Prompt: %s", messages)
|
||||||
|
trace.async_conversation_trace_append(
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
||||||
|
)
|
||||||
|
|
||||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||||
|
|
||||||
@@ -178,12 +180,12 @@ class OpenAIConversationEntity(
|
|||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
try:
|
try:
|
||||||
result = await client.chat.completions.create(
|
result = await client.chat.completions.create(
|
||||||
model=model,
|
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
max_tokens=max_tokens,
|
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||||
top_p=top_p,
|
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
temperature=temperature,
|
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
user=conversation_id,
|
user=conversation_id,
|
||||||
)
|
)
|
||||||
except openai.OpenAIError as err:
|
except openai.OpenAIError as err:
|
||||||
|
@@ -22,7 +22,8 @@
|
|||||||
"max_tokens": "Maximum tokens to return in response",
|
"max_tokens": "Maximum tokens to return in response",
|
||||||
"temperature": "Temperature",
|
"temperature": "Temperature",
|
||||||
"top_p": "Top P",
|
"top_p": "Top P",
|
||||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||||
|
"recommended": "Recommended model settings"
|
||||||
},
|
},
|
||||||
"data_description": {
|
"data_description": {
|
||||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||||
|
@@ -30,7 +30,6 @@ from .util import (
|
|||||||
|
|
||||||
PLATFORMS = [Platform.MEDIA_PLAYER]
|
PLATFORMS = [Platform.MEDIA_PLAYER]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"async_browse_media",
|
"async_browse_media",
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
@@ -50,7 +49,10 @@ class HomeAssistantSpotifyData:
|
|||||||
session: OAuth2Session
|
session: OAuth2Session
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
type SpotifyConfigEntry = ConfigEntry[HomeAssistantSpotifyData]
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(hass: HomeAssistant, entry: SpotifyConfigEntry) -> bool:
|
||||||
"""Set up Spotify from a config entry."""
|
"""Set up Spotify from a config entry."""
|
||||||
implementation = await async_get_config_entry_implementation(hass, entry)
|
implementation = await async_get_config_entry_implementation(hass, entry)
|
||||||
session = OAuth2Session(hass, entry, implementation)
|
session = OAuth2Session(hass, entry, implementation)
|
||||||
@@ -100,8 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
)
|
)
|
||||||
await device_coordinator.async_config_entry_first_refresh()
|
await device_coordinator.async_config_entry_first_refresh()
|
||||||
|
|
||||||
hass.data.setdefault(DOMAIN, {})
|
entry.runtime_data = HomeAssistantSpotifyData(
|
||||||
hass.data[DOMAIN][entry.entry_id] = HomeAssistantSpotifyData(
|
|
||||||
client=spotify,
|
client=spotify,
|
||||||
current_user=current_user,
|
current_user=current_user,
|
||||||
devices=device_coordinator,
|
devices=device_coordinator,
|
||||||
@@ -117,6 +118,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload Spotify config entry."""
|
"""Unload Spotify config entry."""
|
||||||
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||||
del hass.data[DOMAIN][entry.entry_id]
|
|
||||||
return unload_ok
|
|
||||||
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from spotipy import Spotify
|
from spotipy import Spotify
|
||||||
import yarl
|
import yarl
|
||||||
@@ -22,6 +22,9 @@ from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
|
|||||||
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES
|
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES
|
||||||
from .util import fetch_image_url
|
from .util import fetch_image_url
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import HomeAssistantSpotifyData
|
||||||
|
|
||||||
BROWSE_LIMIT = 48
|
BROWSE_LIMIT = 48
|
||||||
|
|
||||||
|
|
||||||
@@ -140,21 +143,21 @@ async def async_browse_media(
|
|||||||
|
|
||||||
# Check if caller is requesting the root nodes
|
# Check if caller is requesting the root nodes
|
||||||
if media_content_type is None and media_content_id is None:
|
if media_content_type is None and media_content_id is None:
|
||||||
children = []
|
config_entries = hass.config_entries.async_entries(
|
||||||
for config_entry_id in hass.data[DOMAIN]:
|
DOMAIN, include_disabled=False, include_ignore=False
|
||||||
config_entry = hass.config_entries.async_get_entry(config_entry_id)
|
)
|
||||||
assert config_entry is not None
|
children = [
|
||||||
children.append(
|
|
||||||
BrowseMedia(
|
BrowseMedia(
|
||||||
title=config_entry.title,
|
title=config_entry.title,
|
||||||
media_class=MediaClass.APP,
|
media_class=MediaClass.APP,
|
||||||
media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry_id}",
|
media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry.entry_id}",
|
||||||
media_content_type=f"{MEDIA_PLAYER_PREFIX}library",
|
media_content_type=f"{MEDIA_PLAYER_PREFIX}library",
|
||||||
thumbnail="https://brands.home-assistant.io/_/spotify/logo.png",
|
thumbnail="https://brands.home-assistant.io/_/spotify/logo.png",
|
||||||
can_play=False,
|
can_play=False,
|
||||||
can_expand=True,
|
can_expand=True,
|
||||||
)
|
)
|
||||||
)
|
for config_entry in config_entries
|
||||||
|
]
|
||||||
return BrowseMedia(
|
return BrowseMedia(
|
||||||
title="Spotify",
|
title="Spotify",
|
||||||
media_class=MediaClass.APP,
|
media_class=MediaClass.APP,
|
||||||
@@ -171,9 +174,15 @@ async def async_browse_media(
|
|||||||
|
|
||||||
# Check for config entry specifier, and extract Spotify URI
|
# Check for config entry specifier, and extract Spotify URI
|
||||||
parsed_url = yarl.URL(media_content_id)
|
parsed_url = yarl.URL(media_content_id)
|
||||||
if (info := hass.data[DOMAIN].get(parsed_url.host)) is None:
|
|
||||||
|
if (
|
||||||
|
parsed_url.host is None
|
||||||
|
or (entry := hass.config_entries.async_get_entry(parsed_url.host)) is None
|
||||||
|
or not isinstance(entry.runtime_data, HomeAssistantSpotifyData)
|
||||||
|
):
|
||||||
raise BrowseError("Invalid Spotify account specified")
|
raise BrowseError("Invalid Spotify account specified")
|
||||||
media_content_id = parsed_url.name
|
media_content_id = parsed_url.name
|
||||||
|
info = entry.runtime_data
|
||||||
|
|
||||||
result = await async_browse_media_internal(
|
result = await async_browse_media_internal(
|
||||||
hass,
|
hass,
|
||||||
|
@@ -22,7 +22,6 @@ from homeassistant.components.media_player import (
|
|||||||
MediaType,
|
MediaType,
|
||||||
RepeatMode,
|
RepeatMode,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.const import CONF_ID
|
from homeassistant.const import CONF_ID
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@@ -30,7 +29,7 @@ from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
|||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
from . import HomeAssistantSpotifyData
|
from . import HomeAssistantSpotifyData, SpotifyConfigEntry
|
||||||
from .browse_media import async_browse_media_internal
|
from .browse_media import async_browse_media_internal
|
||||||
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES
|
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES
|
||||||
from .util import fetch_image_url
|
from .util import fetch_image_url
|
||||||
@@ -70,12 +69,12 @@ SPOTIFY_DJ_PLAYLIST = {"uri": "spotify:playlist:37i9dQZF1EYkqdzj48dyYq", "name":
|
|||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
entry: ConfigEntry,
|
entry: SpotifyConfigEntry,
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up Spotify based on a config entry."""
|
"""Set up Spotify based on a config entry."""
|
||||||
spotify = SpotifyMediaPlayer(
|
spotify = SpotifyMediaPlayer(
|
||||||
hass.data[DOMAIN][entry.entry_id],
|
entry.runtime_data,
|
||||||
entry.data[CONF_ID],
|
entry.data[CONF_ID],
|
||||||
entry.title,
|
entry.title,
|
||||||
)
|
)
|
||||||
|
@@ -4,15 +4,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from aioswitcher.bridge import SwitcherBridge
|
||||||
from aioswitcher.device import SwitcherBase
|
from aioswitcher.device import SwitcherBase
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform
|
||||||
from homeassistant.core import Event, HomeAssistant, callback
|
from homeassistant.core import Event, HomeAssistant, callback
|
||||||
|
|
||||||
from .const import DATA_DEVICE, DOMAIN
|
|
||||||
from .coordinator import SwitcherDataUpdateCoordinator
|
from .coordinator import SwitcherDataUpdateCoordinator
|
||||||
from .utils import async_start_bridge, async_stop_bridge
|
|
||||||
|
|
||||||
PLATFORMS = [
|
PLATFORMS = [
|
||||||
Platform.BUTTON,
|
Platform.BUTTON,
|
||||||
@@ -25,20 +24,20 @@ PLATFORMS = [
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
type SwitcherConfigEntry = ConfigEntry[dict[str, SwitcherDataUpdateCoordinator]]
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
|
||||||
"""Set up Switcher from a config entry."""
|
"""Set up Switcher from a config entry."""
|
||||||
hass.data.setdefault(DOMAIN, {})
|
|
||||||
hass.data[DOMAIN][DATA_DEVICE] = {}
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def on_device_data_callback(device: SwitcherBase) -> None:
|
def on_device_data_callback(device: SwitcherBase) -> None:
|
||||||
"""Use as a callback for device data."""
|
"""Use as a callback for device data."""
|
||||||
|
|
||||||
|
coordinators = entry.runtime_data
|
||||||
|
|
||||||
# Existing device update device data
|
# Existing device update device data
|
||||||
if device.device_id in hass.data[DOMAIN][DATA_DEVICE]:
|
if coordinator := coordinators.get(device.device_id):
|
||||||
coordinator: SwitcherDataUpdateCoordinator = hass.data[DOMAIN][DATA_DEVICE][
|
|
||||||
device.device_id
|
|
||||||
]
|
|
||||||
coordinator.async_set_updated_data(device)
|
coordinator.async_set_updated_data(device)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -52,18 +51,21 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
device.device_type.hex_rep,
|
device.device_type.hex_rep,
|
||||||
)
|
)
|
||||||
|
|
||||||
coordinator = hass.data[DOMAIN][DATA_DEVICE][device.device_id] = (
|
coordinator = SwitcherDataUpdateCoordinator(hass, entry, device)
|
||||||
SwitcherDataUpdateCoordinator(hass, entry, device)
|
|
||||||
)
|
|
||||||
coordinator.async_setup()
|
coordinator.async_setup()
|
||||||
|
coordinators[device.device_id] = coordinator
|
||||||
|
|
||||||
# Must be ready before dispatcher is called
|
# Must be ready before dispatcher is called
|
||||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||||
|
|
||||||
await async_start_bridge(hass, on_device_data_callback)
|
entry.runtime_data = {}
|
||||||
|
bridge = SwitcherBridge(on_device_data_callback)
|
||||||
|
await bridge.start()
|
||||||
|
|
||||||
async def stop_bridge(event: Event) -> None:
|
async def stop_bridge(event: Event | None = None) -> None:
|
||||||
await async_stop_bridge(hass)
|
await bridge.stop()
|
||||||
|
|
||||||
|
entry.async_on_unload(stop_bridge)
|
||||||
|
|
||||||
entry.async_on_unload(
|
entry.async_on_unload(
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge)
|
||||||
@@ -72,12 +74,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
|
||||||
"""Unload a config entry."""
|
"""Unload a config entry."""
|
||||||
await async_stop_bridge(hass)
|
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||||
|
|
||||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
|
||||||
if unload_ok:
|
|
||||||
hass.data[DOMAIN].pop(DATA_DEVICE)
|
|
||||||
|
|
||||||
return unload_ok
|
|
||||||
|
@@ -15,7 +15,6 @@ from aioswitcher.api.remotes import SwitcherBreezeRemote
|
|||||||
from aioswitcher.device import DeviceCategory
|
from aioswitcher.device import DeviceCategory
|
||||||
|
|
||||||
from homeassistant.components.button import ButtonEntity, ButtonEntityDescription
|
from homeassistant.components.button import ButtonEntity, ButtonEntityDescription
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.const import EntityCategory
|
from homeassistant.const import EntityCategory
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@@ -25,6 +24,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
|||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||||
|
|
||||||
|
from . import SwitcherConfigEntry
|
||||||
from .const import SIGNAL_DEVICE_ADD
|
from .const import SIGNAL_DEVICE_ADD
|
||||||
from .coordinator import SwitcherDataUpdateCoordinator
|
from .coordinator import SwitcherDataUpdateCoordinator
|
||||||
from .utils import get_breeze_remote_manager
|
from .utils import get_breeze_remote_manager
|
||||||
@@ -78,7 +78,7 @@ THERMOSTAT_BUTTONS = [
|
|||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config_entry: ConfigEntry,
|
config_entry: SwitcherConfigEntry,
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up Switcher button from config entry."""
|
"""Set up Switcher button from config entry."""
|
||||||
|
@@ -25,7 +25,6 @@ from homeassistant.components.climate import (
|
|||||||
ClimateEntityFeature,
|
ClimateEntityFeature,
|
||||||
HVACMode,
|
HVACMode,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature
|
from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@@ -35,6 +34,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
|||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||||
|
|
||||||
|
from . import SwitcherConfigEntry
|
||||||
from .const import SIGNAL_DEVICE_ADD
|
from .const import SIGNAL_DEVICE_ADD
|
||||||
from .coordinator import SwitcherDataUpdateCoordinator
|
from .coordinator import SwitcherDataUpdateCoordinator
|
||||||
from .utils import get_breeze_remote_manager
|
from .utils import get_breeze_remote_manager
|
||||||
@@ -61,7 +61,7 @@ HA_TO_DEVICE_FAN = {value: key for key, value in DEVICE_FAN_TO_HA.items()}
|
|||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config_entry: ConfigEntry,
|
config_entry: SwitcherConfigEntry,
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up Switcher climate from config entry."""
|
"""Set up Switcher climate from config entry."""
|
||||||
|
@@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
DOMAIN = "switcher_kis"
|
DOMAIN = "switcher_kis"
|
||||||
|
|
||||||
DATA_BRIDGE = "bridge"
|
|
||||||
DATA_DEVICE = "device"
|
|
||||||
|
|
||||||
DISCOVERY_TIME_SEC = 12
|
DISCOVERY_TIME_SEC = 12
|
||||||
|
|
||||||
SIGNAL_DEVICE_ADD = "switcher_device_add"
|
SIGNAL_DEVICE_ADD = "switcher_device_add"
|
||||||
|
@@ -6,24 +6,23 @@ from dataclasses import asdict
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.diagnostics import async_redact_data
|
from homeassistant.components.diagnostics import async_redact_data
|
||||||
from homeassistant.config_entries import ConfigEntry
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from .const import DATA_DEVICE, DOMAIN
|
from . import SwitcherConfigEntry
|
||||||
|
|
||||||
TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"}
|
TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"}
|
||||||
|
|
||||||
|
|
||||||
async def async_get_config_entry_diagnostics(
|
async def async_get_config_entry_diagnostics(
|
||||||
hass: HomeAssistant, entry: ConfigEntry
|
hass: HomeAssistant, entry: SwitcherConfigEntry
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return diagnostics for a config entry."""
|
"""Return diagnostics for a config entry."""
|
||||||
devices = hass.data[DOMAIN][DATA_DEVICE]
|
coordinators = entry.runtime_data
|
||||||
|
|
||||||
return async_redact_data(
|
return async_redact_data(
|
||||||
{
|
{
|
||||||
"entry": entry.as_dict(),
|
"entry": entry.as_dict(),
|
||||||
"devices": [asdict(devices[d].data) for d in devices],
|
"devices": [asdict(coordinators[d].data) for d in coordinators],
|
||||||
},
|
},
|
||||||
TO_REDACT,
|
TO_REDACT,
|
||||||
)
|
)
|
||||||
|
@@ -3,9 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from aioswitcher.api.remotes import SwitcherBreezeRemoteManager
|
from aioswitcher.api.remotes import SwitcherBreezeRemoteManager
|
||||||
from aioswitcher.bridge import SwitcherBase, SwitcherBridge
|
from aioswitcher.bridge import SwitcherBase, SwitcherBridge
|
||||||
@@ -13,29 +11,11 @@ from aioswitcher.bridge import SwitcherBase, SwitcherBridge
|
|||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import singleton
|
from homeassistant.helpers import singleton
|
||||||
|
|
||||||
from .const import DATA_BRIDGE, DISCOVERY_TIME_SEC, DOMAIN
|
from .const import DISCOVERY_TIME_SEC
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_start_bridge(
|
|
||||||
hass: HomeAssistant, on_device_callback: Callable[[SwitcherBase], Any]
|
|
||||||
) -> None:
|
|
||||||
"""Start switcher UDP bridge."""
|
|
||||||
bridge = hass.data[DOMAIN][DATA_BRIDGE] = SwitcherBridge(on_device_callback)
|
|
||||||
_LOGGER.debug("Starting Switcher bridge")
|
|
||||||
await bridge.start()
|
|
||||||
|
|
||||||
|
|
||||||
async def async_stop_bridge(hass: HomeAssistant) -> None:
|
|
||||||
"""Stop switcher UDP bridge."""
|
|
||||||
bridge: SwitcherBridge = hass.data[DOMAIN].get(DATA_BRIDGE)
|
|
||||||
if bridge is not None:
|
|
||||||
_LOGGER.debug("Stopping Switcher bridge")
|
|
||||||
await bridge.stop()
|
|
||||||
hass.data[DOMAIN].pop(DATA_BRIDGE)
|
|
||||||
|
|
||||||
|
|
||||||
async def async_has_devices(hass: HomeAssistant) -> bool:
|
async def async_has_devices(hass: HomeAssistant) -> bool:
|
||||||
"""Discover Switcher devices."""
|
"""Discover Switcher devices."""
|
||||||
_LOGGER.debug("Starting discovery")
|
_LOGGER.debug("Starting discovery")
|
||||||
|
@@ -30,6 +30,7 @@ PLATFORMS: Final = [
|
|||||||
Platform.BINARY_SENSOR,
|
Platform.BINARY_SENSOR,
|
||||||
Platform.CLIMATE,
|
Platform.CLIMATE,
|
||||||
Platform.COVER,
|
Platform.COVER,
|
||||||
|
Platform.DEVICE_TRACKER,
|
||||||
Platform.LOCK,
|
Platform.LOCK,
|
||||||
Platform.SELECT,
|
Platform.SELECT,
|
||||||
Platform.SENSOR,
|
Platform.SENSOR,
|
||||||
|
85
homeassistant/components/teslemetry/device_tracker.py
Normal file
85
homeassistant/components/teslemetry/device_tracker.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Device tracker platform for Teslemetry integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from homeassistant.components.device_tracker import SourceType
|
||||||
|
from homeassistant.components.device_tracker.config_entry import TrackerEntity
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .entity import TeslemetryVehicleEntity
|
||||||
|
from .models import TeslemetryVehicleData
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
|
||||||
|
) -> None:
|
||||||
|
"""Set up the Teslemetry device tracker platform from a config entry."""
|
||||||
|
|
||||||
|
async_add_entities(
|
||||||
|
klass(vehicle)
|
||||||
|
for klass in (
|
||||||
|
TeslemetryDeviceTrackerLocationEntity,
|
||||||
|
TeslemetryDeviceTrackerRouteEntity,
|
||||||
|
)
|
||||||
|
for vehicle in entry.runtime_data.vehicles
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TeslemetryDeviceTrackerEntity(TeslemetryVehicleEntity, TrackerEntity):
|
||||||
|
"""Base class for Teslemetry tracker entities."""
|
||||||
|
|
||||||
|
lat_key: str
|
||||||
|
lon_key: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vehicle: TeslemetryVehicleData,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the device tracker."""
|
||||||
|
super().__init__(vehicle, self.key)
|
||||||
|
|
||||||
|
def _async_update_attrs(self) -> None:
|
||||||
|
"""Update the attributes of the device tracker."""
|
||||||
|
|
||||||
|
self._attr_available = (
|
||||||
|
self.get(self.lat_key, False) is not None
|
||||||
|
and self.get(self.lon_key, False) is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latitude(self) -> float | None:
|
||||||
|
"""Return latitude value of the device."""
|
||||||
|
return self.get(self.lat_key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def longitude(self) -> float | None:
|
||||||
|
"""Return longitude value of the device."""
|
||||||
|
return self.get(self.lon_key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_type(self) -> SourceType:
|
||||||
|
"""Return the source type of the device tracker."""
|
||||||
|
return SourceType.GPS
|
||||||
|
|
||||||
|
|
||||||
|
class TeslemetryDeviceTrackerLocationEntity(TeslemetryDeviceTrackerEntity):
|
||||||
|
"""Vehicle location device tracker class."""
|
||||||
|
|
||||||
|
key = "location"
|
||||||
|
lat_key = "drive_state_latitude"
|
||||||
|
lon_key = "drive_state_longitude"
|
||||||
|
|
||||||
|
|
||||||
|
class TeslemetryDeviceTrackerRouteEntity(TeslemetryDeviceTrackerEntity):
|
||||||
|
"""Vehicle navigation device tracker class."""
|
||||||
|
|
||||||
|
key = "route"
|
||||||
|
lat_key = "drive_state_active_route_latitude"
|
||||||
|
lon_key = "drive_state_active_route_longitude"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def location_name(self) -> str | None:
|
||||||
|
"""Return a location name for the current location of the device."""
|
||||||
|
return self.get("drive_state_active_route_destination")
|
@@ -109,6 +109,7 @@
|
|||||||
"off": "mdi:car-seat"
|
"off": "mdi:car-seat"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
"components_customer_preferred_export_rule": {
|
"components_customer_preferred_export_rule": {
|
||||||
"default": "mdi:transmission-tower",
|
"default": "mdi:transmission-tower",
|
||||||
"state": {
|
"state": {
|
||||||
@@ -126,6 +127,14 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"device_tracker": {
|
||||||
|
"location": {
|
||||||
|
"default": "mdi:map-marker"
|
||||||
|
},
|
||||||
|
"route": {
|
||||||
|
"default": "mdi:routes"
|
||||||
|
}
|
||||||
|
},
|
||||||
"cover": {
|
"cover": {
|
||||||
"charge_state_charge_port_door_open": {
|
"charge_state_charge_port_door_open": {
|
||||||
"default": "mdi:ev-plug-ccs2"
|
"default": "mdi:ev-plug-ccs2"
|
||||||
|
@@ -111,6 +111,14 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"device_tracker": {
|
||||||
|
"location": {
|
||||||
|
"name": "Location"
|
||||||
|
},
|
||||||
|
"route": {
|
||||||
|
"name": "Route"
|
||||||
|
}
|
||||||
|
},
|
||||||
"lock": {
|
"lock": {
|
||||||
"charge_state_charge_port_latch": {
|
"charge_state_charge_port_latch": {
|
||||||
"name": "Charge cable lock"
|
"name": "Charge cable lock"
|
||||||
|
@@ -13,7 +13,7 @@
|
|||||||
"velbus-packet",
|
"velbus-packet",
|
||||||
"velbus-protocol"
|
"velbus-protocol"
|
||||||
],
|
],
|
||||||
"requirements": ["velbus-aio==2024.4.1"],
|
"requirements": ["velbus-aio==2024.5.1"],
|
||||||
"usb": [
|
"usb": [
|
||||||
{
|
{
|
||||||
"vid": "10CF",
|
"vid": "10CF",
|
||||||
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Awaitable
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -157,16 +156,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
||||||
"""Unload Withings config entry."""
|
"""Unload vera config entry."""
|
||||||
controller_data: ControllerData = get_controller_data(hass, config_entry)
|
controller_data: ControllerData = get_controller_data(hass, config_entry)
|
||||||
|
await asyncio.gather(
|
||||||
tasks: list[Awaitable] = [
|
*(
|
||||||
hass.config_entries.async_forward_entry_unload(config_entry, platform)
|
hass.config_entries.async_unload_platforms(
|
||||||
for platform in get_configured_platforms(controller_data)
|
config_entry, get_configured_platforms(controller_data)
|
||||||
]
|
),
|
||||||
tasks.append(hass.async_add_executor_job(controller_data.controller.stop))
|
hass.async_add_executor_job(controller_data.controller.stop),
|
||||||
await asyncio.gather(*tasks)
|
)
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
"""Support for Zigbee Home Automation devices."""
|
"""Support for Zigbee Home Automation devices."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
@@ -238,12 +237,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
|
|||||||
websocket_api.async_unload_api(hass)
|
websocket_api.async_unload_api(hass)
|
||||||
|
|
||||||
# our components don't have unload methods so no need to look at return values
|
# our components don't have unload methods so no need to look at return values
|
||||||
await asyncio.gather(
|
await hass.config_entries.async_unload_platforms(config_entry, PLATFORMS)
|
||||||
*(
|
|
||||||
hass.config_entries.async_forward_entry_unload(config_entry, platform)
|
|
||||||
for platform in PLATFORMS
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Coroutine
|
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -958,14 +957,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
"""Unload a config entry."""
|
"""Unload a config entry."""
|
||||||
client: ZwaveClient = entry.runtime_data[DATA_CLIENT]
|
client: ZwaveClient = entry.runtime_data[DATA_CLIENT]
|
||||||
driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS]
|
driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS]
|
||||||
|
platforms = [
|
||||||
tasks: list[Coroutine] = [
|
platform
|
||||||
hass.config_entries.async_forward_entry_unload(entry, platform)
|
|
||||||
for platform, task in driver_events.platform_setup_tasks.items()
|
for platform, task in driver_events.platform_setup_tasks.items()
|
||||||
if not task.cancel()
|
if not task.cancel()
|
||||||
]
|
]
|
||||||
|
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
|
||||||
unload_ok = all(await asyncio.gather(*tasks)) if tasks else True
|
|
||||||
|
|
||||||
if client.connected and client.driver:
|
if client.connected and client.driver:
|
||||||
await async_disable_server_logging_if_needed(hass, entry, client.driver)
|
await async_disable_server_logging_if_needed(hass, entry, client.driver)
|
||||||
|
@@ -3,12 +3,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
|
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
|
||||||
|
from homeassistant.components.conversation.trace import (
|
||||||
|
ConversationTraceEventType,
|
||||||
|
async_conversation_trace_append,
|
||||||
|
)
|
||||||
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@@ -116,6 +120,10 @@ class API(ABC):
|
|||||||
|
|
||||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||||
"""Call a LLM tool, validate args and return the response."""
|
"""Call a LLM tool, validate args and return the response."""
|
||||||
|
async_conversation_trace_append(
|
||||||
|
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
|
||||||
|
)
|
||||||
|
|
||||||
for tool in self.async_get_tools():
|
for tool in self.async_get_tools():
|
||||||
if tool.name == tool_input.tool_name:
|
if tool.name == tool_input.tool_name:
|
||||||
break
|
break
|
||||||
@@ -191,7 +199,10 @@ class AssistAPI(API):
|
|||||||
|
|
||||||
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
||||||
"""Return the prompt for the API."""
|
"""Return the prompt for the API."""
|
||||||
prompt = "Call the intent tools to control Home Assistant. Just pass the name to the intent."
|
prompt = (
|
||||||
|
"Call the intent tools to control Home Assistant. "
|
||||||
|
"Just pass the name to the intent."
|
||||||
|
)
|
||||||
if tool_input.device_id:
|
if tool_input.device_id:
|
||||||
device_reg = device_registry.async_get(self.hass)
|
device_reg = device_registry.async_get(self.hass)
|
||||||
device = device_reg.async_get(tool_input.device_id)
|
device = device_reg.async_get(tool_input.device_id)
|
||||||
|
@@ -1821,7 +1821,7 @@ pyegps==0.2.5
|
|||||||
pyenphase==1.20.3
|
pyenphase==1.20.3
|
||||||
|
|
||||||
# homeassistant.components.envisalink
|
# homeassistant.components.envisalink
|
||||||
pyenvisalink==4.6
|
pyenvisalink==4.7
|
||||||
|
|
||||||
# homeassistant.components.ephember
|
# homeassistant.components.ephember
|
||||||
pyephember==0.3.1
|
pyephember==0.3.1
|
||||||
@@ -2817,7 +2817,7 @@ vallox-websocket-api==5.1.1
|
|||||||
vehicle==2.2.1
|
vehicle==2.2.1
|
||||||
|
|
||||||
# homeassistant.components.velbus
|
# homeassistant.components.velbus
|
||||||
velbus-aio==2024.4.1
|
velbus-aio==2024.5.1
|
||||||
|
|
||||||
# homeassistant.components.venstar
|
# homeassistant.components.venstar
|
||||||
venstarcolortouch==0.19
|
venstarcolortouch==0.19
|
||||||
|
@@ -2185,7 +2185,7 @@ vallox-websocket-api==5.1.1
|
|||||||
vehicle==2.2.1
|
vehicle==2.2.1
|
||||||
|
|
||||||
# homeassistant.components.velbus
|
# homeassistant.components.velbus
|
||||||
velbus-aio==2024.4.1
|
velbus-aio==2024.5.1
|
||||||
|
|
||||||
# homeassistant.components.venstar
|
# homeassistant.components.venstar
|
||||||
venstarcolortouch==0.19
|
venstarcolortouch==0.19
|
||||||
|
@@ -117,7 +117,6 @@ NO_IOT_CLASS = [
|
|||||||
# https://github.com/home-assistant/developers.home-assistant/pull/1512
|
# https://github.com/home-assistant/developers.home-assistant/pull/1512
|
||||||
NO_DIAGNOSTICS = [
|
NO_DIAGNOSTICS = [
|
||||||
"dlna_dms",
|
"dlna_dms",
|
||||||
"fronius",
|
|
||||||
"gdacs",
|
"gdacs",
|
||||||
"geonetnz_quakes",
|
"geonetnz_quakes",
|
||||||
"google_assistant_sdk",
|
"google_assistant_sdk",
|
||||||
|
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
from homeassistant.core import Context, HomeAssistant, State
|
from homeassistant.core import Context, HomeAssistant, State
|
||||||
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.util.dt as dt_util
|
import homeassistant.util.dt as dt_util
|
||||||
|
|
||||||
@@ -31,6 +33,11 @@ async def test_state_set_and_restore(hass: HomeAssistant) -> None:
|
|||||||
) as mock_process,
|
) as mock_process,
|
||||||
patch("homeassistant.util.dt.utcnow", return_value=now),
|
patch("homeassistant.util.dt.utcnow", return_value=now),
|
||||||
):
|
):
|
||||||
|
intent_response = intent.IntentResponse(language="en")
|
||||||
|
intent_response.async_set_speech("response text")
|
||||||
|
mock_process.return_value = conversation.ConversationResult(
|
||||||
|
response=intent_response,
|
||||||
|
)
|
||||||
await hass.services.async_call(
|
await hass.services.async_call(
|
||||||
"conversation",
|
"conversation",
|
||||||
"process",
|
"process",
|
||||||
|
80
tests/components/conversation/test_trace.py
Normal file
80
tests/components/conversation/test_trace.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Test for the conversation traces."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.components.conversation import trace
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_components(hass: HomeAssistant):
|
||||||
|
"""Initialize relevant components with empty configs."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
assert await async_setup_component(hass, "intent", {})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_converation_trace(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: None,
|
||||||
|
sl_setup: None,
|
||||||
|
) -> None:
|
||||||
|
"""Test tracing a conversation."""
|
||||||
|
await conversation.async_converse(
|
||||||
|
hass, "add apples to my shopping list", None, Context()
|
||||||
|
)
|
||||||
|
|
||||||
|
traces = trace.async_get_traces()
|
||||||
|
assert traces
|
||||||
|
last_trace = traces[-1].as_dict()
|
||||||
|
assert last_trace.get("events")
|
||||||
|
assert len(last_trace.get("events")) == 1
|
||||||
|
trace_event = last_trace["events"][0]
|
||||||
|
assert (
|
||||||
|
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
||||||
|
)
|
||||||
|
assert trace_event.get("data")
|
||||||
|
assert trace_event["data"].get("text") == "add apples to my shopping list"
|
||||||
|
assert last_trace.get("result")
|
||||||
|
assert (
|
||||||
|
last_trace["result"]
|
||||||
|
.get("response", {})
|
||||||
|
.get("speech", {})
|
||||||
|
.get("plain", {})
|
||||||
|
.get("speech")
|
||||||
|
== "Added apples"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_converation_trace_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components: None,
|
||||||
|
sl_setup: None,
|
||||||
|
) -> None:
|
||||||
|
"""Test tracing a conversation."""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
|
||||||
|
side_effect=HomeAssistantError("Failed to talk to agent"),
|
||||||
|
),
|
||||||
|
pytest.raises(HomeAssistantError),
|
||||||
|
):
|
||||||
|
await conversation.async_converse(
|
||||||
|
hass, "add apples to my shopping list", None, Context()
|
||||||
|
)
|
||||||
|
|
||||||
|
traces = trace.async_get_traces()
|
||||||
|
assert traces
|
||||||
|
last_trace = traces[-1].as_dict()
|
||||||
|
assert last_trace.get("events")
|
||||||
|
assert len(last_trace.get("events")) == 1
|
||||||
|
trace_event = last_trace["events"][0]
|
||||||
|
assert (
|
||||||
|
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
||||||
|
)
|
||||||
|
assert last_trace.get("error") == "Failed to talk to agent"
|
@@ -25,6 +25,7 @@ async def setup_fronius_integration(
|
|||||||
"""Create the Fronius integration."""
|
"""Create the Fronius integration."""
|
||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
|
entry_id="f1e2b9837e8adaed6fa682acaa216fd8",
|
||||||
unique_id=unique_id, # has to match mocked logger unique_id
|
unique_id=unique_id, # has to match mocked logger unique_id
|
||||||
data={
|
data={
|
||||||
CONF_HOST: MOCK_HOST,
|
CONF_HOST: MOCK_HOST,
|
||||||
|
370
tests/components/fronius/snapshots/test_diagnostics.ambr
Normal file
370
tests/components/fronius/snapshots/test_diagnostics.ambr
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_diagnostics
|
||||||
|
dict({
|
||||||
|
'config_entry': dict({
|
||||||
|
'data': dict({
|
||||||
|
'host': 'http://fronius',
|
||||||
|
'is_logger': True,
|
||||||
|
}),
|
||||||
|
'disabled_by': None,
|
||||||
|
'domain': 'fronius',
|
||||||
|
'entry_id': 'f1e2b9837e8adaed6fa682acaa216fd8',
|
||||||
|
'minor_version': 1,
|
||||||
|
'options': dict({
|
||||||
|
}),
|
||||||
|
'pref_disable_new_entities': False,
|
||||||
|
'pref_disable_polling': False,
|
||||||
|
'source': 'user',
|
||||||
|
'title': 'Mock Title',
|
||||||
|
'unique_id': '**REDACTED**',
|
||||||
|
'version': 1,
|
||||||
|
}),
|
||||||
|
'coordinators': dict({
|
||||||
|
'inverters': dict({
|
||||||
|
'1': dict({
|
||||||
|
'current_ac': dict({
|
||||||
|
'unit': 'A',
|
||||||
|
'value': 5.19,
|
||||||
|
}),
|
||||||
|
'current_dc': dict({
|
||||||
|
'unit': 'A',
|
||||||
|
'value': 2.19,
|
||||||
|
}),
|
||||||
|
'energy_day': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 1113,
|
||||||
|
}),
|
||||||
|
'energy_total': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 44188000,
|
||||||
|
}),
|
||||||
|
'energy_year': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 25508798,
|
||||||
|
}),
|
||||||
|
'error_code': dict({
|
||||||
|
'value': 0,
|
||||||
|
}),
|
||||||
|
'frequency_ac': dict({
|
||||||
|
'unit': 'Hz',
|
||||||
|
'value': 49.94,
|
||||||
|
}),
|
||||||
|
'led_color': dict({
|
||||||
|
'value': 2,
|
||||||
|
}),
|
||||||
|
'led_state': dict({
|
||||||
|
'value': 0,
|
||||||
|
}),
|
||||||
|
'power_ac': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': 1190,
|
||||||
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'Code': 0,
|
||||||
|
'Reason': '',
|
||||||
|
'UserMessage': '',
|
||||||
|
}),
|
||||||
|
'status_code': dict({
|
||||||
|
'value': 7,
|
||||||
|
}),
|
||||||
|
'timestamp': dict({
|
||||||
|
'value': '2021-10-07T10:01:17+02:00',
|
||||||
|
}),
|
||||||
|
'voltage_ac': dict({
|
||||||
|
'unit': 'V',
|
||||||
|
'value': 227.9,
|
||||||
|
}),
|
||||||
|
'voltage_dc': dict({
|
||||||
|
'unit': 'V',
|
||||||
|
'value': 518,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'logger': dict({
|
||||||
|
'system': dict({
|
||||||
|
'cash_factor': dict({
|
||||||
|
'unit': 'EUR/kWh',
|
||||||
|
'value': 0.07800000160932541,
|
||||||
|
}),
|
||||||
|
'co2_factor': dict({
|
||||||
|
'unit': 'kg/kWh',
|
||||||
|
'value': 0.5299999713897705,
|
||||||
|
}),
|
||||||
|
'delivery_factor': dict({
|
||||||
|
'unit': 'EUR/kWh',
|
||||||
|
'value': 0.15000000596046448,
|
||||||
|
}),
|
||||||
|
'hardware_platform': dict({
|
||||||
|
'value': 'wilma',
|
||||||
|
}),
|
||||||
|
'hardware_version': dict({
|
||||||
|
'value': '2.4E',
|
||||||
|
}),
|
||||||
|
'product_type': dict({
|
||||||
|
'value': 'fronius-datamanager-card',
|
||||||
|
}),
|
||||||
|
'software_version': dict({
|
||||||
|
'value': '3.18.7-1',
|
||||||
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'Code': 0,
|
||||||
|
'Reason': '',
|
||||||
|
'UserMessage': '',
|
||||||
|
}),
|
||||||
|
'time_zone': dict({
|
||||||
|
'value': 'CEST',
|
||||||
|
}),
|
||||||
|
'time_zone_location': dict({
|
||||||
|
'value': 'Vienna',
|
||||||
|
}),
|
||||||
|
'timestamp': dict({
|
||||||
|
'value': '2021-10-06T23:56:32+02:00',
|
||||||
|
}),
|
||||||
|
'unique_identifier': '**REDACTED**',
|
||||||
|
'utc_offset': dict({
|
||||||
|
'value': 7200,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'meter': dict({
|
||||||
|
'0': dict({
|
||||||
|
'current_ac_phase_1': dict({
|
||||||
|
'unit': 'A',
|
||||||
|
'value': 7.755,
|
||||||
|
}),
|
||||||
|
'current_ac_phase_2': dict({
|
||||||
|
'unit': 'A',
|
||||||
|
'value': 6.68,
|
||||||
|
}),
|
||||||
|
'current_ac_phase_3': dict({
|
||||||
|
'unit': 'A',
|
||||||
|
'value': 10.102,
|
||||||
|
}),
|
||||||
|
'enable': dict({
|
||||||
|
'value': 1,
|
||||||
|
}),
|
||||||
|
'energy_reactive_ac_consumed': dict({
|
||||||
|
'unit': 'VArh',
|
||||||
|
'value': 59960790,
|
||||||
|
}),
|
||||||
|
'energy_reactive_ac_produced': dict({
|
||||||
|
'unit': 'VArh',
|
||||||
|
'value': 723160,
|
||||||
|
}),
|
||||||
|
'energy_real_ac_minus': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 35623065,
|
||||||
|
}),
|
||||||
|
'energy_real_ac_plus': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 15303334,
|
||||||
|
}),
|
||||||
|
'energy_real_consumed': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 15303334,
|
||||||
|
}),
|
||||||
|
'energy_real_produced': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 35623065,
|
||||||
|
}),
|
||||||
|
'frequency_phase_average': dict({
|
||||||
|
'unit': 'Hz',
|
||||||
|
'value': 50,
|
||||||
|
}),
|
||||||
|
'manufacturer': dict({
|
||||||
|
'value': 'Fronius',
|
||||||
|
}),
|
||||||
|
'meter_location': dict({
|
||||||
|
'value': 0,
|
||||||
|
}),
|
||||||
|
'model': dict({
|
||||||
|
'value': 'Smart Meter 63A',
|
||||||
|
}),
|
||||||
|
'power_apparent': dict({
|
||||||
|
'unit': 'VA',
|
||||||
|
'value': 5592.57,
|
||||||
|
}),
|
||||||
|
'power_apparent_phase_1': dict({
|
||||||
|
'unit': 'VA',
|
||||||
|
'value': 1772.793,
|
||||||
|
}),
|
||||||
|
'power_apparent_phase_2': dict({
|
||||||
|
'unit': 'VA',
|
||||||
|
'value': 1527.048,
|
||||||
|
}),
|
||||||
|
'power_apparent_phase_3': dict({
|
||||||
|
'unit': 'VA',
|
||||||
|
'value': 2333.562,
|
||||||
|
}),
|
||||||
|
'power_factor': dict({
|
||||||
|
'value': 1,
|
||||||
|
}),
|
||||||
|
'power_factor_phase_1': dict({
|
||||||
|
'value': -0.99,
|
||||||
|
}),
|
||||||
|
'power_factor_phase_2': dict({
|
||||||
|
'value': -0.99,
|
||||||
|
}),
|
||||||
|
'power_factor_phase_3': dict({
|
||||||
|
'value': 0.99,
|
||||||
|
}),
|
||||||
|
'power_reactive': dict({
|
||||||
|
'unit': 'VAr',
|
||||||
|
'value': 2.87,
|
||||||
|
}),
|
||||||
|
'power_reactive_phase_1': dict({
|
||||||
|
'unit': 'VAr',
|
||||||
|
'value': 51.48,
|
||||||
|
}),
|
||||||
|
'power_reactive_phase_2': dict({
|
||||||
|
'unit': 'VAr',
|
||||||
|
'value': 115.63,
|
||||||
|
}),
|
||||||
|
'power_reactive_phase_3': dict({
|
||||||
|
'unit': 'VAr',
|
||||||
|
'value': -164.24,
|
||||||
|
}),
|
||||||
|
'power_real': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': 5592.57,
|
||||||
|
}),
|
||||||
|
'power_real_phase_1': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': 1765.55,
|
||||||
|
}),
|
||||||
|
'power_real_phase_2': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': 1515.8,
|
||||||
|
}),
|
||||||
|
'power_real_phase_3': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': 2311.22,
|
||||||
|
}),
|
||||||
|
'serial': '**REDACTED**',
|
||||||
|
'visible': dict({
|
||||||
|
'value': 1,
|
||||||
|
}),
|
||||||
|
'voltage_ac_phase_1': dict({
|
||||||
|
'unit': 'V',
|
||||||
|
'value': 228.6,
|
||||||
|
}),
|
||||||
|
'voltage_ac_phase_2': dict({
|
||||||
|
'unit': 'V',
|
||||||
|
'value': 228.6,
|
||||||
|
}),
|
||||||
|
'voltage_ac_phase_3': dict({
|
||||||
|
'unit': 'V',
|
||||||
|
'value': 231,
|
||||||
|
}),
|
||||||
|
'voltage_ac_phase_to_phase_12': dict({
|
||||||
|
'unit': 'V',
|
||||||
|
'value': 395.9,
|
||||||
|
}),
|
||||||
|
'voltage_ac_phase_to_phase_23': dict({
|
||||||
|
'unit': 'V',
|
||||||
|
'value': 398,
|
||||||
|
}),
|
||||||
|
'voltage_ac_phase_to_phase_31': dict({
|
||||||
|
'unit': 'V',
|
||||||
|
'value': 398,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'ohmpilot': None,
|
||||||
|
'power_flow': dict({
|
||||||
|
'power_flow': dict({
|
||||||
|
'energy_day': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 1101.7000732421875,
|
||||||
|
}),
|
||||||
|
'energy_total': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 44188000,
|
||||||
|
}),
|
||||||
|
'energy_year': dict({
|
||||||
|
'unit': 'Wh',
|
||||||
|
'value': 25508788,
|
||||||
|
}),
|
||||||
|
'meter_location': dict({
|
||||||
|
'value': 'grid',
|
||||||
|
}),
|
||||||
|
'meter_mode': dict({
|
||||||
|
'value': 'meter',
|
||||||
|
}),
|
||||||
|
'power_battery': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': None,
|
||||||
|
}),
|
||||||
|
'power_grid': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': 1703.74,
|
||||||
|
}),
|
||||||
|
'power_load': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': -2814.74,
|
||||||
|
}),
|
||||||
|
'power_photovoltaics': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': 1111,
|
||||||
|
}),
|
||||||
|
'relative_autonomy': dict({
|
||||||
|
'unit': '%',
|
||||||
|
'value': 39.4707859340472,
|
||||||
|
}),
|
||||||
|
'relative_self_consumption': dict({
|
||||||
|
'unit': '%',
|
||||||
|
'value': 100,
|
||||||
|
}),
|
||||||
|
'status': dict({
|
||||||
|
'Code': 0,
|
||||||
|
'Reason': '',
|
||||||
|
'UserMessage': '',
|
||||||
|
}),
|
||||||
|
'timestamp': dict({
|
||||||
|
'value': '2021-10-07T10:00:43+02:00',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'storage': None,
|
||||||
|
}),
|
||||||
|
'inverter_info': dict({
|
||||||
|
'inverters': list([
|
||||||
|
dict({
|
||||||
|
'custom_name': dict({
|
||||||
|
'value': 'Symo 20',
|
||||||
|
}),
|
||||||
|
'device_id': dict({
|
||||||
|
'value': '1',
|
||||||
|
}),
|
||||||
|
'device_type': dict({
|
||||||
|
'manufacturer': 'Fronius',
|
||||||
|
'model': 'Symo 20.0-3-M',
|
||||||
|
'value': 121,
|
||||||
|
}),
|
||||||
|
'error_code': dict({
|
||||||
|
'value': 0,
|
||||||
|
}),
|
||||||
|
'pv_power': dict({
|
||||||
|
'unit': 'W',
|
||||||
|
'value': 23100,
|
||||||
|
}),
|
||||||
|
'show': dict({
|
||||||
|
'value': 1,
|
||||||
|
}),
|
||||||
|
'status_code': dict({
|
||||||
|
'value': 7,
|
||||||
|
}),
|
||||||
|
'unique_id': '**REDACTED**',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'status': dict({
|
||||||
|
'Code': 0,
|
||||||
|
'Reason': '',
|
||||||
|
'UserMessage': '',
|
||||||
|
}),
|
||||||
|
'timestamp': dict({
|
||||||
|
'value': '2021-10-07T13:41:00+02:00',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
31
tests/components/fronius/test_diagnostics.py
Normal file
31
tests/components/fronius/test_diagnostics.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Tests for the diagnostics data provided by the KNX integration."""
|
||||||
|
|
||||||
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from . import mock_responses, setup_fronius_integration
|
||||||
|
|
||||||
|
from tests.components.diagnostics import get_diagnostics_for_config_entry
|
||||||
|
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||||
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
|
|
||||||
|
async def test_diagnostics(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test diagnostics."""
|
||||||
|
mock_responses(aioclient_mock)
|
||||||
|
entry = await setup_fronius_integration(hass)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await get_diagnostics_for_config_entry(
|
||||||
|
hass,
|
||||||
|
hass_client,
|
||||||
|
entry,
|
||||||
|
)
|
||||||
|
== snapshot
|
||||||
|
)
|
@@ -1,4 +1,114 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
|
# name: test_chat_history
|
||||||
|
list([
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'generation_config': dict({
|
||||||
|
'max_output_tokens': 150,
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_k': 64,
|
||||||
|
'top_p': 0.95,
|
||||||
|
}),
|
||||||
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
|
'safety_settings': dict({
|
||||||
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
}),
|
||||||
|
'tools': None,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': '''
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
|
''',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': 'Ok',
|
||||||
|
'role': 'model',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
'1st user request',
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'generation_config': dict({
|
||||||
|
'max_output_tokens': 150,
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_k': 64,
|
||||||
|
'top_p': 0.95,
|
||||||
|
}),
|
||||||
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
|
'safety_settings': dict({
|
||||||
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
|
}),
|
||||||
|
'tools': None,
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat',
|
||||||
|
tuple(
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
'history': list([
|
||||||
|
dict({
|
||||||
|
'parts': '''
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
|
''',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': 'Ok',
|
||||||
|
'role': 'model',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': '1st user request',
|
||||||
|
'role': 'user',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'parts': '1st model response',
|
||||||
|
'role': 'model',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
tuple(
|
||||||
|
'().start_chat().send_message_async',
|
||||||
|
tuple(
|
||||||
|
'2nd user request',
|
||||||
|
),
|
||||||
|
dict({
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_default_prompt[config_entry_options0-None]
|
# name: test_default_prompt[config_entry_options0-None]
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
@@ -14,10 +124,10 @@
|
|||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-1.5-flash-latest',
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
'safety_settings': dict({
|
'safety_settings': dict({
|
||||||
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HATE': 'BLOCK_LOW_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
@@ -29,7 +139,10 @@
|
|||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
dict({
|
dict({
|
||||||
'parts': 'Answer in plain text. Keep it simple and to the point.',
|
'parts': '''
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
|
''',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
@@ -64,10 +177,10 @@
|
|||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-1.5-flash-latest',
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
'safety_settings': dict({
|
'safety_settings': dict({
|
||||||
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HATE': 'BLOCK_LOW_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
@@ -79,7 +192,10 @@
|
|||||||
dict({
|
dict({
|
||||||
'history': list([
|
'history': list([
|
||||||
dict({
|
dict({
|
||||||
'parts': 'Answer in plain text. Keep it simple and to the point.',
|
'parts': '''
|
||||||
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||||
|
''',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
@@ -114,10 +230,10 @@
|
|||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-1.5-flash-latest',
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
'safety_settings': dict({
|
'safety_settings': dict({
|
||||||
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HATE': 'BLOCK_LOW_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
@@ -130,8 +246,8 @@
|
|||||||
'history': list([
|
'history': list([
|
||||||
dict({
|
dict({
|
||||||
'parts': '''
|
'parts': '''
|
||||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
|
||||||
Answer in plain text. Keep it simple and to the point.
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
||||||
''',
|
''',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
@@ -167,10 +283,10 @@
|
|||||||
}),
|
}),
|
||||||
'model_name': 'models/gemini-1.5-flash-latest',
|
'model_name': 'models/gemini-1.5-flash-latest',
|
||||||
'safety_settings': dict({
|
'safety_settings': dict({
|
||||||
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
|
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
|
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'HATE': 'BLOCK_LOW_AND_ABOVE',
|
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
|
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||||
}),
|
}),
|
||||||
'tools': None,
|
'tools': None,
|
||||||
}),
|
}),
|
||||||
@@ -183,8 +299,8 @@
|
|||||||
'history': list([
|
'history': list([
|
||||||
dict({
|
dict({
|
||||||
'parts': '''
|
'parts': '''
|
||||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
|
||||||
Answer in plain text. Keep it simple and to the point.
|
Answer in plain text. Keep it simple and to the point.
|
||||||
|
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
||||||
''',
|
''',
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
}),
|
}),
|
||||||
|
@@ -2,12 +2,14 @@
|
|||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from google.api_core.exceptions import ClientError
|
from google.api_core.exceptions import GoogleAPICallError
|
||||||
|
import google.generativeai.types as genai_types
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@@ -150,6 +152,57 @@ async def test_default_prompt(
|
|||||||
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chat_history(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the agent keeps track of the chat history."""
|
||||||
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
|
chat_response = MagicMock()
|
||||||
|
mock_chat.send_message_async.return_value = chat_response
|
||||||
|
mock_part = MagicMock()
|
||||||
|
mock_part.function_call = None
|
||||||
|
chat_response.parts = [mock_part]
|
||||||
|
chat_response.text = "1st model response"
|
||||||
|
mock_chat.history = [
|
||||||
|
{"role": "user", "parts": "prompt"},
|
||||||
|
{"role": "model", "parts": "Ok"},
|
||||||
|
{"role": "user", "parts": "1st user request"},
|
||||||
|
{"role": "model", "parts": "1st model response"},
|
||||||
|
]
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"1st user request",
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||||
|
== "1st model response"
|
||||||
|
)
|
||||||
|
chat_response.text = "2nd model response"
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"2nd user request",
|
||||||
|
result.conversation_id,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
|
assert (
|
||||||
|
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||||
|
== "2nd model response"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||||
)
|
)
|
||||||
@@ -233,6 +286,20 @@ async def test_function_call(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test conversating tracing
|
||||||
|
traces = trace.async_get_traces()
|
||||||
|
assert traces
|
||||||
|
last_trace = traces[-1].as_dict()
|
||||||
|
trace_events = last_trace.get("events", [])
|
||||||
|
assert [event["event_type"] for event in trace_events] == [
|
||||||
|
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
|
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
||||||
|
]
|
||||||
|
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||||
|
detail_event = trace_events[1]
|
||||||
|
assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"]
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||||
@@ -325,7 +392,7 @@ async def test_error_handling(
|
|||||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
mock_chat = AsyncMock()
|
mock_chat = AsyncMock()
|
||||||
mock_model.return_value.start_chat.return_value = mock_chat
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
mock_chat.send_message_async.side_effect = ClientError("some error")
|
mock_chat.send_message_async.side_effect = GoogleAPICallError("some error")
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
)
|
)
|
||||||
@@ -340,7 +407,28 @@ async def test_error_handling(
|
|||||||
async def test_blocked_response(
|
async def test_blocked_response(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test response was blocked."""
|
"""Test blocked response."""
|
||||||
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
|
mock_chat = AsyncMock()
|
||||||
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
|
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException(
|
||||||
|
"finish_reason: SAFETY\n"
|
||||||
|
)
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
|
assert result.response.error_code == "unknown", result
|
||||||
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
|
"The message got blocked by your safety settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_empty_response(
|
||||||
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test empty response."""
|
||||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||||
mock_chat = AsyncMock()
|
mock_chat = AsyncMock()
|
||||||
mock_model.return_value.start_chat.return_value = mock_chat
|
mock_model.return_value.start_chat.return_value = mock_chat
|
||||||
@@ -358,6 +446,32 @@ async def test_blocked_response(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_llm_api(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_init_component,
|
||||||
|
) -> None:
|
||||||
|
"""Test handling of invalid llm api."""
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry,
|
||||||
|
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
"hello",
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
|
assert result.response.error_code == "unknown", result
|
||||||
|
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||||
|
"Error preparing LLM API: API invalid_llm_api not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_template_error(
|
async def test_template_error(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@@ -529,16 +529,16 @@ async def test_non_unique_triggers(
|
|||||||
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
|
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(calls) == 2
|
assert len(calls) == 2
|
||||||
assert calls[0].data["some"] == "press1"
|
all_calls = {calls[0].data["some"], calls[1].data["some"]}
|
||||||
assert calls[1].data["some"] == "press2"
|
assert all_calls == {"press1", "press2"}
|
||||||
|
|
||||||
# Trigger second config references to same trigger
|
# Trigger second config references to same trigger
|
||||||
# and triggers both attached instances.
|
# and triggers both attached instances.
|
||||||
async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press")
|
async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press")
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert len(calls) == 2
|
assert len(calls) == 2
|
||||||
assert calls[0].data["some"] == "press1"
|
all_calls = {calls[0].data["some"], calls[1].data["some"]}
|
||||||
assert calls[1].data["some"] == "press2"
|
assert all_calls == {"press1", "press2"}
|
||||||
|
|
||||||
# Removing the first trigger will clean up
|
# Removing the first trigger will clean up
|
||||||
calls.clear()
|
calls.clear()
|
||||||
|
@@ -4,6 +4,7 @@ import asyncio
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
@@ -1050,6 +1051,27 @@ async def test_subscribe_topic_not_initialize(
|
|||||||
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subscribe_mqtt_config_entry_disabled(
|
||||||
|
hass: HomeAssistant, mqtt_mock: MqttMockHAClient
|
||||||
|
) -> None:
|
||||||
|
"""Test the subscription of a topic when MQTT config entry is disabled."""
|
||||||
|
mqtt_mock.connected = True
|
||||||
|
|
||||||
|
mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
|
||||||
|
assert mqtt_config_entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
|
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
|
||||||
|
assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED
|
||||||
|
|
||||||
|
await hass.config_entries.async_set_disabled_by(
|
||||||
|
mqtt_config_entry.entry_id, ConfigEntryDisabler.USER
|
||||||
|
)
|
||||||
|
mqtt_mock.connected = False
|
||||||
|
|
||||||
|
with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"):
|
||||||
|
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||||
|
|
||||||
|
|
||||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||||
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
|
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
|
||||||
async def test_subscribe_and_resubscribe(
|
async def test_subscribe_and_resubscribe(
|
||||||
@@ -2912,8 +2934,8 @@ async def test_message_callback_exception_gets_logged(
|
|||||||
await mqtt_mock_entry()
|
await mqtt_mock_entry()
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def bad_handler(*args) -> None:
|
def bad_handler(msg: ReceiveMessage) -> None:
|
||||||
"""Record calls."""
|
"""Handle callback."""
|
||||||
raise ValueError("This is a bad message callback")
|
raise ValueError("This is a bad message callback")
|
||||||
|
|
||||||
await mqtt.async_subscribe(hass, "test-topic", bad_handler)
|
await mqtt.async_subscribe(hass, "test-topic", bad_handler)
|
||||||
@@ -2926,6 +2948,40 @@ async def test_message_callback_exception_gets_logged(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_fail_on_log_exception
|
||||||
|
async def test_message_partial_callback_exception_gets_logged(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test exception raised by message handler."""
|
||||||
|
await mqtt_mock_entry()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def bad_handler(msg: ReceiveMessage) -> None:
|
||||||
|
"""Handle callback."""
|
||||||
|
raise ValueError("This is a bad message callback")
|
||||||
|
|
||||||
|
def parial_handler(
|
||||||
|
msg_callback: MessageCallbackType,
|
||||||
|
attributes: set[str],
|
||||||
|
msg: ReceiveMessage,
|
||||||
|
) -> None:
|
||||||
|
"""Partial callback handler."""
|
||||||
|
msg_callback(msg)
|
||||||
|
|
||||||
|
await mqtt.async_subscribe(
|
||||||
|
hass, "test-topic", partial(parial_handler, bad_handler, {"some_attr"})
|
||||||
|
)
|
||||||
|
async_fire_mqtt_message(hass, "test-topic", "test")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"Exception in bad_handler when handling msg on 'test-topic':"
|
||||||
|
" 'test'" in caplog.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_mqtt_ws_subscription(
|
async def test_mqtt_ws_subscription(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
@@ -3787,7 +3843,7 @@ async def test_unload_config_entry(
|
|||||||
async def test_publish_or_subscribe_without_valid_config_entry(
|
async def test_publish_or_subscribe_without_valid_config_entry(
|
||||||
hass: HomeAssistant, record_calls: MessageCallbackType
|
hass: HomeAssistant, record_calls: MessageCallbackType
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test internal publish function with bas use cases."""
|
"""Test internal publish function with bad use cases."""
|
||||||
with pytest.raises(HomeAssistantError):
|
with pytest.raises(HomeAssistantError):
|
||||||
await mqtt.async_publish(
|
await mqtt.async_publish(
|
||||||
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None
|
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None
|
||||||
|
@@ -11,8 +11,12 @@ from .const import DEFAULT_FORECAST, DEFAULT_OBSERVATION
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_simple_nws():
|
def mock_simple_nws():
|
||||||
"""Mock pynws SimpleNWS with default values."""
|
"""Mock pynws SimpleNWS with default values."""
|
||||||
|
# set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
|
||||||
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws:
|
with (
|
||||||
|
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
|
||||||
|
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
|
||||||
|
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
|
||||||
|
):
|
||||||
instance = mock_nws.return_value
|
instance = mock_nws.return_value
|
||||||
instance.set_station = AsyncMock(return_value=None)
|
instance.set_station = AsyncMock(return_value=None)
|
||||||
instance.update_observation = AsyncMock(return_value=None)
|
instance.update_observation = AsyncMock(return_value=None)
|
||||||
@@ -29,7 +33,12 @@ def mock_simple_nws():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_simple_nws_times_out():
|
def mock_simple_nws_times_out():
|
||||||
"""Mock pynws SimpleNWS that times out."""
|
"""Mock pynws SimpleNWS that times out."""
|
||||||
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws:
|
# set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
|
||||||
|
with (
|
||||||
|
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
|
||||||
|
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
|
||||||
|
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
|
||||||
|
):
|
||||||
instance = mock_nws.return_value
|
instance = mock_nws.return_value
|
||||||
instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError)
|
instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||||
instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError)
|
instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
"""Tests for the NWS weather component."""
|
"""Tests for the NWS weather component."""
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from freezegun.api import FrozenDateTimeFactory
|
from freezegun.api import FrozenDateTimeFactory
|
||||||
@@ -24,7 +23,6 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.helpers import entity_registry as er
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.util.dt as dt_util
|
|
||||||
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
|
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
@@ -127,10 +125,6 @@ async def test_data_caching_error_observation(
|
|||||||
caplog,
|
caplog,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test caching of data with errors."""
|
"""Test caching of data with errors."""
|
||||||
with (
|
|
||||||
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
|
|
||||||
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
|
|
||||||
):
|
|
||||||
instance = mock_simple_nws.return_value
|
instance = mock_simple_nws.return_value
|
||||||
|
|
||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
@@ -302,9 +296,6 @@ async def test_error_observation(
|
|||||||
hass: HomeAssistant, mock_simple_nws, no_sensor
|
hass: HomeAssistant, mock_simple_nws, no_sensor
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test error during update observation."""
|
"""Test error during update observation."""
|
||||||
utc_time = dt_util.utcnow()
|
|
||||||
with patch("homeassistant.components.nws.coordinator.utcnow") as mock_utc:
|
|
||||||
mock_utc.return_value = utc_time
|
|
||||||
instance = mock_simple_nws.return_value
|
instance = mock_simple_nws.return_value
|
||||||
# first update fails
|
# first update fails
|
||||||
instance.update_observation.side_effect = aiohttp.ClientError
|
instance.update_observation.side_effect = aiohttp.ClientError
|
||||||
|
@@ -6,6 +6,7 @@ from ollama import Message, ResponseError
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import conversation, ollama
|
from homeassistant.components import conversation, ollama
|
||||||
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
@@ -110,6 +111,19 @@ async def test_chat(
|
|||||||
), result
|
), result
|
||||||
assert result.response.speech["plain"]["speech"] == "test response"
|
assert result.response.speech["plain"]["speech"] == "test response"
|
||||||
|
|
||||||
|
# Test Conversation tracing
|
||||||
|
traces = trace.async_get_traces()
|
||||||
|
assert traces
|
||||||
|
last_trace = traces[-1].as_dict()
|
||||||
|
trace_events = last_trace.get("events", [])
|
||||||
|
assert [event["event_type"] for event in trace_events] == [
|
||||||
|
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
|
]
|
||||||
|
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||||
|
detail_event = trace_events[1]
|
||||||
|
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
||||||
|
|
||||||
|
|
||||||
async def test_message_history_trimming(
|
async def test_message_history_trimming(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
|
@@ -9,9 +9,17 @@ import pytest
|
|||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.openai_conversation.const import (
|
from homeassistant.components.openai_conversation.const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
DEFAULT_CHAT_MODEL,
|
CONF_MAX_TOKENS,
|
||||||
|
CONF_PROMPT,
|
||||||
|
CONF_RECOMMENDED,
|
||||||
|
CONF_TEMPERATURE,
|
||||||
|
CONF_TOP_P,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
|
||||||
@@ -75,7 +83,7 @@ async def test_options(
|
|||||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert options["data"]["prompt"] == "Speak like a pirate"
|
assert options["data"]["prompt"] == "Speak like a pirate"
|
||||||
assert options["data"]["max_tokens"] == 200
|
assert options["data"]["max_tokens"] == 200
|
||||||
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
|
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
|||||||
|
|
||||||
assert result2["type"] is FlowResultType.FORM
|
assert result2["type"] is FlowResultType.FORM
|
||||||
assert result2["errors"] == {"base": error}
|
assert result2["errors"] == {"base": error}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("current_options", "new_options", "expected_options"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "none",
|
||||||
|
CONF_PROMPT: "bla",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: False,
|
||||||
|
CONF_PROMPT: "Speak like a pirate",
|
||||||
|
CONF_TEMPERATURE: 0.3,
|
||||||
|
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||||
|
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||||
|
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "assist",
|
||||||
|
CONF_PROMPT: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: "assist",
|
||||||
|
CONF_PROMPT: "",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_options_switching(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry,
|
||||||
|
mock_init_component,
|
||||||
|
current_options,
|
||||||
|
new_options,
|
||||||
|
expected_options,
|
||||||
|
) -> None:
|
||||||
|
"""Test the options form."""
|
||||||
|
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
|
||||||
|
options_flow = await hass.config_entries.options.async_init(
|
||||||
|
mock_config_entry.entry_id
|
||||||
|
)
|
||||||
|
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
|
||||||
|
options_flow = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
{
|
||||||
|
**current_options,
|
||||||
|
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
options = await hass.config_entries.options.async_configure(
|
||||||
|
options_flow["flow_id"],
|
||||||
|
new_options,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert options["data"] == expected_options
|
||||||
|
@@ -15,6 +15,7 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.components.conversation import trace
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@@ -200,6 +201,20 @@ async def test_function_call(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test Conversation tracing
|
||||||
|
traces = trace.async_get_traces()
|
||||||
|
assert traces
|
||||||
|
last_trace = traces[-1].as_dict()
|
||||||
|
trace_events = last_trace.get("events", [])
|
||||||
|
assert [event["event_type"] for event in trace_events] == [
|
||||||
|
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||||
|
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||||
|
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
||||||
|
]
|
||||||
|
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||||
|
detail_event = trace_events[1]
|
||||||
|
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
from pyplaato.models.airlock import PlaatoAirlock
|
from pyplaato.models.airlock import PlaatoAirlock
|
||||||
from pyplaato.models.device import PlaatoDeviceType
|
from pyplaato.models.device import PlaatoDeviceType
|
||||||
from pyplaato.models.keg import PlaatoKeg
|
from pyplaato.models.keg import PlaatoKeg
|
||||||
@@ -23,6 +24,7 @@ AIRLOCK_DATA = {}
|
|||||||
KEG_DATA = {}
|
KEG_DATA = {}
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2024-05-24 12:00:00", tz_offset=0)
|
||||||
async def init_integration(
|
async def init_integration(
|
||||||
hass: HomeAssistant, device_type: PlaatoDeviceType
|
hass: HomeAssistant, device_type: PlaatoDeviceType
|
||||||
) -> MockConfigEntry:
|
) -> MockConfigEntry:
|
||||||
|
@@ -492,7 +492,6 @@ async def test_block_set_mode_auth_error(
|
|||||||
{ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT},
|
{ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
|
@@ -227,7 +227,6 @@ async def test_block_set_value_auth_error(
|
|||||||
{ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30},
|
{ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
|
@@ -618,7 +618,6 @@ async def test_rpc_sleeping_update_entity_service(
|
|||||||
service_data={ATTR_ENTITY_ID: entity_id},
|
service_data={ATTR_ENTITY_ID: entity_id},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
# Entity should be available after update_entity service call
|
# Entity should be available after update_entity service call
|
||||||
state = hass.states.get(entity_id)
|
state = hass.states.get(entity_id)
|
||||||
@@ -667,7 +666,6 @@ async def test_block_sleeping_update_entity_service(
|
|||||||
service_data={ATTR_ENTITY_ID: entity_id},
|
service_data={ATTR_ENTITY_ID: entity_id},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
# Entity should be available after update_entity service call
|
# Entity should be available after update_entity service call
|
||||||
state = hass.states.get(entity_id)
|
state = hass.states.get(entity_id)
|
||||||
|
@@ -230,7 +230,6 @@ async def test_block_set_state_auth_error(
|
|||||||
{ATTR_ENTITY_ID: "switch.test_name_channel_1"},
|
{ATTR_ENTITY_ID: "switch.test_name_channel_1"},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
@@ -374,7 +373,6 @@ async def test_rpc_auth_error(
|
|||||||
{ATTR_ENTITY_ID: "switch.test_switch_0"},
|
{ATTR_ENTITY_ID: "switch.test_switch_0"},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
|
@@ -207,7 +207,6 @@ async def test_block_update_auth_error(
|
|||||||
{ATTR_ENTITY_ID: "update.test_name_firmware_update"},
|
{ATTR_ENTITY_ID: "update.test_name_firmware_update"},
|
||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
@@ -669,7 +668,6 @@ async def test_rpc_update_auth_error(
|
|||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
flows = hass.config_entries.flow.async_progress()
|
flows = hass.config_entries.flow.async_progress()
|
||||||
|
@@ -18,9 +18,15 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_bridge(request):
|
def mock_bridge(request):
|
||||||
"""Return a mocked SwitcherBridge."""
|
"""Return a mocked SwitcherBridge."""
|
||||||
with patch(
|
with (
|
||||||
"homeassistant.components.switcher_kis.utils.SwitcherBridge", autospec=True
|
patch(
|
||||||
) as bridge_mock:
|
"homeassistant.components.switcher_kis.SwitcherBridge", autospec=True
|
||||||
|
) as bridge_mock,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.switcher_kis.utils.SwitcherBridge",
|
||||||
|
new=bridge_mock,
|
||||||
|
),
|
||||||
|
):
|
||||||
bridge = bridge_mock.return_value
|
bridge = bridge_mock.return_value
|
||||||
|
|
||||||
bridge.devices = []
|
bridge.devices = []
|
||||||
|
@@ -4,11 +4,7 @@ from datetime import timedelta
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.switcher_kis.const import (
|
from homeassistant.components.switcher_kis.const import MAX_UPDATE_INTERVAL_SEC
|
||||||
DATA_DEVICE,
|
|
||||||
DOMAIN,
|
|
||||||
MAX_UPDATE_INTERVAL_SEC,
|
|
||||||
)
|
|
||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
from homeassistant.const import STATE_UNAVAILABLE
|
from homeassistant.const import STATE_UNAVAILABLE
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@@ -24,15 +20,14 @@ async def test_update_fail(
|
|||||||
hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture
|
hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test entities state unavailable when updates fail.."""
|
"""Test entities state unavailable when updates fail.."""
|
||||||
await init_integration(hass)
|
entry = await init_integration(hass)
|
||||||
assert mock_bridge
|
assert mock_bridge
|
||||||
|
|
||||||
mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES)
|
mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert mock_bridge.is_running is True
|
assert mock_bridge.is_running is True
|
||||||
assert len(hass.data[DOMAIN]) == 2
|
assert len(entry.runtime_data) == 2
|
||||||
assert len(hass.data[DOMAIN][DATA_DEVICE]) == 2
|
|
||||||
|
|
||||||
async_fire_time_changed(
|
async_fire_time_changed(
|
||||||
hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1)
|
hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1)
|
||||||
@@ -77,11 +72,9 @@ async def test_entry_unload(hass: HomeAssistant, mock_bridge) -> None:
|
|||||||
|
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
assert mock_bridge.is_running is True
|
assert mock_bridge.is_running is True
|
||||||
assert len(hass.data[DOMAIN]) == 2
|
|
||||||
|
|
||||||
await hass.config_entries.async_unload(entry.entry_id)
|
await hass.config_entries.async_unload(entry.entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert entry.state is ConfigEntryState.NOT_LOADED
|
assert entry.state is ConfigEntryState.NOT_LOADED
|
||||||
assert mock_bridge.is_running is False
|
assert mock_bridge.is_running is False
|
||||||
assert len(hass.data[DOMAIN]) == 0
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user