mirror of
https://github.com/home-assistant/core.git
synced 2025-08-06 14:15:12 +02:00
Merge branch 'dev' into jbouwh-mqtt-device-discovery
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -20,6 +21,11 @@ from .models import (
|
||||
ConversationInput,
|
||||
ConversationResult,
|
||||
)
|
||||
from .trace import (
|
||||
ConversationTraceEvent,
|
||||
ConversationTraceEventType,
|
||||
async_conversation_trace,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,15 +90,23 @@ async def async_converse(
|
||||
language = hass.config.language
|
||||
|
||||
_LOGGER.debug("Processing in %s: %s", language, text)
|
||||
return await method(
|
||||
ConversationInput(
|
||||
text=text,
|
||||
context=context,
|
||||
conversation_id=conversation_id,
|
||||
device_id=device_id,
|
||||
language=language,
|
||||
)
|
||||
conversation_input = ConversationInput(
|
||||
text=text,
|
||||
context=context,
|
||||
conversation_id=conversation_id,
|
||||
device_id=device_id,
|
||||
language=language,
|
||||
)
|
||||
with async_conversation_trace() as trace:
|
||||
trace.add_event(
|
||||
ConversationTraceEvent(
|
||||
ConversationTraceEventType.ASYNC_PROCESS,
|
||||
dataclasses.asdict(conversation_input),
|
||||
)
|
||||
)
|
||||
result = await method(conversation_input)
|
||||
trace.set_result(**result.as_dict())
|
||||
return result
|
||||
|
||||
|
||||
class AgentManager:
|
||||
|
118
homeassistant/components/conversation/trace.py
Normal file
118
homeassistant/components/conversation/trace.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Debug traces for conversation."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import asdict, dataclass, field
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||
|
||||
STORED_TRACES = 3
|
||||
|
||||
|
||||
class ConversationTraceEventType(enum.StrEnum):
|
||||
"""Type of an event emitted during a conversation."""
|
||||
|
||||
ASYNC_PROCESS = "async_process"
|
||||
"""The conversation is started from user input."""
|
||||
|
||||
AGENT_DETAIL = "agent_detail"
|
||||
"""Event detail added by a conversation agent."""
|
||||
|
||||
LLM_TOOL_CALL = "llm_tool_call"
|
||||
"""An LLM Tool call"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConversationTraceEvent:
|
||||
"""Event emitted during a conversation."""
|
||||
|
||||
event_type: ConversationTraceEventType
|
||||
data: dict[str, Any] | None = None
|
||||
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
|
||||
|
||||
|
||||
class ConversationTrace:
|
||||
"""Stores debug data related to a conversation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize ConversationTrace."""
|
||||
self._trace_id = ulid_util.ulid_now()
|
||||
self._events: list[ConversationTraceEvent] = []
|
||||
self._error: Exception | None = None
|
||||
self._result: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def trace_id(self) -> str:
|
||||
"""Identifier for this trace."""
|
||||
return self._trace_id
|
||||
|
||||
def add_event(self, event: ConversationTraceEvent) -> None:
|
||||
"""Add an event to the trace."""
|
||||
self._events.append(event)
|
||||
|
||||
def set_error(self, ex: Exception) -> None:
|
||||
"""Set error."""
|
||||
self._error = ex
|
||||
|
||||
def set_result(self, **kwargs: Any) -> None:
|
||||
"""Set result."""
|
||||
self._result = {**kwargs}
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return dictionary version of this ConversationTrace."""
|
||||
result: dict[str, Any] = {
|
||||
"id": self._trace_id,
|
||||
"events": [asdict(event) for event in self._events],
|
||||
}
|
||||
if self._error is not None:
|
||||
result["error"] = str(self._error) or self._error.__class__.__name__
|
||||
if self._result is not None:
|
||||
result["result"] = self._result
|
||||
return result
|
||||
|
||||
|
||||
_current_trace: ContextVar[ConversationTrace | None] = ContextVar(
|
||||
"current_trace", default=None
|
||||
)
|
||||
_recent_traces: LimitedSizeDict[str, ConversationTrace] = LimitedSizeDict(
|
||||
size_limit=STORED_TRACES
|
||||
)
|
||||
|
||||
|
||||
def async_conversation_trace_append(
|
||||
event_type: ConversationTraceEventType, event_data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Append a ConversationTraceEvent to the current active trace."""
|
||||
trace = _current_trace.get()
|
||||
if not trace:
|
||||
return
|
||||
trace.add_event(ConversationTraceEvent(event_type, event_data))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def async_conversation_trace() -> Generator[ConversationTrace, None]:
|
||||
"""Create a new active ConversationTrace."""
|
||||
trace = ConversationTrace()
|
||||
token = _current_trace.set(trace)
|
||||
_recent_traces[trace.trace_id] = trace
|
||||
try:
|
||||
yield trace
|
||||
except Exception as ex:
|
||||
trace.set_error(ex)
|
||||
raise
|
||||
finally:
|
||||
_current_trace.reset(token)
|
||||
|
||||
|
||||
def async_get_traces() -> list[ConversationTrace]:
|
||||
"""Get the most recent traces."""
|
||||
return list(_recent_traces.values())
|
||||
|
||||
|
||||
def async_clear_traces() -> None:
|
||||
"""Clear all traces."""
|
||||
_recent_traces.clear()
|
@@ -5,5 +5,5 @@
|
||||
"documentation": "https://www.home-assistant.io/integrations/envisalink",
|
||||
"iot_class": "local_push",
|
||||
"loggers": ["pyenvisalink"],
|
||||
"requirements": ["pyenvisalink==4.6"]
|
||||
"requirements": ["pyenvisalink==4.7"]
|
||||
}
|
||||
|
@@ -1,6 +1,5 @@
|
||||
"""Support for Arduino-compatible Microcontrollers through Firmata."""
|
||||
|
||||
import asyncio
|
||||
from copy import copy
|
||||
import logging
|
||||
|
||||
@@ -212,16 +211,15 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
|
||||
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
||||
"""Shutdown and close a Firmata board for a config entry."""
|
||||
_LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME])
|
||||
|
||||
unload_entries = []
|
||||
for conf, platform in CONF_PLATFORM_MAP.items():
|
||||
if conf in config_entry.data:
|
||||
unload_entries.append(
|
||||
hass.config_entries.async_forward_entry_unload(config_entry, platform)
|
||||
)
|
||||
results = []
|
||||
if unload_entries:
|
||||
results = await asyncio.gather(*unload_entries)
|
||||
results: list[bool] = []
|
||||
if platforms := [
|
||||
platform
|
||||
for conf, platform in CONF_PLATFORM_MAP.items()
|
||||
if conf in config_entry.data
|
||||
]:
|
||||
results.append(
|
||||
await hass.config_entries.async_unload_platforms(config_entry, platforms)
|
||||
)
|
||||
results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset())
|
||||
|
||||
return False not in results
|
||||
|
@@ -11,12 +11,13 @@ from .const import (
|
||||
CONF_DAMPING_EVENING,
|
||||
CONF_DAMPING_MORNING,
|
||||
CONF_MODULES_POWER,
|
||||
DOMAIN,
|
||||
)
|
||||
from .coordinator import ForecastSolarDataUpdateCoordinator
|
||||
|
||||
PLATFORMS = [Platform.SENSOR]
|
||||
|
||||
type ForecastSolarConfigEntry = ConfigEntry[ForecastSolarDataUpdateCoordinator]
|
||||
|
||||
|
||||
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Migrate old config entry."""
|
||||
@@ -36,12 +37,14 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant, entry: ForecastSolarConfigEntry
|
||||
) -> bool:
|
||||
"""Set up Forecast.Solar from a config entry."""
|
||||
coordinator = ForecastSolarDataUpdateCoordinator(hass, entry)
|
||||
await coordinator.async_config_entry_first_refresh()
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
|
||||
entry.runtime_data = coordinator
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
||||
@@ -52,11 +55,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
if unload_ok:
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
|
||||
return unload_ok
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
||||
|
||||
async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
|
@@ -4,15 +4,11 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from forecast_solar import Estimate
|
||||
|
||||
from homeassistant.components.diagnostics import async_redact_data
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
|
||||
|
||||
from .const import DOMAIN
|
||||
from . import ForecastSolarConfigEntry
|
||||
|
||||
TO_REDACT = {
|
||||
CONF_API_KEY,
|
||||
@@ -22,10 +18,10 @@ TO_REDACT = {
|
||||
|
||||
|
||||
async def async_get_config_entry_diagnostics(
|
||||
hass: HomeAssistant, entry: ConfigEntry
|
||||
hass: HomeAssistant, entry: ForecastSolarConfigEntry
|
||||
) -> dict[str, Any]:
|
||||
"""Return diagnostics for a config entry."""
|
||||
coordinator: DataUpdateCoordinator[Estimate] = hass.data[DOMAIN][entry.entry_id]
|
||||
coordinator = entry.runtime_data
|
||||
|
||||
return {
|
||||
"entry": {
|
||||
|
@@ -4,19 +4,21 @@ from __future__ import annotations
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DOMAIN
|
||||
from .coordinator import ForecastSolarDataUpdateCoordinator
|
||||
|
||||
|
||||
async def async_get_solar_forecast(
|
||||
hass: HomeAssistant, config_entry_id: str
|
||||
) -> dict[str, dict[str, float | int]] | None:
|
||||
"""Get solar forecast for a config entry ID."""
|
||||
if (coordinator := hass.data[DOMAIN].get(config_entry_id)) is None:
|
||||
if (
|
||||
entry := hass.config_entries.async_get_entry(config_entry_id)
|
||||
) is None or not isinstance(entry.runtime_data, ForecastSolarDataUpdateCoordinator):
|
||||
return None
|
||||
|
||||
return {
|
||||
"wh_hours": {
|
||||
timestamp.isoformat(): val
|
||||
for timestamp, val in coordinator.data.wh_period.items()
|
||||
for timestamp, val in entry.runtime_data.data.wh_period.items()
|
||||
}
|
||||
}
|
||||
|
@@ -16,7 +16,6 @@ from homeassistant.components.sensor import (
|
||||
SensorEntityDescription,
|
||||
SensorStateClass,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import UnitOfEnergy, UnitOfPower
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
@@ -24,6 +23,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import StateType
|
||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||
|
||||
from . import ForecastSolarConfigEntry
|
||||
from .const import DOMAIN
|
||||
from .coordinator import ForecastSolarDataUpdateCoordinator
|
||||
|
||||
@@ -133,10 +133,12 @@ SENSORS: tuple[ForecastSolarSensorEntityDescription, ...] = (
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
|
||||
hass: HomeAssistant,
|
||||
entry: ForecastSolarConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Defer sensor setup to the shared sensor module."""
|
||||
coordinator: ForecastSolarDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id]
|
||||
coordinator = entry.runtime_data
|
||||
|
||||
async_add_entities(
|
||||
ForecastSolarSensorEntity(
|
||||
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Final
|
||||
from typing import Any, Final
|
||||
|
||||
from homeassistant.components.button import (
|
||||
ButtonDeviceClass,
|
||||
@@ -30,7 +30,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
class FritzButtonDescription(ButtonEntityDescription):
|
||||
"""Class to describe a Button entity."""
|
||||
|
||||
press_action: Callable
|
||||
press_action: Callable[[AvmWrapper], Any]
|
||||
|
||||
|
||||
BUTTONS: Final = [
|
||||
|
@@ -57,9 +57,6 @@ ERROR_UPNP_NOT_CONFIGURED = "upnp_not_configured"
|
||||
ERROR_UNKNOWN = "unknown_error"
|
||||
|
||||
FRITZ_SERVICES = "fritz_services"
|
||||
SERVICE_REBOOT = "reboot"
|
||||
SERVICE_RECONNECT = "reconnect"
|
||||
SERVICE_CLEANUP = "cleanup"
|
||||
SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password"
|
||||
|
||||
SWITCH_TYPE_DEFLECTION = "CallDeflection"
|
||||
|
@@ -46,9 +46,6 @@ from .const import (
|
||||
DEFAULT_USERNAME,
|
||||
DOMAIN,
|
||||
FRITZ_EXCEPTIONS,
|
||||
SERVICE_CLEANUP,
|
||||
SERVICE_REBOOT,
|
||||
SERVICE_RECONNECT,
|
||||
SERVICE_SET_GUEST_WIFI_PW,
|
||||
MeshRoles,
|
||||
)
|
||||
@@ -730,30 +727,6 @@ class FritzBoxTools(DataUpdateCoordinator[UpdateCoordinatorDataType]):
|
||||
)
|
||||
|
||||
try:
|
||||
if service_call.service == SERVICE_REBOOT:
|
||||
_LOGGER.warning(
|
||||
'Service "fritz.reboot" is deprecated, please use the corresponding'
|
||||
" button entity instead"
|
||||
)
|
||||
await self.async_trigger_reboot()
|
||||
return
|
||||
|
||||
if service_call.service == SERVICE_RECONNECT:
|
||||
_LOGGER.warning(
|
||||
'Service "fritz.reconnect" is deprecated, please use the'
|
||||
" corresponding button entity instead"
|
||||
)
|
||||
await self.async_trigger_reconnect()
|
||||
return
|
||||
|
||||
if service_call.service == SERVICE_CLEANUP:
|
||||
_LOGGER.warning(
|
||||
'Service "fritz.cleanup" is deprecated, please use the'
|
||||
" corresponding button entity instead"
|
||||
)
|
||||
await self.async_trigger_cleanup(config_entry)
|
||||
return
|
||||
|
||||
if service_call.service == SERVICE_SET_GUEST_WIFI_PW:
|
||||
await self.async_trigger_set_guest_password(
|
||||
service_call.data.get("password"),
|
||||
|
@@ -11,14 +11,7 @@ from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.service import async_extract_config_entry_ids
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
FRITZ_SERVICES,
|
||||
SERVICE_CLEANUP,
|
||||
SERVICE_REBOOT,
|
||||
SERVICE_RECONNECT,
|
||||
SERVICE_SET_GUEST_WIFI_PW,
|
||||
)
|
||||
from .const import DOMAIN, FRITZ_SERVICES, SERVICE_SET_GUEST_WIFI_PW
|
||||
from .coordinator import AvmWrapper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -32,9 +25,6 @@ SERVICE_SCHEMA_SET_GUEST_WIFI_PW = vol.Schema(
|
||||
)
|
||||
|
||||
SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [
|
||||
(SERVICE_CLEANUP, None),
|
||||
(SERVICE_REBOOT, None),
|
||||
(SERVICE_RECONNECT, None),
|
||||
(SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW),
|
||||
]
|
||||
|
||||
|
@@ -1,31 +1,3 @@
|
||||
reconnect:
|
||||
fields:
|
||||
device_id:
|
||||
required: true
|
||||
selector:
|
||||
device:
|
||||
integration: fritz
|
||||
entity:
|
||||
device_class: connectivity
|
||||
reboot:
|
||||
fields:
|
||||
device_id:
|
||||
required: true
|
||||
selector:
|
||||
device:
|
||||
integration: fritz
|
||||
entity:
|
||||
device_class: connectivity
|
||||
|
||||
cleanup:
|
||||
fields:
|
||||
device_id:
|
||||
required: true
|
||||
selector:
|
||||
device:
|
||||
integration: fritz
|
||||
entity:
|
||||
device_class: connectivity
|
||||
set_guest_wifi_password:
|
||||
fields:
|
||||
device_id:
|
||||
|
@@ -144,42 +144,12 @@
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"reconnect": {
|
||||
"name": "[%key:component::fritz::entity::button::reconnect::name%]",
|
||||
"description": "Reconnects your FRITZ!Box internet connection.",
|
||||
"fields": {
|
||||
"device_id": {
|
||||
"name": "Fritz!Box Device",
|
||||
"description": "Select the Fritz!Box to reconnect."
|
||||
}
|
||||
}
|
||||
},
|
||||
"reboot": {
|
||||
"name": "Reboot",
|
||||
"description": "Reboots your FRITZ!Box.",
|
||||
"fields": {
|
||||
"device_id": {
|
||||
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
|
||||
"description": "Select the Fritz!Box to reboot."
|
||||
}
|
||||
}
|
||||
},
|
||||
"cleanup": {
|
||||
"name": "Remove stale device tracker entities",
|
||||
"description": "Remove FRITZ!Box stale device_tracker entities.",
|
||||
"fields": {
|
||||
"device_id": {
|
||||
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
|
||||
"description": "Select the Fritz!Box to check."
|
||||
}
|
||||
}
|
||||
},
|
||||
"set_guest_wifi_password": {
|
||||
"name": "Set guest Wi-Fi password",
|
||||
"description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.",
|
||||
"fields": {
|
||||
"device_id": {
|
||||
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
|
||||
"name": "Fritz!Box Device",
|
||||
"description": "Select the Fritz!Box to configure."
|
||||
},
|
||||
"password": {
|
||||
|
46
homeassistant/components/fronius/diagnostics.py
Normal file
46
homeassistant/components/fronius/diagnostics.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Diagnostics support for Fronius."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.diagnostics import async_redact_data
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import FroniusConfigEntry
|
||||
|
||||
TO_REDACT = {"unique_id", "unique_identifier", "serial"}
|
||||
|
||||
|
||||
async def async_get_config_entry_diagnostics(
|
||||
hass: HomeAssistant, config_entry: FroniusConfigEntry
|
||||
) -> dict[str, Any]:
|
||||
"""Return diagnostics for a config entry."""
|
||||
diag: dict[str, Any] = {}
|
||||
solar_net = config_entry.runtime_data
|
||||
fronius = solar_net.fronius
|
||||
|
||||
diag["config_entry"] = config_entry.as_dict()
|
||||
diag["inverter_info"] = await fronius.inverter_info()
|
||||
|
||||
diag["coordinators"] = {"inverters": {}}
|
||||
for inv in solar_net.inverter_coordinators:
|
||||
diag["coordinators"]["inverters"] |= inv.data
|
||||
|
||||
diag["coordinators"]["logger"] = (
|
||||
solar_net.logger_coordinator.data if solar_net.logger_coordinator else None
|
||||
)
|
||||
diag["coordinators"]["meter"] = (
|
||||
solar_net.meter_coordinator.data if solar_net.meter_coordinator else None
|
||||
)
|
||||
diag["coordinators"]["ohmpilot"] = (
|
||||
solar_net.ohmpilot_coordinator.data if solar_net.ohmpilot_coordinator else None
|
||||
)
|
||||
diag["coordinators"]["power_flow"] = (
|
||||
solar_net.power_flow_coordinator.data
|
||||
if solar_net.power_flow_coordinator
|
||||
else None
|
||||
)
|
||||
diag["coordinators"]["storage"] = (
|
||||
solar_net.storage_coordinator.data if solar_net.storage_coordinator else None
|
||||
)
|
||||
|
||||
return async_redact_data(diag, TO_REDACT)
|
@@ -181,8 +181,7 @@ async def google_generative_ai_config_option_schema(
|
||||
schema = {
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
||||
default=DEFAULT_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
|
@@ -22,4 +22,4 @@ CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
|
||||
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
|
||||
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_LOW_AND_ABOVE"
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
|
@@ -5,13 +5,14 @@ from __future__ import annotations
|
||||
from typing import Any, Literal
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
from google.api_core.exceptions import ClientError
|
||||
from google.api_core.exceptions import GoogleAPICallError
|
||||
import google.generativeai as genai
|
||||
import google.generativeai.types as genai_types
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -205,15 +206,6 @@ class GoogleGenerativeAIConversationEntity(
|
||||
messages = [{}, {}]
|
||||
|
||||
try:
|
||||
prompt = template.Template(
|
||||
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
if llm_api:
|
||||
empty_tool_input = llm.ToolInput(
|
||||
tool_name="",
|
||||
@@ -226,9 +218,24 @@ class GoogleGenerativeAIConversationEntity(
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
|
||||
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
|
||||
|
||||
else:
|
||||
api_prompt = llm.PROMPT_NO_API_CONFIGURED
|
||||
|
||||
prompt = "\n".join(
|
||||
(
|
||||
template.Template(
|
||||
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
},
|
||||
parse_result=False,
|
||||
),
|
||||
api_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
@@ -244,6 +251,9 @@ class GoogleGenerativeAIConversationEntity(
|
||||
messages[1] = {"role": "model", "parts": "Ok"}
|
||||
|
||||
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
||||
)
|
||||
|
||||
chat = model.start_chat(history=messages)
|
||||
chat_request = user_input.text
|
||||
@@ -252,15 +262,25 @@ class GoogleGenerativeAIConversationEntity(
|
||||
try:
|
||||
chat_response = await chat.send_message_async(chat_request)
|
||||
except (
|
||||
ClientError,
|
||||
GoogleAPICallError,
|
||||
ValueError,
|
||||
genai_types.BlockedPromptException,
|
||||
genai_types.StopCandidateException,
|
||||
) as err:
|
||||
LOGGER.error("Error sending message: %s", err)
|
||||
LOGGER.error("Error sending message: %s %s", type(err), err)
|
||||
|
||||
if isinstance(
|
||||
err, genai_types.StopCandidateException
|
||||
) and "finish_reason: SAFETY\n" in str(err):
|
||||
error = "The message got blocked by your safety settings"
|
||||
else:
|
||||
error = (
|
||||
f"Sorry, I had a problem talking to Google Generative AI: {err}"
|
||||
)
|
||||
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to Google Generative AI: {err}",
|
||||
error,
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"domain": "integration",
|
||||
"name": "Integration - Riemann sum integral",
|
||||
"name": "Integral",
|
||||
"after_dependencies": ["counter"],
|
||||
"codeowners": ["@dgomes"],
|
||||
"config_flow": true,
|
||||
|
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"title": "Integration - Riemann sum integral sensor",
|
||||
"title": "Integral sensor",
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
|
@@ -21,7 +21,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
CONF_VALIDATOR = "validator"
|
||||
CONF_SECRET = "secret"
|
||||
URL = "/api/meraki"
|
||||
VERSION = "2.0"
|
||||
ACCEPTED_VERSIONS = ["2.0", "2.1"]
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -74,7 +74,7 @@ class MerakiView(HomeAssistantView):
|
||||
if data["secret"] != self.secret:
|
||||
_LOGGER.error("Invalid Secret received from Meraki")
|
||||
return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY)
|
||||
if data["version"] != VERSION:
|
||||
if data["version"] not in ACCEPTED_VERSIONS:
|
||||
_LOGGER.error("Invalid API version: %s", data["version"])
|
||||
return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY)
|
||||
_LOGGER.debug("Valid Secret")
|
||||
|
@@ -6,6 +6,6 @@
|
||||
"documentation": "https://www.home-assistant.io/integrations/minecraft_server",
|
||||
"iot_class": "local_polling",
|
||||
"loggers": ["dnspython", "mcstatus"],
|
||||
"quality_scale": "gold",
|
||||
"quality_scale": "platinum",
|
||||
"requirements": ["mcstatus==11.1.1"]
|
||||
}
|
||||
|
@@ -39,6 +39,7 @@ from .client import ( # noqa: F401
|
||||
MQTT,
|
||||
async_publish,
|
||||
async_subscribe,
|
||||
async_subscribe_internal,
|
||||
publish,
|
||||
subscribe,
|
||||
)
|
||||
@@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
def collect_msg(msg: ReceiveMessage) -> None:
|
||||
messages.append((msg.topic, str(msg.payload).replace("\n", "")))
|
||||
|
||||
unsub = await async_subscribe(hass, call.data["topic"], collect_msg)
|
||||
unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)
|
||||
|
||||
def write_dump() -> None:
|
||||
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
|
||||
@@ -459,7 +460,7 @@ async def websocket_subscribe(
|
||||
|
||||
# Perform UTF-8 decoding directly in callback routine
|
||||
qos: int = msg.get("qos", DEFAULT_QOS)
|
||||
connection.subscriptions[msg["id"]] = await async_subscribe(
|
||||
connection.subscriptions[msg["id"]] = async_subscribe_internal(
|
||||
hass, msg["topic"], forward_messages, encoding=None, qos=qos
|
||||
)
|
||||
|
||||
@@ -522,24 +523,13 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
mqtt_client = mqtt_data.client
|
||||
|
||||
# Unload publish and dump services.
|
||||
hass.services.async_remove(
|
||||
DOMAIN,
|
||||
SERVICE_PUBLISH,
|
||||
)
|
||||
hass.services.async_remove(
|
||||
DOMAIN,
|
||||
SERVICE_DUMP,
|
||||
)
|
||||
hass.services.async_remove(DOMAIN, SERVICE_PUBLISH)
|
||||
hass.services.async_remove(DOMAIN, SERVICE_DUMP)
|
||||
|
||||
# Stop the discovery
|
||||
await discovery.async_stop(hass)
|
||||
# Unload the platforms
|
||||
await asyncio.gather(
|
||||
*(
|
||||
hass.config_entries.async_forward_entry_unload(entry, component)
|
||||
for component in mqtt_data.platforms_loaded
|
||||
)
|
||||
)
|
||||
await hass.config_entries.async_unload_platforms(entry, mqtt_data.platforms_loaded)
|
||||
mqtt_data.platforms_loaded = set()
|
||||
await asyncio.sleep(0)
|
||||
# Unsubscribe reload dispatchers
|
||||
|
@@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_alarm_disarm(self, code: str | None = None) -> None:
|
||||
"""Send disarm command.
|
||||
|
@@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
@callback
|
||||
def _value_is_expired(self, *_: Any) -> None:
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from base64 import b64decode
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -20,7 +21,6 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from . import subscription
|
||||
from .config import MQTT_BASE_SCHEMA
|
||||
from .const import CONF_QOS, CONF_TOPIC
|
||||
from .debug_info import log_messages
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import ReceiveMessage
|
||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||
@@ -97,27 +97,31 @@ class MqttCamera(MqttEntity, Camera):
|
||||
"""Return the config schema."""
|
||||
return DISCOVERY_SCHEMA
|
||||
|
||||
@callback
|
||||
def _image_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
if CONF_IMAGE_ENCODING in self._config:
|
||||
self._last_image = b64decode(msg.payload)
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(msg.payload, bytes)
|
||||
self._last_image = msg.payload
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
if CONF_IMAGE_ENCODING in self._config:
|
||||
self._last_image = b64decode(msg.payload)
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(msg.payload, bytes)
|
||||
self._last_image = msg.payload
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"state_topic": {
|
||||
"topic": self._config[CONF_TOPIC],
|
||||
"msg_callback": message_received,
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._image_received,
|
||||
None,
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": None,
|
||||
}
|
||||
@@ -126,7 +130,7 @@ class MqttCamera(MqttEntity, Camera):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_camera_image(
|
||||
self, width: int | None = None, height: int | None = None
|
||||
|
@@ -77,7 +77,6 @@ from .const import (
|
||||
)
|
||||
from .models import (
|
||||
DATA_MQTT,
|
||||
AsyncMessageCallbackType,
|
||||
MessageCallbackType,
|
||||
MqttData,
|
||||
PublishMessage,
|
||||
@@ -184,7 +183,7 @@ async def async_publish(
|
||||
async def async_subscribe(
|
||||
hass: HomeAssistant,
|
||||
topic: str,
|
||||
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
|
||||
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||
qos: int = DEFAULT_QOS,
|
||||
encoding: str | None = DEFAULT_ENCODING,
|
||||
) -> CALLBACK_TYPE:
|
||||
@@ -192,13 +191,25 @@ async def async_subscribe(
|
||||
|
||||
Call the return value to unsubscribe.
|
||||
"""
|
||||
if not mqtt_config_entry_enabled(hass):
|
||||
raise HomeAssistantError(
|
||||
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
|
||||
translation_key="mqtt_not_setup_cannot_subscribe",
|
||||
translation_domain=DOMAIN,
|
||||
translation_placeholders={"topic": topic},
|
||||
)
|
||||
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
|
||||
|
||||
|
||||
@callback
|
||||
def async_subscribe_internal(
|
||||
hass: HomeAssistant,
|
||||
topic: str,
|
||||
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||
qos: int = DEFAULT_QOS,
|
||||
encoding: str | None = DEFAULT_ENCODING,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Subscribe to an MQTT topic.
|
||||
|
||||
This function is internal to the MQTT integration
|
||||
and may change at any time. It should not be considered
|
||||
a stable API.
|
||||
|
||||
Call the return value to unsubscribe.
|
||||
"""
|
||||
try:
|
||||
mqtt_data = hass.data[DATA_MQTT]
|
||||
except KeyError as exc:
|
||||
@@ -209,12 +220,15 @@ async def async_subscribe(
|
||||
translation_domain=DOMAIN,
|
||||
translation_placeholders={"topic": topic},
|
||||
) from exc
|
||||
return await mqtt_data.client.async_subscribe(
|
||||
topic,
|
||||
msg_callback,
|
||||
qos,
|
||||
encoding,
|
||||
)
|
||||
client = mqtt_data.client
|
||||
if not client.connected and not mqtt_config_entry_enabled(hass):
|
||||
raise HomeAssistantError(
|
||||
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
|
||||
translation_key="mqtt_not_setup_cannot_subscribe",
|
||||
translation_domain=DOMAIN,
|
||||
translation_placeholders={"topic": topic},
|
||||
)
|
||||
return client.async_subscribe(topic, msg_callback, qos, encoding)
|
||||
|
||||
|
||||
@bind_hass
|
||||
@@ -429,10 +443,10 @@ class MQTT:
|
||||
self.config_entry = config_entry
|
||||
self.conf = conf
|
||||
|
||||
self._simple_subscriptions: defaultdict[str, list[Subscription]] = defaultdict(
|
||||
list
|
||||
self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
|
||||
set
|
||||
)
|
||||
self._wildcard_subscriptions: list[Subscription] = []
|
||||
self._wildcard_subscriptions: set[Subscription] = set()
|
||||
# _retained_topics prevents a Subscription from receiving a
|
||||
# retained message more than once per topic. This prevents flooding
|
||||
# already active subscribers when new subscribers subscribe to a topic
|
||||
@@ -452,7 +466,7 @@ class MQTT:
|
||||
self._should_reconnect: bool = True
|
||||
self._available_future: asyncio.Future[bool] | None = None
|
||||
|
||||
self._max_qos: dict[str, int] = {} # topic, max qos
|
||||
self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos
|
||||
self._pending_subscriptions: dict[str, int] = {} # topic, qos
|
||||
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
|
||||
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
|
||||
@@ -789,9 +803,9 @@ class MQTT:
|
||||
The caller is responsible clearing the cache of _matching_subscriptions.
|
||||
"""
|
||||
if subscription.is_simple_match:
|
||||
self._simple_subscriptions[subscription.topic].append(subscription)
|
||||
self._simple_subscriptions[subscription.topic].add(subscription)
|
||||
else:
|
||||
self._wildcard_subscriptions.append(subscription)
|
||||
self._wildcard_subscriptions.add(subscription)
|
||||
|
||||
@callback
|
||||
def _async_untrack_subscription(self, subscription: Subscription) -> None:
|
||||
@@ -820,8 +834,8 @@ class MQTT:
|
||||
"""Queue requested subscriptions."""
|
||||
for subscription in subscriptions:
|
||||
topic, qos = subscription
|
||||
max_qos = max(qos, self._max_qos.setdefault(topic, qos))
|
||||
self._max_qos[topic] = max_qos
|
||||
if (max_qos := self._max_qos[topic]) < qos:
|
||||
self._max_qos[topic] = (max_qos := qos)
|
||||
self._pending_subscriptions[topic] = max_qos
|
||||
# Cancel any pending unsubscribe since we are subscribing now
|
||||
if topic in self._pending_unsubscribes:
|
||||
@@ -832,26 +846,29 @@ class MQTT:
|
||||
|
||||
def _exception_message(
|
||||
self,
|
||||
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
|
||||
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||
msg: ReceiveMessage,
|
||||
) -> str:
|
||||
"""Return a string with the exception message."""
|
||||
# if msg_callback is a partial we return the name of the first argument
|
||||
if isinstance(msg_callback, partial):
|
||||
call_back_name = getattr(msg_callback.args[0], "__name__") # type: ignore[unreachable]
|
||||
else:
|
||||
call_back_name = getattr(msg_callback, "__name__")
|
||||
return (
|
||||
f"Exception in {msg_callback.__name__} when handling msg on "
|
||||
f"Exception in {call_back_name} when handling msg on "
|
||||
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
|
||||
)
|
||||
|
||||
async def async_subscribe(
|
||||
@callback
|
||||
def async_subscribe(
|
||||
self,
|
||||
topic: str,
|
||||
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
|
||||
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
|
||||
qos: int,
|
||||
encoding: str | None = None,
|
||||
) -> Callable[[], None]:
|
||||
"""Set up a subscription to a topic with the provided qos.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
"""Set up a subscription to a topic with the provided qos."""
|
||||
if not isinstance(topic, str):
|
||||
raise HomeAssistantError("Topic needs to be a string!")
|
||||
|
||||
@@ -877,18 +894,18 @@ class MQTT:
|
||||
if self.connected:
|
||||
self._async_queue_subscriptions(((topic, qos),))
|
||||
|
||||
@callback
|
||||
def async_remove() -> None:
|
||||
"""Remove subscription."""
|
||||
self._async_untrack_subscription(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
if subscription in self._retained_topics:
|
||||
del self._retained_topics[subscription]
|
||||
# Only unsubscribe if currently connected
|
||||
if self.connected:
|
||||
self._async_unsubscribe(topic)
|
||||
return partial(self._async_remove, subscription)
|
||||
|
||||
return async_remove
|
||||
@callback
|
||||
def _async_remove(self, subscription: Subscription) -> None:
|
||||
"""Remove subscription."""
|
||||
self._async_untrack_subscription(subscription)
|
||||
self._matching_subscriptions.cache_clear()
|
||||
if subscription in self._retained_topics:
|
||||
del self._retained_topics[subscription]
|
||||
# Only unsubscribe if currently connected
|
||||
if self.connected:
|
||||
self._async_unsubscribe(subscription.topic)
|
||||
|
||||
@callback
|
||||
def _async_unsubscribe(self, topic: str) -> None:
|
||||
@@ -1257,9 +1274,7 @@ class MQTT:
|
||||
|
||||
last_discovery = self._mqtt_data.last_discovery
|
||||
last_subscribe = now if self._pending_subscriptions else self._last_subscribe
|
||||
wait_until = max(
|
||||
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
|
||||
)
|
||||
wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
|
||||
while now < wait_until:
|
||||
await asyncio.sleep(wait_until - now)
|
||||
now = time.monotonic()
|
||||
@@ -1267,9 +1282,7 @@ class MQTT:
|
||||
last_subscribe = (
|
||||
now if self._pending_subscriptions else self._last_subscribe
|
||||
)
|
||||
wait_until = max(
|
||||
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
|
||||
)
|
||||
wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
|
||||
|
||||
|
||||
def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:
|
||||
|
@@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
|
||||
if self._topic[topic] is not None:
|
||||
|
@@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_open_cover(self, **kwargs: Any) -> None:
|
||||
"""Move the cover up.
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -32,13 +33,7 @@ from homeassistant.helpers.typing import ConfigType
|
||||
from . import subscription
|
||||
from .config import MQTT_BASE_SCHEMA
|
||||
from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
CONF_JSON_ATTRS_TOPIC,
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import CONF_JSON_ATTRS_TOPIC, MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType
|
||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||
from .util import valid_subscribe_topic
|
||||
@@ -119,33 +114,31 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
|
||||
config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||
).async_render_with_possible_json_value
|
||||
|
||||
@callback
|
||||
def _tracker_message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
if not payload.strip(): # No output from template, ignore
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
if payload == self._config[CONF_PAYLOAD_HOME]:
|
||||
self._location_name = STATE_HOME
|
||||
elif payload == self._config[CONF_PAYLOAD_NOT_HOME]:
|
||||
self._location_name = STATE_NOT_HOME
|
||||
elif payload == self._config[CONF_PAYLOAD_RESET]:
|
||||
self._location_name = None
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(msg.payload, str)
|
||||
self._location_name = msg.payload
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_location_name"})
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
if not payload.strip(): # No output from template, ignore
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
if payload == self._config[CONF_PAYLOAD_HOME]:
|
||||
self._location_name = STATE_HOME
|
||||
elif payload == self._config[CONF_PAYLOAD_NOT_HOME]:
|
||||
self._location_name = STATE_NOT_HOME
|
||||
elif payload == self._config[CONF_PAYLOAD_RESET]:
|
||||
self._location_name = None
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(msg.payload, str)
|
||||
self._location_name = msg.payload
|
||||
|
||||
state_topic: str | None = self._config.get(CONF_STATE_TOPIC)
|
||||
if state_topic is None:
|
||||
return
|
||||
@@ -155,7 +148,12 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
|
||||
{
|
||||
"state_topic": {
|
||||
"topic": state_topic,
|
||||
"msg_callback": message_received,
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._tracker_message_received,
|
||||
{"_location_name"},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
}
|
||||
},
|
||||
@@ -168,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
@property
|
||||
def latitude(self) -> float | None:
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -31,7 +32,6 @@ from .const import (
|
||||
PAYLOAD_EMPTY_JSON,
|
||||
PAYLOAD_NONE,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
DATA_MQTT,
|
||||
@@ -113,90 +113,91 @@ class MqttEvent(MqttEntity, EventEntity):
|
||||
self._config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||
).async_render_with_possible_json_value
|
||||
|
||||
@callback
|
||||
def _event_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
if msg.retain:
|
||||
_LOGGER.debug(
|
||||
"Ignoring event trigger from replayed retained payload '%s' on topic %s",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
event_attributes: dict[str, Any] = {}
|
||||
event_type: str
|
||||
try:
|
||||
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
|
||||
except MqttValueTemplateException as exc:
|
||||
_LOGGER.warning(exc)
|
||||
return
|
||||
if (
|
||||
not payload
|
||||
or payload is PayloadSentinel.DEFAULT
|
||||
or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON)
|
||||
):
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
try:
|
||||
event_attributes = json_loads_object(payload)
|
||||
event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE))
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"JSON event data detected after processing payload '%s' on"
|
||||
" topic %s, type %s, attributes %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
event_type,
|
||||
event_attributes,
|
||||
)
|
||||
except KeyError:
|
||||
_LOGGER.warning(
|
||||
("`event_type` missing in JSON event payload, " " '%s' on topic %s"),
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
except JSON_DECODE_EXCEPTIONS:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"No valid JSON event payload detected, "
|
||||
"value after processing payload"
|
||||
" '%s' on topic %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
try:
|
||||
self._trigger_event(event_type, event_attributes)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Invalid event type %s for %s received on topic %s, payload %s",
|
||||
event_type,
|
||||
self.entity_id,
|
||||
msg.topic,
|
||||
payload,
|
||||
)
|
||||
return
|
||||
mqtt_data = self.hass.data[DATA_MQTT]
|
||||
mqtt_data.state_write_requests.write_state_request(self)
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
if msg.retain:
|
||||
_LOGGER.debug(
|
||||
"Ignoring event trigger from replayed retained payload '%s' on topic %s",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
event_attributes: dict[str, Any] = {}
|
||||
event_type: str
|
||||
try:
|
||||
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
|
||||
except MqttValueTemplateException as exc:
|
||||
_LOGGER.warning(exc)
|
||||
return
|
||||
if (
|
||||
not payload
|
||||
or payload is PayloadSentinel.DEFAULT
|
||||
or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON)
|
||||
):
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
try:
|
||||
event_attributes = json_loads_object(payload)
|
||||
event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE))
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"JSON event data detected after processing payload '%s' on"
|
||||
" topic %s, type %s, attributes %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
event_type,
|
||||
event_attributes,
|
||||
)
|
||||
except KeyError:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"`event_type` missing in JSON event payload, "
|
||||
" '%s' on topic %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
except JSON_DECODE_EXCEPTIONS:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"No valid JSON event payload detected, "
|
||||
"value after processing payload"
|
||||
" '%s' on topic %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
try:
|
||||
self._trigger_event(event_type, event_attributes)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Invalid event type %s for %s received on topic %s, payload %s",
|
||||
event_type,
|
||||
self.entity_id,
|
||||
msg.topic,
|
||||
payload,
|
||||
)
|
||||
return
|
||||
mqtt_data = self.hass.data[DATA_MQTT]
|
||||
mqtt_data.state_write_requests.write_state_request(self)
|
||||
|
||||
topics["state_topic"] = {
|
||||
"topic": self._config[CONF_STATE_TOPIC],
|
||||
"msg_callback": message_received,
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._event_received,
|
||||
None,
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
@@ -207,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
@@ -49,12 +50,7 @@ from .const import (
|
||||
CONF_STATE_VALUE_TEMPLATE,
|
||||
PAYLOAD_NONE,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
MessageCallbackType,
|
||||
MqttCommandTemplate,
|
||||
@@ -338,137 +334,142 @@ class MqttFan(MqttEntity, FanEntity):
|
||||
for key, tpl in value_templates.items()
|
||||
}
|
||||
|
||||
@callback
|
||||
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message."""
|
||||
payload = self._value_templates[CONF_STATE](msg.payload)
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
|
||||
return
|
||||
if payload == self._payload["STATE_ON"]:
|
||||
self._attr_is_on = True
|
||||
elif payload == self._payload["STATE_OFF"]:
|
||||
self._attr_is_on = False
|
||||
elif payload == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
|
||||
@callback
|
||||
def _percentage_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the percentage."""
|
||||
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
|
||||
msg.payload
|
||||
)
|
||||
if not rendered_percentage_payload:
|
||||
_LOGGER.debug("Ignoring empty speed from '%s'", msg.topic)
|
||||
return
|
||||
if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]:
|
||||
self._attr_percentage = None
|
||||
return
|
||||
try:
|
||||
percentage = ranged_value_to_percentage(
|
||||
self._speed_range, int(rendered_percentage_payload)
|
||||
)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"'%s' received on topic %s. '%s' is not a valid speed within"
|
||||
" the speed range"
|
||||
),
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_percentage_payload,
|
||||
)
|
||||
return
|
||||
if percentage < 0 or percentage > 100:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"'%s' received on topic %s. '%s' is not a valid speed within"
|
||||
" the speed range"
|
||||
),
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_percentage_payload,
|
||||
)
|
||||
return
|
||||
self._attr_percentage = percentage
|
||||
|
||||
@callback
|
||||
def _preset_mode_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for preset mode."""
|
||||
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
|
||||
if preset_mode == self._payload["PRESET_MODE_RESET"]:
|
||||
self._attr_preset_mode = None
|
||||
return
|
||||
if not preset_mode:
|
||||
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
|
||||
return
|
||||
if not self.preset_modes or preset_mode not in self.preset_modes:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid preset mode",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
preset_mode,
|
||||
)
|
||||
return
|
||||
|
||||
self._attr_preset_mode = preset_mode
|
||||
|
||||
@callback
|
||||
def _oscillation_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the oscillation."""
|
||||
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty oscillation from '%s'", msg.topic)
|
||||
return
|
||||
if payload == self._payload["OSCILLATE_ON_PAYLOAD"]:
|
||||
self._attr_oscillating = True
|
||||
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
|
||||
self._attr_oscillating = False
|
||||
|
||||
@callback
|
||||
def _direction_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the direction."""
|
||||
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
|
||||
if not direction:
|
||||
_LOGGER.debug("Ignoring empty direction from '%s'", msg.topic)
|
||||
return
|
||||
self._attr_current_direction = str(direction)
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics: dict[str, Any] = {}
|
||||
|
||||
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
|
||||
def add_subscribe_topic(
|
||||
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
|
||||
) -> bool:
|
||||
"""Add a topic to subscribe to."""
|
||||
if has_topic := self._topic[topic] is not None:
|
||||
topics[topic] = {
|
||||
"topic": self._topic[topic],
|
||||
"msg_callback": msg_callback,
|
||||
"msg_callback": partial(
|
||||
self._message_callback, msg_callback, tracked_attributes
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
return has_topic
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
||||
def state_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message."""
|
||||
payload = self._value_templates[CONF_STATE](msg.payload)
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
|
||||
return
|
||||
if payload == self._payload["STATE_ON"]:
|
||||
self._attr_is_on = True
|
||||
elif payload == self._payload["STATE_OFF"]:
|
||||
self._attr_is_on = False
|
||||
elif payload == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
|
||||
add_subscribe_topic(CONF_STATE_TOPIC, state_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_percentage"})
|
||||
def percentage_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the percentage."""
|
||||
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
|
||||
msg.payload
|
||||
)
|
||||
if not rendered_percentage_payload:
|
||||
_LOGGER.debug("Ignoring empty speed from '%s'", msg.topic)
|
||||
return
|
||||
if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]:
|
||||
self._attr_percentage = None
|
||||
return
|
||||
try:
|
||||
percentage = ranged_value_to_percentage(
|
||||
self._speed_range, int(rendered_percentage_payload)
|
||||
)
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"'%s' received on topic %s. '%s' is not a valid speed within"
|
||||
" the speed range"
|
||||
),
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_percentage_payload,
|
||||
)
|
||||
return
|
||||
if percentage < 0 or percentage > 100:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"'%s' received on topic %s. '%s' is not a valid speed within"
|
||||
" the speed range"
|
||||
),
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_percentage_payload,
|
||||
)
|
||||
return
|
||||
self._attr_percentage = percentage
|
||||
|
||||
add_subscribe_topic(CONF_PERCENTAGE_STATE_TOPIC, percentage_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_preset_mode"})
|
||||
def preset_mode_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for preset mode."""
|
||||
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
|
||||
if preset_mode == self._payload["PRESET_MODE_RESET"]:
|
||||
self._attr_preset_mode = None
|
||||
return
|
||||
if not preset_mode:
|
||||
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
|
||||
return
|
||||
if not self.preset_modes or preset_mode not in self.preset_modes:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid preset mode",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
preset_mode,
|
||||
)
|
||||
return
|
||||
|
||||
self._attr_preset_mode = preset_mode
|
||||
|
||||
add_subscribe_topic(CONF_PRESET_MODE_STATE_TOPIC, preset_mode_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_oscillating"})
|
||||
def oscillation_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the oscillation."""
|
||||
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty oscillation from '%s'", msg.topic)
|
||||
return
|
||||
if payload == self._payload["OSCILLATE_ON_PAYLOAD"]:
|
||||
self._attr_oscillating = True
|
||||
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
|
||||
self._attr_oscillating = False
|
||||
|
||||
if add_subscribe_topic(CONF_OSCILLATION_STATE_TOPIC, oscillation_received):
|
||||
add_subscribe_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
|
||||
add_subscribe_topic(
|
||||
CONF_PERCENTAGE_STATE_TOPIC, self._percentage_received, {"_attr_percentage"}
|
||||
)
|
||||
add_subscribe_topic(
|
||||
CONF_PRESET_MODE_STATE_TOPIC,
|
||||
self._preset_mode_received,
|
||||
{"_attr_preset_mode"},
|
||||
)
|
||||
if add_subscribe_topic(
|
||||
CONF_OSCILLATION_STATE_TOPIC,
|
||||
self._oscillation_received,
|
||||
{"_attr_oscillating"},
|
||||
):
|
||||
self._attr_oscillating = False
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_current_direction"})
|
||||
def direction_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the direction."""
|
||||
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
|
||||
if not direction:
|
||||
_LOGGER.debug("Ignoring empty direction from '%s'", msg.topic)
|
||||
return
|
||||
self._attr_current_direction = str(direction)
|
||||
|
||||
add_subscribe_topic(CONF_DIRECTION_STATE_TOPIC, direction_received)
|
||||
add_subscribe_topic(
|
||||
CONF_DIRECTION_STATE_TOPIC,
|
||||
self._direction_received,
|
||||
{"_attr_current_direction"},
|
||||
)
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass, self._sub_state, topics
|
||||
@@ -476,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool | None:
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -51,12 +52,7 @@ from .const import (
|
||||
CONF_STATE_VALUE_TEMPLATE,
|
||||
PAYLOAD_NONE,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
MqttCommandTemplate,
|
||||
MqttValueTemplate,
|
||||
@@ -284,164 +280,166 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
||||
topics: dict[str, dict[str, Any]],
|
||||
topic: str,
|
||||
msg_callback: Callable[[ReceiveMessage], None],
|
||||
tracked_attributes: set[str],
|
||||
) -> None:
|
||||
"""Add a subscription."""
|
||||
qos: int = self._config[CONF_QOS]
|
||||
if topic in self._topic and self._topic[topic] is not None:
|
||||
topics[topic] = {
|
||||
"topic": self._topic[topic],
|
||||
"msg_callback": msg_callback,
|
||||
"msg_callback": partial(
|
||||
self._message_callback, msg_callback, tracked_attributes
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": qos,
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
|
||||
@callback
|
||||
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message."""
|
||||
payload = self._value_templates[CONF_STATE](msg.payload)
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
|
||||
return
|
||||
if payload == self._payload["STATE_ON"]:
|
||||
self._attr_is_on = True
|
||||
elif payload == self._payload["STATE_OFF"]:
|
||||
self._attr_is_on = False
|
||||
elif payload == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
|
||||
@callback
|
||||
def _action_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message."""
|
||||
action_payload = self._value_templates[ATTR_ACTION](msg.payload)
|
||||
if not action_payload or action_payload == PAYLOAD_NONE:
|
||||
_LOGGER.debug("Ignoring empty action from '%s'", msg.topic)
|
||||
return
|
||||
try:
|
||||
self._attr_action = HumidifierAction(str(action_payload))
|
||||
except ValueError:
|
||||
_LOGGER.error(
|
||||
"'%s' received on topic %s. '%s' is not a valid action",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
action_payload,
|
||||
)
|
||||
return
|
||||
|
||||
@callback
|
||||
def _current_humidity_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the current humidity."""
|
||||
rendered_current_humidity_payload = self._value_templates[
|
||||
ATTR_CURRENT_HUMIDITY
|
||||
](msg.payload)
|
||||
if rendered_current_humidity_payload == self._payload["HUMIDITY_RESET"]:
|
||||
self._attr_current_humidity = None
|
||||
return
|
||||
if not rendered_current_humidity_payload:
|
||||
_LOGGER.debug("Ignoring empty current humidity from '%s'", msg.topic)
|
||||
return
|
||||
try:
|
||||
current_humidity = round(float(rendered_current_humidity_payload))
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid humidity",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_current_humidity_payload,
|
||||
)
|
||||
return
|
||||
if current_humidity < 0 or current_humidity > 100:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid humidity",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_current_humidity_payload,
|
||||
)
|
||||
return
|
||||
self._attr_current_humidity = current_humidity
|
||||
|
||||
@callback
|
||||
def _target_humidity_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the target humidity."""
|
||||
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
|
||||
msg.payload
|
||||
)
|
||||
if not rendered_target_humidity_payload:
|
||||
_LOGGER.debug("Ignoring empty target humidity from '%s'", msg.topic)
|
||||
return
|
||||
if rendered_target_humidity_payload == self._payload["HUMIDITY_RESET"]:
|
||||
self._attr_target_humidity = None
|
||||
return
|
||||
try:
|
||||
target_humidity = round(float(rendered_target_humidity_payload))
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid target humidity",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_target_humidity_payload,
|
||||
)
|
||||
return
|
||||
if (
|
||||
target_humidity < self._attr_min_humidity
|
||||
or target_humidity > self._attr_max_humidity
|
||||
):
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid target humidity",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_target_humidity_payload,
|
||||
)
|
||||
return
|
||||
self._attr_target_humidity = target_humidity
|
||||
|
||||
@callback
|
||||
def _mode_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for mode."""
|
||||
mode = str(self._value_templates[ATTR_MODE](msg.payload))
|
||||
if mode == self._payload["MODE_RESET"]:
|
||||
self._attr_mode = None
|
||||
return
|
||||
if not mode:
|
||||
_LOGGER.debug("Ignoring empty mode from '%s'", msg.topic)
|
||||
return
|
||||
if not self.available_modes or mode not in self.available_modes:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid mode",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
mode,
|
||||
)
|
||||
return
|
||||
|
||||
self._attr_mode = mode
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics: dict[str, Any] = {}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
||||
def state_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message."""
|
||||
payload = self._value_templates[CONF_STATE](msg.payload)
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
|
||||
return
|
||||
if payload == self._payload["STATE_ON"]:
|
||||
self._attr_is_on = True
|
||||
elif payload == self._payload["STATE_OFF"]:
|
||||
self._attr_is_on = False
|
||||
elif payload == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
|
||||
self.add_subscription(topics, CONF_STATE_TOPIC, state_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_action"})
|
||||
def action_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message."""
|
||||
action_payload = self._value_templates[ATTR_ACTION](msg.payload)
|
||||
if not action_payload or action_payload == PAYLOAD_NONE:
|
||||
_LOGGER.debug("Ignoring empty action from '%s'", msg.topic)
|
||||
return
|
||||
try:
|
||||
self._attr_action = HumidifierAction(str(action_payload))
|
||||
except ValueError:
|
||||
_LOGGER.error(
|
||||
"'%s' received on topic %s. '%s' is not a valid action",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
action_payload,
|
||||
)
|
||||
return
|
||||
|
||||
self.add_subscription(topics, CONF_ACTION_TOPIC, action_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_current_humidity"})
|
||||
def current_humidity_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the current humidity."""
|
||||
rendered_current_humidity_payload = self._value_templates[
|
||||
ATTR_CURRENT_HUMIDITY
|
||||
](msg.payload)
|
||||
if rendered_current_humidity_payload == self._payload["HUMIDITY_RESET"]:
|
||||
self._attr_current_humidity = None
|
||||
return
|
||||
if not rendered_current_humidity_payload:
|
||||
_LOGGER.debug("Ignoring empty current humidity from '%s'", msg.topic)
|
||||
return
|
||||
try:
|
||||
current_humidity = round(float(rendered_current_humidity_payload))
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid humidity",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_current_humidity_payload,
|
||||
)
|
||||
return
|
||||
if current_humidity < 0 or current_humidity > 100:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid humidity",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_current_humidity_payload,
|
||||
)
|
||||
return
|
||||
self._attr_current_humidity = current_humidity
|
||||
|
||||
self.add_subscription(
|
||||
topics, CONF_CURRENT_HUMIDITY_TOPIC, current_humidity_received
|
||||
topics, CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"}
|
||||
)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_target_humidity"})
|
||||
def target_humidity_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for the target humidity."""
|
||||
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
|
||||
msg.payload
|
||||
)
|
||||
if not rendered_target_humidity_payload:
|
||||
_LOGGER.debug("Ignoring empty target humidity from '%s'", msg.topic)
|
||||
return
|
||||
if rendered_target_humidity_payload == self._payload["HUMIDITY_RESET"]:
|
||||
self._attr_target_humidity = None
|
||||
return
|
||||
try:
|
||||
target_humidity = round(float(rendered_target_humidity_payload))
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid target humidity",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_target_humidity_payload,
|
||||
)
|
||||
return
|
||||
if (
|
||||
target_humidity < self._attr_min_humidity
|
||||
or target_humidity > self._attr_max_humidity
|
||||
):
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid target humidity",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
rendered_target_humidity_payload,
|
||||
)
|
||||
return
|
||||
self._attr_target_humidity = target_humidity
|
||||
|
||||
self.add_subscription(
|
||||
topics, CONF_TARGET_HUMIDITY_STATE_TOPIC, target_humidity_received
|
||||
topics, CONF_ACTION_TOPIC, self._action_received, {"_attr_action"}
|
||||
)
|
||||
self.add_subscription(
|
||||
topics,
|
||||
CONF_CURRENT_HUMIDITY_TOPIC,
|
||||
self._current_humidity_received,
|
||||
{"_attr_current_humidity"},
|
||||
)
|
||||
self.add_subscription(
|
||||
topics,
|
||||
CONF_TARGET_HUMIDITY_STATE_TOPIC,
|
||||
self._target_humidity_received,
|
||||
{"_attr_target_humidity"},
|
||||
)
|
||||
self.add_subscription(
|
||||
topics, CONF_MODE_STATE_TOPIC, self._mode_received, {"_attr_mode"}
|
||||
)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_mode"})
|
||||
def mode_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new received MQTT message for mode."""
|
||||
mode = str(self._value_templates[ATTR_MODE](msg.payload))
|
||||
if mode == self._payload["MODE_RESET"]:
|
||||
self._attr_mode = None
|
||||
return
|
||||
if not mode:
|
||||
_LOGGER.debug("Ignoring empty mode from '%s'", msg.topic)
|
||||
return
|
||||
if not self.available_modes or mode not in self.available_modes:
|
||||
_LOGGER.warning(
|
||||
"'%s' received on topic %s. '%s' is not a valid mode",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
mode,
|
||||
)
|
||||
return
|
||||
|
||||
self._attr_mode = mode
|
||||
|
||||
self.add_subscription(topics, CONF_MODE_STATE_TOPIC, mode_received)
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass, self._sub_state, topics
|
||||
@@ -449,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||
"""Turn on the entity.
|
||||
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from base64 import b64decode
|
||||
import binascii
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -26,7 +27,6 @@ from homeassistant.util import dt as dt_util
|
||||
from . import subscription
|
||||
from .config import MQTT_BASE_SCHEMA
|
||||
from .const import CONF_ENCODING, CONF_QOS
|
||||
from .debug_info import log_messages
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
DATA_MQTT,
|
||||
@@ -143,6 +143,45 @@ class MqttImage(MqttEntity, ImageEntity):
|
||||
config.get(CONF_URL_TEMPLATE), entity=self
|
||||
).async_render_with_possible_json_value
|
||||
|
||||
@callback
|
||||
def _image_data_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
try:
|
||||
if CONF_IMAGE_ENCODING in self._config:
|
||||
self._last_image = b64decode(msg.payload)
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(msg.payload, bytes)
|
||||
self._last_image = msg.payload
|
||||
except (binascii.Error, ValueError, AssertionError) as err:
|
||||
_LOGGER.error(
|
||||
"Error processing image data received at topic %s: %s",
|
||||
msg.topic,
|
||||
err,
|
||||
)
|
||||
self._last_image = None
|
||||
self._attr_image_last_updated = dt_util.utcnow()
|
||||
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
||||
|
||||
@callback
|
||||
def _image_from_url_request_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
try:
|
||||
url = cv.url(self._url_template(msg.payload))
|
||||
self._attr_image_url = url
|
||||
except MqttValueTemplateException as exc:
|
||||
_LOGGER.warning(exc)
|
||||
return
|
||||
except vol.Invalid:
|
||||
_LOGGER.error(
|
||||
"Invalid image URL '%s' received at topic %s",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
)
|
||||
self._attr_image_last_updated = dt_util.utcnow()
|
||||
self._cached_image = None
|
||||
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@@ -159,56 +198,15 @@ class MqttImage(MqttEntity, ImageEntity):
|
||||
if has_topic := self._topic[topic] is not None:
|
||||
topics[topic] = {
|
||||
"topic": self._topic[topic],
|
||||
"msg_callback": msg_callback,
|
||||
"msg_callback": partial(self._message_callback, msg_callback, None),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": encoding,
|
||||
}
|
||||
return has_topic
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def image_data_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
try:
|
||||
if CONF_IMAGE_ENCODING in self._config:
|
||||
self._last_image = b64decode(msg.payload)
|
||||
else:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(msg.payload, bytes)
|
||||
self._last_image = msg.payload
|
||||
except (binascii.Error, ValueError, AssertionError) as err:
|
||||
_LOGGER.error(
|
||||
"Error processing image data received at topic %s: %s",
|
||||
msg.topic,
|
||||
err,
|
||||
)
|
||||
self._last_image = None
|
||||
self._attr_image_last_updated = dt_util.utcnow()
|
||||
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
||||
|
||||
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
def image_from_url_request_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
try:
|
||||
url = cv.url(self._url_template(msg.payload))
|
||||
self._attr_image_url = url
|
||||
except MqttValueTemplateException as exc:
|
||||
_LOGGER.warning(exc)
|
||||
return
|
||||
except vol.Invalid:
|
||||
_LOGGER.error(
|
||||
"Invalid image URL '%s' received at topic %s",
|
||||
msg.payload,
|
||||
msg.topic,
|
||||
)
|
||||
self._attr_image_last_updated = dt_util.utcnow()
|
||||
self._cached_image = None
|
||||
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
|
||||
|
||||
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)
|
||||
add_subscribe_topic(CONF_IMAGE_TOPIC, self._image_data_received)
|
||||
add_subscribe_topic(CONF_URL_TOPIC, self._image_from_url_request_received)
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass, self._sub_state, topics
|
||||
@@ -216,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_image(self) -> bytes | None:
|
||||
"""Return bytes of image."""
|
||||
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
@@ -31,12 +32,7 @@ from .const import (
|
||||
DEFAULT_OPTIMISTIC,
|
||||
DEFAULT_RETAIN,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
MqttCommandTemplate,
|
||||
MqttValueTemplate,
|
||||
@@ -150,57 +146,59 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
|
||||
config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self
|
||||
).async_render
|
||||
|
||||
@callback
|
||||
def _message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
payload = str(self._value_template(msg.payload))
|
||||
if not payload:
|
||||
_LOGGER.debug(
|
||||
"Invalid empty activity payload from topic %s, for entity %s",
|
||||
msg.topic,
|
||||
self.entity_id,
|
||||
)
|
||||
return
|
||||
if payload.lower() == "none":
|
||||
self._attr_activity = None
|
||||
return
|
||||
|
||||
try:
|
||||
self._attr_activity = LawnMowerActivity(payload)
|
||||
except ValueError:
|
||||
_LOGGER.error(
|
||||
"Invalid activity for %s: '%s' (valid activities: %s)",
|
||||
self.entity_id,
|
||||
payload,
|
||||
[option.value for option in LawnMowerActivity],
|
||||
)
|
||||
return
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_activity"})
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
payload = str(self._value_template(msg.payload))
|
||||
if not payload:
|
||||
_LOGGER.debug(
|
||||
"Invalid empty activity payload from topic %s, for entity %s",
|
||||
msg.topic,
|
||||
self.entity_id,
|
||||
)
|
||||
return
|
||||
if payload.lower() == "none":
|
||||
self._attr_activity = None
|
||||
return
|
||||
|
||||
try:
|
||||
self._attr_activity = LawnMowerActivity(payload)
|
||||
except ValueError:
|
||||
_LOGGER.error(
|
||||
"Invalid activity for %s: '%s' (valid activities: %s)",
|
||||
self.entity_id,
|
||||
payload,
|
||||
[option.value for option in LawnMowerActivity],
|
||||
)
|
||||
return
|
||||
|
||||
if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None:
|
||||
# Force into optimistic mode.
|
||||
self._attr_assumed_state = True
|
||||
else:
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
CONF_ACTIVITY_STATE_TOPIC: {
|
||||
"topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC),
|
||||
"msg_callback": message_received,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
CONF_ACTIVITY_STATE_TOPIC: {
|
||||
"topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC),
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._message_received,
|
||||
{"_attr_activity"},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
if self._attr_assumed_state and (
|
||||
last_state := await self.async_get_last_state()
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -53,8 +54,7 @@ from ..const import (
|
||||
CONF_STATE_VALUE_TEMPLATE,
|
||||
PAYLOAD_NONE,
|
||||
)
|
||||
from ..debug_info import log_messages
|
||||
from ..mixins import MqttEntity, write_state_on_attr_change
|
||||
from ..mixins import MqttEntity
|
||||
from ..models import (
|
||||
MessageCallbackType,
|
||||
MqttCommandTemplate,
|
||||
@@ -378,263 +378,248 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
||||
attr: bool = getattr(self, f"_optimistic_{attribute}")
|
||||
return attr
|
||||
|
||||
@callback
|
||||
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.NONE
|
||||
)
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
if payload == self._payload["on"]:
|
||||
self._attr_is_on = True
|
||||
elif payload == self._payload["off"]:
|
||||
self._attr_is_on = False
|
||||
elif payload == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
|
||||
@callback
|
||||
def _brightness_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for the brightness."""
|
||||
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty brightness message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
device_value = float(payload)
|
||||
if device_value == 0:
|
||||
_LOGGER.debug("Ignoring zero brightness from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
|
||||
self._attr_brightness = min(round(percent_bright * 255), 255)
|
||||
|
||||
@callback
|
||||
def _rgbx_received(
|
||||
self,
|
||||
msg: ReceiveMessage,
|
||||
template: str,
|
||||
color_mode: ColorMode,
|
||||
convert_color: Callable[..., tuple[int, ...]],
|
||||
) -> tuple[int, ...] | None:
|
||||
"""Process MQTT messages for RGBW and RGBWW."""
|
||||
payload = self._value_templates[template](msg.payload, PayloadSentinel.DEFAULT)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty %s message from '%s'", color_mode, msg.topic)
|
||||
return None
|
||||
color = tuple(int(val) for val in str(payload).split(","))
|
||||
if self._optimistic_color_mode:
|
||||
self._attr_color_mode = color_mode
|
||||
if self._topic[CONF_BRIGHTNESS_STATE_TOPIC] is None:
|
||||
rgb = convert_color(*color)
|
||||
brightness = max(rgb)
|
||||
if brightness == 0:
|
||||
_LOGGER.debug(
|
||||
"Ignoring %s message with zero rgb brightness from '%s'",
|
||||
color_mode,
|
||||
msg.topic,
|
||||
)
|
||||
return None
|
||||
self._attr_brightness = brightness
|
||||
# Normalize the color to 100% brightness
|
||||
color = tuple(
|
||||
min(round(channel / brightness * 255), 255) for channel in color
|
||||
)
|
||||
return color
|
||||
|
||||
@callback
|
||||
def _rgb_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for RGB."""
|
||||
rgb = self._rgbx_received(
|
||||
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x
|
||||
)
|
||||
if rgb is None:
|
||||
return
|
||||
self._attr_rgb_color = cast(tuple[int, int, int], rgb)
|
||||
|
||||
@callback
|
||||
def _rgbw_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for RGBW."""
|
||||
rgbw = self._rgbx_received(
|
||||
msg,
|
||||
CONF_RGBW_VALUE_TEMPLATE,
|
||||
ColorMode.RGBW,
|
||||
color_util.color_rgbw_to_rgb,
|
||||
)
|
||||
if rgbw is None:
|
||||
return
|
||||
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
|
||||
|
||||
@callback
|
||||
def _rgbww_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for RGBWW."""
|
||||
|
||||
@callback
|
||||
def _converter(
|
||||
r: int, g: int, b: int, cw: int, ww: int
|
||||
) -> tuple[int, int, int]:
|
||||
min_kelvin = color_util.color_temperature_mired_to_kelvin(self.max_mireds)
|
||||
max_kelvin = color_util.color_temperature_mired_to_kelvin(self.min_mireds)
|
||||
return color_util.color_rgbww_to_rgb(
|
||||
r, g, b, cw, ww, min_kelvin, max_kelvin
|
||||
)
|
||||
|
||||
rgbww = self._rgbx_received(
|
||||
msg,
|
||||
CONF_RGBWW_VALUE_TEMPLATE,
|
||||
ColorMode.RGBWW,
|
||||
_converter,
|
||||
)
|
||||
if rgbww is None:
|
||||
return
|
||||
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
|
||||
|
||||
@callback
|
||||
def _color_mode_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for color mode."""
|
||||
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty color mode message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
self._attr_color_mode = ColorMode(str(payload))
|
||||
|
||||
@callback
|
||||
def _color_temp_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for color temperature."""
|
||||
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty color temp message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
if self._optimistic_color_mode:
|
||||
self._attr_color_mode = ColorMode.COLOR_TEMP
|
||||
self._attr_color_temp = int(payload)
|
||||
|
||||
@callback
|
||||
def _effect_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for effect."""
|
||||
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty effect message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
self._attr_effect = str(payload)
|
||||
|
||||
@callback
|
||||
def _hs_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for hs color."""
|
||||
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic)
|
||||
return
|
||||
try:
|
||||
hs_color = tuple(float(val) for val in str(payload).split(",", 2))
|
||||
if self._optimistic_color_mode:
|
||||
self._attr_color_mode = ColorMode.HS
|
||||
self._attr_hs_color = cast(tuple[float, float], hs_color)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Failed to parse hs state update: '%s'", payload)
|
||||
|
||||
@callback
|
||||
def _xy_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for xy color."""
|
||||
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty xy-color message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
xy_color = tuple(float(val) for val in str(payload).split(",", 2))
|
||||
if self._optimistic_color_mode:
|
||||
self._attr_color_mode = ColorMode.XY
|
||||
self._attr_xy_color = cast(tuple[float, float], xy_color)
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None: # noqa: C901
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def add_topic(topic: str, msg_callback: MessageCallbackType) -> None:
|
||||
def add_topic(
|
||||
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
|
||||
) -> None:
|
||||
"""Add a topic."""
|
||||
if self._topic[topic] is not None:
|
||||
topics[topic] = {
|
||||
"topic": self._topic[topic],
|
||||
"msg_callback": msg_callback,
|
||||
"msg_callback": partial(
|
||||
self._message_callback, msg_callback, tracked_attributes
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
||||
def state_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.NONE
|
||||
)
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
if payload == self._payload["on"]:
|
||||
self._attr_is_on = True
|
||||
elif payload == self._payload["off"]:
|
||||
self._attr_is_on = False
|
||||
elif payload == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
|
||||
if self._topic[CONF_STATE_TOPIC] is not None:
|
||||
topics[CONF_STATE_TOPIC] = {
|
||||
"topic": self._topic[CONF_STATE_TOPIC],
|
||||
"msg_callback": state_received,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_brightness"})
|
||||
def brightness_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for the brightness."""
|
||||
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty brightness message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
device_value = float(payload)
|
||||
if device_value == 0:
|
||||
_LOGGER.debug("Ignoring zero brightness from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
|
||||
self._attr_brightness = min(round(percent_bright * 255), 255)
|
||||
|
||||
add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received)
|
||||
|
||||
@callback
|
||||
def _rgbx_received(
|
||||
msg: ReceiveMessage,
|
||||
template: str,
|
||||
color_mode: ColorMode,
|
||||
convert_color: Callable[..., tuple[int, ...]],
|
||||
) -> tuple[int, ...] | None:
|
||||
"""Handle new MQTT messages for RGBW and RGBWW."""
|
||||
payload = self._value_templates[template](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty %s message from '%s'", color_mode, msg.topic
|
||||
)
|
||||
return None
|
||||
color = tuple(int(val) for val in str(payload).split(","))
|
||||
if self._optimistic_color_mode:
|
||||
self._attr_color_mode = color_mode
|
||||
if self._topic[CONF_BRIGHTNESS_STATE_TOPIC] is None:
|
||||
rgb = convert_color(*color)
|
||||
brightness = max(rgb)
|
||||
if brightness == 0:
|
||||
_LOGGER.debug(
|
||||
"Ignoring %s message with zero rgb brightness from '%s'",
|
||||
color_mode,
|
||||
msg.topic,
|
||||
)
|
||||
return None
|
||||
self._attr_brightness = brightness
|
||||
# Normalize the color to 100% brightness
|
||||
color = tuple(
|
||||
min(round(channel / brightness * 255), 255) for channel in color
|
||||
)
|
||||
return color
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"}
|
||||
add_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
|
||||
add_topic(
|
||||
CONF_BRIGHTNESS_STATE_TOPIC, self._brightness_received, {"_attr_brightness"}
|
||||
)
|
||||
def rgb_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for RGB."""
|
||||
rgb = _rgbx_received(
|
||||
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x
|
||||
)
|
||||
if rgb is None:
|
||||
return
|
||||
self._attr_rgb_color = cast(tuple[int, int, int], rgb)
|
||||
|
||||
add_topic(CONF_RGB_STATE_TOPIC, rgb_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"}
|
||||
add_topic(
|
||||
CONF_RGB_STATE_TOPIC,
|
||||
self._rgb_received,
|
||||
{"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"},
|
||||
)
|
||||
def rgbw_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for RGBW."""
|
||||
rgbw = _rgbx_received(
|
||||
msg,
|
||||
CONF_RGBW_VALUE_TEMPLATE,
|
||||
ColorMode.RGBW,
|
||||
color_util.color_rgbw_to_rgb,
|
||||
)
|
||||
if rgbw is None:
|
||||
return
|
||||
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
|
||||
|
||||
add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"}
|
||||
add_topic(
|
||||
CONF_RGBW_STATE_TOPIC,
|
||||
self._rgbw_received,
|
||||
{"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"},
|
||||
)
|
||||
add_topic(
|
||||
CONF_RGBWW_STATE_TOPIC,
|
||||
self._rgbww_received,
|
||||
{"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"},
|
||||
)
|
||||
add_topic(
|
||||
CONF_COLOR_MODE_STATE_TOPIC, self._color_mode_received, {"_attr_color_mode"}
|
||||
)
|
||||
add_topic(
|
||||
CONF_COLOR_TEMP_STATE_TOPIC,
|
||||
self._color_temp_received,
|
||||
{"_attr_color_mode", "_attr_color_temp"},
|
||||
)
|
||||
add_topic(CONF_EFFECT_STATE_TOPIC, self._effect_received, {"_attr_effect"})
|
||||
add_topic(
|
||||
CONF_HS_STATE_TOPIC,
|
||||
self._hs_received,
|
||||
{"_attr_color_mode", "_attr_hs_color"},
|
||||
)
|
||||
add_topic(
|
||||
CONF_XY_STATE_TOPIC,
|
||||
self._xy_received,
|
||||
{"_attr_color_mode", "_attr_xy_color"},
|
||||
)
|
||||
def rgbww_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for RGBWW."""
|
||||
|
||||
@callback
|
||||
def _converter(
|
||||
r: int, g: int, b: int, cw: int, ww: int
|
||||
) -> tuple[int, int, int]:
|
||||
min_kelvin = color_util.color_temperature_mired_to_kelvin(
|
||||
self.max_mireds
|
||||
)
|
||||
max_kelvin = color_util.color_temperature_mired_to_kelvin(
|
||||
self.min_mireds
|
||||
)
|
||||
return color_util.color_rgbww_to_rgb(
|
||||
r, g, b, cw, ww, min_kelvin, max_kelvin
|
||||
)
|
||||
|
||||
rgbww = _rgbx_received(
|
||||
msg,
|
||||
CONF_RGBWW_VALUE_TEMPLATE,
|
||||
ColorMode.RGBWW,
|
||||
_converter,
|
||||
)
|
||||
if rgbww is None:
|
||||
return
|
||||
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
|
||||
|
||||
add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_color_mode"})
|
||||
def color_mode_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for color mode."""
|
||||
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty color mode message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
self._attr_color_mode = ColorMode(str(payload))
|
||||
|
||||
add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_color_temp"})
|
||||
def color_temp_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for color temperature."""
|
||||
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty color temp message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
if self._optimistic_color_mode:
|
||||
self._attr_color_mode = ColorMode.COLOR_TEMP
|
||||
self._attr_color_temp = int(payload)
|
||||
|
||||
add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_effect"})
|
||||
def effect_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for effect."""
|
||||
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty effect message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
self._attr_effect = str(payload)
|
||||
|
||||
add_topic(CONF_EFFECT_STATE_TOPIC, effect_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_hs_color"})
|
||||
def hs_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for hs color."""
|
||||
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic)
|
||||
return
|
||||
try:
|
||||
hs_color = tuple(float(val) for val in str(payload).split(",", 2))
|
||||
if self._optimistic_color_mode:
|
||||
self._attr_color_mode = ColorMode.HS
|
||||
self._attr_hs_color = cast(tuple[float, float], hs_color)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Failed to parse hs state update: '%s'", payload)
|
||||
|
||||
add_topic(CONF_HS_STATE_TOPIC, hs_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_xy_color"})
|
||||
def xy_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages for xy color."""
|
||||
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
|
||||
msg.payload, PayloadSentinel.DEFAULT
|
||||
)
|
||||
if payload is PayloadSentinel.DEFAULT or not payload:
|
||||
_LOGGER.debug("Ignoring empty xy-color message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
xy_color = tuple(float(val) for val in str(payload).split(",", 2))
|
||||
if self._optimistic_color_mode:
|
||||
self._attr_color_mode = ColorMode.XY
|
||||
self._attr_xy_color = cast(tuple[float, float], xy_color)
|
||||
|
||||
add_topic(CONF_XY_STATE_TOPIC, xy_received)
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass, self._sub_state, topics
|
||||
@@ -642,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
last_state = await self.async_get_last_state()
|
||||
|
||||
def restore_state(
|
||||
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
@@ -66,8 +67,7 @@ from ..const import (
|
||||
CONF_STATE_TOPIC,
|
||||
DOMAIN as MQTT_DOMAIN,
|
||||
)
|
||||
from ..debug_info import log_messages
|
||||
from ..mixins import MqttEntity, write_state_on_attr_change
|
||||
from ..mixins import MqttEntity
|
||||
from ..models import ReceiveMessage
|
||||
from ..schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||
from ..util import valid_subscribe_topic
|
||||
@@ -414,118 +414,121 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
@callback
|
||||
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
values = json_loads_object(msg.payload)
|
||||
|
||||
if values["state"] == "ON":
|
||||
self._attr_is_on = True
|
||||
elif values["state"] == "OFF":
|
||||
self._attr_is_on = False
|
||||
elif values["state"] is None:
|
||||
self._attr_is_on = None
|
||||
|
||||
if (
|
||||
self._deprecated_color_handling
|
||||
and color_supported(self.supported_color_modes)
|
||||
and "color" in values
|
||||
):
|
||||
# Deprecated color handling
|
||||
if values["color"] is None:
|
||||
self._attr_hs_color = None
|
||||
else:
|
||||
self._update_color(values)
|
||||
|
||||
if not self._deprecated_color_handling and "color_mode" in values:
|
||||
self._update_color(values)
|
||||
|
||||
if brightness_supported(self.supported_color_modes):
|
||||
try:
|
||||
if brightness := values["brightness"]:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(brightness, float)
|
||||
self._attr_brightness = color_util.value_to_brightness(
|
||||
(1, self._config[CONF_BRIGHTNESS_SCALE]), brightness
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Ignoring zero brightness value for entity %s",
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
pass
|
||||
except (TypeError, ValueError):
|
||||
_LOGGER.warning(
|
||||
"Invalid brightness value '%s' received for entity %s",
|
||||
values["brightness"],
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
if (
|
||||
self._deprecated_color_handling
|
||||
and self.supported_color_modes
|
||||
and ColorMode.COLOR_TEMP in self.supported_color_modes
|
||||
):
|
||||
# Deprecated color handling
|
||||
try:
|
||||
if values["color_temp"] is None:
|
||||
self._attr_color_temp = None
|
||||
else:
|
||||
self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type]
|
||||
except KeyError:
|
||||
pass
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Invalid color temp value '%s' received for entity %s",
|
||||
values["color_temp"],
|
||||
self.entity_id,
|
||||
)
|
||||
# Allow to switch back to color_temp
|
||||
if "color" not in values:
|
||||
self._attr_hs_color = None
|
||||
|
||||
if self.supported_features and LightEntityFeature.EFFECT:
|
||||
with suppress(KeyError):
|
||||
self._attr_effect = cast(str, values["effect"])
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self,
|
||||
#
|
||||
if self._topic[CONF_STATE_TOPIC] is None:
|
||||
return
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"_attr_brightness",
|
||||
"_attr_color_temp",
|
||||
"_attr_effect",
|
||||
"_attr_hs_color",
|
||||
"_attr_is_on",
|
||||
"_attr_rgb_color",
|
||||
"_attr_rgbw_color",
|
||||
"_attr_rgbww_color",
|
||||
"_attr_xy_color",
|
||||
"color_mode",
|
||||
CONF_STATE_TOPIC: {
|
||||
"topic": self._topic[CONF_STATE_TOPIC],
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._state_received,
|
||||
{
|
||||
"_attr_brightness",
|
||||
"_attr_color_temp",
|
||||
"_attr_effect",
|
||||
"_attr_hs_color",
|
||||
"_attr_is_on",
|
||||
"_attr_rgb_color",
|
||||
"_attr_rgbw_color",
|
||||
"_attr_rgbww_color",
|
||||
"_attr_xy_color",
|
||||
"color_mode",
|
||||
},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
def state_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
values = json_loads_object(msg.payload)
|
||||
|
||||
if values["state"] == "ON":
|
||||
self._attr_is_on = True
|
||||
elif values["state"] == "OFF":
|
||||
self._attr_is_on = False
|
||||
elif values["state"] is None:
|
||||
self._attr_is_on = None
|
||||
|
||||
if (
|
||||
self._deprecated_color_handling
|
||||
and color_supported(self.supported_color_modes)
|
||||
and "color" in values
|
||||
):
|
||||
# Deprecated color handling
|
||||
if values["color"] is None:
|
||||
self._attr_hs_color = None
|
||||
else:
|
||||
self._update_color(values)
|
||||
|
||||
if not self._deprecated_color_handling and "color_mode" in values:
|
||||
self._update_color(values)
|
||||
|
||||
if brightness_supported(self.supported_color_modes):
|
||||
try:
|
||||
if brightness := values["brightness"]:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(brightness, float)
|
||||
self._attr_brightness = color_util.value_to_brightness(
|
||||
(1, self._config[CONF_BRIGHTNESS_SCALE]), brightness
|
||||
)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Ignoring zero brightness value for entity %s",
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
pass
|
||||
except (TypeError, ValueError):
|
||||
_LOGGER.warning(
|
||||
"Invalid brightness value '%s' received for entity %s",
|
||||
values["brightness"],
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
if (
|
||||
self._deprecated_color_handling
|
||||
and self.supported_color_modes
|
||||
and ColorMode.COLOR_TEMP in self.supported_color_modes
|
||||
):
|
||||
# Deprecated color handling
|
||||
try:
|
||||
if values["color_temp"] is None:
|
||||
self._attr_color_temp = None
|
||||
else:
|
||||
self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type]
|
||||
except KeyError:
|
||||
pass
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Invalid color temp value '%s' received for entity %s",
|
||||
values["color_temp"],
|
||||
self.entity_id,
|
||||
)
|
||||
# Allow to switch back to color_temp
|
||||
if "color" not in values:
|
||||
self._attr_hs_color = None
|
||||
|
||||
if self.supported_features and LightEntityFeature.EFFECT:
|
||||
with suppress(KeyError):
|
||||
self._attr_effect = cast(str, values["effect"])
|
||||
|
||||
if self._topic[CONF_STATE_TOPIC] is not None:
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"state_topic": {
|
||||
"topic": self._topic[CONF_STATE_TOPIC],
|
||||
"msg_callback": state_received,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
last_state = await self.async_get_last_state()
|
||||
if self._optimistic and last_state:
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -44,8 +45,7 @@ from ..const import (
|
||||
CONF_STATE_TOPIC,
|
||||
PAYLOAD_NONE,
|
||||
)
|
||||
from ..debug_info import log_messages
|
||||
from ..mixins import MqttEntity, write_state_on_attr_change
|
||||
from ..mixins import MqttEntity
|
||||
from ..models import (
|
||||
MqttCommandTemplate,
|
||||
MqttValueTemplate,
|
||||
@@ -188,107 +188,107 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
|
||||
# Support for ct + hs, prioritize hs
|
||||
self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP
|
||||
|
||||
@callback
|
||||
def _state_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
|
||||
if state == STATE_ON:
|
||||
self._attr_is_on = True
|
||||
elif state == STATE_OFF:
|
||||
self._attr_is_on = False
|
||||
elif state == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
else:
|
||||
_LOGGER.warning("Invalid state value received")
|
||||
|
||||
if CONF_BRIGHTNESS_TEMPLATE in self._config:
|
||||
try:
|
||||
if brightness := int(
|
||||
self._value_templates[CONF_BRIGHTNESS_TEMPLATE](msg.payload)
|
||||
):
|
||||
self._attr_brightness = brightness
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Ignoring zero brightness value for entity %s",
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
_LOGGER.warning("Invalid brightness value received from %s", msg.topic)
|
||||
|
||||
if CONF_COLOR_TEMP_TEMPLATE in self._config:
|
||||
try:
|
||||
color_temp = self._value_templates[CONF_COLOR_TEMP_TEMPLATE](
|
||||
msg.payload
|
||||
)
|
||||
self._attr_color_temp = (
|
||||
int(color_temp) if color_temp != "None" else None
|
||||
)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Invalid color temperature value received")
|
||||
|
||||
if (
|
||||
CONF_RED_TEMPLATE in self._config
|
||||
and CONF_GREEN_TEMPLATE in self._config
|
||||
and CONF_BLUE_TEMPLATE in self._config
|
||||
):
|
||||
try:
|
||||
red = self._value_templates[CONF_RED_TEMPLATE](msg.payload)
|
||||
green = self._value_templates[CONF_GREEN_TEMPLATE](msg.payload)
|
||||
blue = self._value_templates[CONF_BLUE_TEMPLATE](msg.payload)
|
||||
if red == "None" and green == "None" and blue == "None":
|
||||
self._attr_hs_color = None
|
||||
else:
|
||||
self._attr_hs_color = color_util.color_RGB_to_hs(
|
||||
int(red), int(green), int(blue)
|
||||
)
|
||||
self._update_color_mode()
|
||||
except ValueError:
|
||||
_LOGGER.warning("Invalid color value received")
|
||||
|
||||
if CONF_EFFECT_TEMPLATE in self._config:
|
||||
effect = str(self._value_templates[CONF_EFFECT_TEMPLATE](msg.payload))
|
||||
if (
|
||||
effect_list := self._config[CONF_EFFECT_LIST]
|
||||
) and effect in effect_list:
|
||||
self._attr_effect = effect
|
||||
else:
|
||||
_LOGGER.warning("Unsupported effect value received")
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self,
|
||||
if self._topics[CONF_STATE_TOPIC] is None:
|
||||
return
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"_attr_brightness",
|
||||
"_attr_color_mode",
|
||||
"_attr_color_temp",
|
||||
"_attr_effect",
|
||||
"_attr_hs_color",
|
||||
"_attr_is_on",
|
||||
"state_topic": {
|
||||
"topic": self._topics[CONF_STATE_TOPIC],
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._state_received,
|
||||
{
|
||||
"_attr_brightness",
|
||||
"_attr_color_mode",
|
||||
"_attr_color_temp",
|
||||
"_attr_effect",
|
||||
"_attr_hs_color",
|
||||
"_attr_is_on",
|
||||
},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
def state_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
|
||||
if state == STATE_ON:
|
||||
self._attr_is_on = True
|
||||
elif state == STATE_OFF:
|
||||
self._attr_is_on = False
|
||||
elif state == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
else:
|
||||
_LOGGER.warning("Invalid state value received")
|
||||
|
||||
if CONF_BRIGHTNESS_TEMPLATE in self._config:
|
||||
try:
|
||||
if brightness := int(
|
||||
self._value_templates[CONF_BRIGHTNESS_TEMPLATE](msg.payload)
|
||||
):
|
||||
self._attr_brightness = brightness
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Ignoring zero brightness value for entity %s",
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
_LOGGER.warning(
|
||||
"Invalid brightness value received from %s", msg.topic
|
||||
)
|
||||
|
||||
if CONF_COLOR_TEMP_TEMPLATE in self._config:
|
||||
try:
|
||||
color_temp = self._value_templates[CONF_COLOR_TEMP_TEMPLATE](
|
||||
msg.payload
|
||||
)
|
||||
self._attr_color_temp = (
|
||||
int(color_temp) if color_temp != "None" else None
|
||||
)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Invalid color temperature value received")
|
||||
|
||||
if (
|
||||
CONF_RED_TEMPLATE in self._config
|
||||
and CONF_GREEN_TEMPLATE in self._config
|
||||
and CONF_BLUE_TEMPLATE in self._config
|
||||
):
|
||||
try:
|
||||
red = self._value_templates[CONF_RED_TEMPLATE](msg.payload)
|
||||
green = self._value_templates[CONF_GREEN_TEMPLATE](msg.payload)
|
||||
blue = self._value_templates[CONF_BLUE_TEMPLATE](msg.payload)
|
||||
if red == "None" and green == "None" and blue == "None":
|
||||
self._attr_hs_color = None
|
||||
else:
|
||||
self._attr_hs_color = color_util.color_RGB_to_hs(
|
||||
int(red), int(green), int(blue)
|
||||
)
|
||||
self._update_color_mode()
|
||||
except ValueError:
|
||||
_LOGGER.warning("Invalid color value received")
|
||||
|
||||
if CONF_EFFECT_TEMPLATE in self._config:
|
||||
effect = str(self._value_templates[CONF_EFFECT_TEMPLATE](msg.payload))
|
||||
if (
|
||||
effect_list := self._config[CONF_EFFECT_LIST]
|
||||
) and effect in effect_list:
|
||||
self._attr_effect = effect
|
||||
else:
|
||||
_LOGGER.warning("Unsupported effect value received")
|
||||
|
||||
if self._topics[CONF_STATE_TOPIC] is not None:
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"state_topic": {
|
||||
"topic": self._topics[CONF_STATE_TOPIC],
|
||||
"msg_callback": state_received,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
last_state = await self.async_get_last_state()
|
||||
if self._optimistic and last_state:
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
@@ -36,12 +37,7 @@ from .const import (
|
||||
CONF_STATE_OPENING,
|
||||
CONF_STATE_TOPIC,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
MqttCommandTemplate,
|
||||
MqttValueTemplate,
|
||||
@@ -186,57 +182,58 @@ class MqttLock(MqttEntity, LockEntity):
|
||||
|
||||
self._valid_states = [config[state] for state in STATE_CONFIG_KEYS]
|
||||
|
||||
@callback
|
||||
def _message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new lock state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
if not payload.strip(): # No output from template, ignore
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
if payload == self._config[CONF_PAYLOAD_RESET]:
|
||||
# Reset the state to `unknown`
|
||||
self._attr_is_locked = None
|
||||
elif payload in self._valid_states:
|
||||
self._attr_is_locked = payload == self._config[CONF_STATE_LOCKED]
|
||||
self._attr_is_locking = payload == self._config[CONF_STATE_LOCKING]
|
||||
self._attr_is_open = payload == self._config[CONF_STATE_OPEN]
|
||||
self._attr_is_opening = payload == self._config[CONF_STATE_OPENING]
|
||||
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
|
||||
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
topics: dict[str, dict[str, Any]] = {}
|
||||
topics: dict[str, dict[str, Any]]
|
||||
qos: int = self._config[CONF_QOS]
|
||||
encoding: str | None = self._config[CONF_ENCODING] or None
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self,
|
||||
{
|
||||
"_attr_is_jammed",
|
||||
"_attr_is_locked",
|
||||
"_attr_is_locking",
|
||||
"_attr_is_open",
|
||||
"_attr_is_opening",
|
||||
"_attr_is_unlocking",
|
||||
},
|
||||
)
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new lock state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
if not payload.strip(): # No output from template, ignore
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
if payload == self._config[CONF_PAYLOAD_RESET]:
|
||||
# Reset the state to `unknown`
|
||||
self._attr_is_locked = None
|
||||
elif payload in self._valid_states:
|
||||
self._attr_is_locked = payload == self._config[CONF_STATE_LOCKED]
|
||||
self._attr_is_locking = payload == self._config[CONF_STATE_LOCKING]
|
||||
self._attr_is_open = payload == self._config[CONF_STATE_OPEN]
|
||||
self._attr_is_opening = payload == self._config[CONF_STATE_OPENING]
|
||||
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
|
||||
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
|
||||
|
||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||
# Force into optimistic mode.
|
||||
self._optimistic = True
|
||||
else:
|
||||
topics[CONF_STATE_TOPIC] = {
|
||||
return
|
||||
topics = {
|
||||
CONF_STATE_TOPIC: {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": message_received,
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._message_received,
|
||||
{
|
||||
"_attr_is_jammed",
|
||||
"_attr_is_locked",
|
||||
"_attr_is_locking",
|
||||
"_attr_is_open",
|
||||
"_attr_is_opening",
|
||||
"_attr_is_unlocking",
|
||||
},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
CONF_QOS: qos,
|
||||
CONF_ENCODING: encoding,
|
||||
}
|
||||
}
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
@@ -246,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_lock(self, **kwargs: Any) -> None:
|
||||
"""Lock the device.
|
||||
|
@@ -114,7 +114,7 @@ from .models import (
|
||||
from .subscription import (
|
||||
EntitySubscription,
|
||||
async_prepare_subscribe_topics,
|
||||
async_subscribe_topics,
|
||||
async_subscribe_topics_internal,
|
||||
async_unsubscribe_topics,
|
||||
)
|
||||
from .util import mqtt_config_entry_enabled
|
||||
@@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity):
|
||||
"""Subscribe MQTT events."""
|
||||
await super().async_added_to_hass()
|
||||
self._attributes_prepare_subscribe_topics()
|
||||
await self._attributes_subscribe_topics()
|
||||
self._attributes_subscribe_topics()
|
||||
|
||||
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||
"""Handle updated discovery message."""
|
||||
@@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity):
|
||||
|
||||
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||
"""Handle updated discovery message."""
|
||||
await self._attributes_subscribe_topics()
|
||||
self._attributes_subscribe_topics()
|
||||
|
||||
def _attributes_prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
@@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity):
|
||||
},
|
||||
)
|
||||
|
||||
async def _attributes_subscribe_topics(self) -> None:
|
||||
@callback
|
||||
def _attributes_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await async_subscribe_topics(self.hass, self._attributes_sub_state)
|
||||
async_subscribe_topics_internal(self.hass, self._attributes_sub_state)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Unsubscribe when removed."""
|
||||
@@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity):
|
||||
"""Subscribe MQTT events."""
|
||||
await super().async_added_to_hass()
|
||||
self._availability_prepare_subscribe_topics()
|
||||
await self._availability_subscribe_topics()
|
||||
self._availability_subscribe_topics()
|
||||
self.async_on_remove(
|
||||
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
|
||||
)
|
||||
@@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity):
|
||||
|
||||
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
|
||||
"""Handle updated discovery message."""
|
||||
await self._availability_subscribe_topics()
|
||||
self._availability_subscribe_topics()
|
||||
|
||||
def _availability_setup_from_config(self, config: ConfigType) -> None:
|
||||
"""(Re)Setup."""
|
||||
@@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity):
|
||||
self._available[topic] = False
|
||||
self._available_latest = False
|
||||
|
||||
async def _availability_subscribe_topics(self) -> None:
|
||||
@callback
|
||||
def _availability_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await async_subscribe_topics(self.hass, self._availability_sub_state)
|
||||
async_subscribe_topics_internal(self.hass, self._availability_sub_state)
|
||||
|
||||
@callback
|
||||
def async_mqtt_connect(self) -> None:
|
||||
@@ -1254,13 +1256,15 @@ class MqttEntity(
|
||||
def _message_callback(
|
||||
self,
|
||||
msg_callback: MessageCallbackType,
|
||||
attributes: set[str],
|
||||
attributes: set[str] | None,
|
||||
msg: ReceiveMessage,
|
||||
) -> None:
|
||||
"""Process the message callback."""
|
||||
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
|
||||
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes
|
||||
)
|
||||
if attributes is not None:
|
||||
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
|
||||
(attribute, getattr(self, attribute, UNDEFINED))
|
||||
for attribute in attributes
|
||||
)
|
||||
mqtt_data = self.hass.data[DATA_MQTT]
|
||||
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
|
||||
msg.subscribed_topic
|
||||
@@ -1274,7 +1278,7 @@ class MqttEntity(
|
||||
_LOGGER.warning(exc)
|
||||
return
|
||||
|
||||
if self._attrs_have_changed(attrs_snapshot):
|
||||
if attributes is not None and self._attrs_have_changed(attrs_snapshot):
|
||||
mqtt_data.state_write_requests.write_state_request(self)
|
||||
|
||||
|
||||
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from ast import literal_eval
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Coroutine
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
import logging
|
||||
@@ -70,7 +70,6 @@ class ReceiveMessage:
|
||||
timestamp: float
|
||||
|
||||
|
||||
type AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]]
|
||||
type MessageCallbackType = Callable[[ReceiveMessage], None]
|
||||
|
||||
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
@@ -41,12 +42,7 @@ from .const import (
|
||||
CONF_RETAIN,
|
||||
CONF_STATE_TOPIC,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
MqttCommandTemplate,
|
||||
MqttValueTemplate,
|
||||
@@ -165,64 +161,66 @@ class MqttNumber(MqttEntity, RestoreNumber):
|
||||
self._attr_native_step = config[CONF_STEP]
|
||||
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
|
||||
|
||||
@callback
|
||||
def _message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
num_value: int | float | None
|
||||
payload = str(self._value_template(msg.payload))
|
||||
if not payload.strip():
|
||||
_LOGGER.debug("Ignoring empty state update from '%s'", msg.topic)
|
||||
return
|
||||
try:
|
||||
if payload == self._config[CONF_PAYLOAD_RESET]:
|
||||
num_value = None
|
||||
elif payload.isnumeric():
|
||||
num_value = int(payload)
|
||||
else:
|
||||
num_value = float(payload)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Payload '%s' is not a Number", msg.payload)
|
||||
return
|
||||
|
||||
if num_value is not None and (
|
||||
num_value < self.min_value or num_value > self.max_value
|
||||
):
|
||||
_LOGGER.error(
|
||||
"Invalid value for %s: %s (range %s - %s)",
|
||||
self.entity_id,
|
||||
num_value,
|
||||
self.min_value,
|
||||
self.max_value,
|
||||
)
|
||||
return
|
||||
|
||||
self._attr_native_value = num_value
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_native_value"})
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
num_value: int | float | None
|
||||
payload = str(self._value_template(msg.payload))
|
||||
if not payload.strip():
|
||||
_LOGGER.debug("Ignoring empty state update from '%s'", msg.topic)
|
||||
return
|
||||
try:
|
||||
if payload == self._config[CONF_PAYLOAD_RESET]:
|
||||
num_value = None
|
||||
elif payload.isnumeric():
|
||||
num_value = int(payload)
|
||||
else:
|
||||
num_value = float(payload)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Payload '%s' is not a Number", msg.payload)
|
||||
return
|
||||
|
||||
if num_value is not None and (
|
||||
num_value < self.min_value or num_value > self.max_value
|
||||
):
|
||||
_LOGGER.error(
|
||||
"Invalid value for %s: %s (range %s - %s)",
|
||||
self.entity_id,
|
||||
num_value,
|
||||
self.min_value,
|
||||
self.max_value,
|
||||
)
|
||||
return
|
||||
|
||||
self._attr_native_value = num_value
|
||||
|
||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||
# Force into optimistic mode.
|
||||
self._attr_assumed_state = True
|
||||
else:
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"state_topic": {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": message_received,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"state_topic": {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._message_received,
|
||||
{"_attr_native_value"},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
if self._attr_assumed_state and (
|
||||
last_number_data := await self.async_get_last_number_data()
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
@@ -27,12 +28,7 @@ from .const import (
|
||||
CONF_RETAIN,
|
||||
CONF_STATE_TOPIC,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
MqttCommandTemplate,
|
||||
MqttValueTemplate,
|
||||
@@ -113,56 +109,58 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
|
||||
config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||
).async_render_with_possible_json_value
|
||||
|
||||
@callback
|
||||
def _message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
payload = str(self._value_template(msg.payload))
|
||||
if not payload.strip(): # No output from template, ignore
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
if payload.lower() == "none":
|
||||
self._attr_current_option = None
|
||||
return
|
||||
|
||||
if payload not in self.options:
|
||||
_LOGGER.error(
|
||||
"Invalid option for %s: '%s' (valid options: %s)",
|
||||
self.entity_id,
|
||||
payload,
|
||||
self.options,
|
||||
)
|
||||
return
|
||||
self._attr_current_option = payload
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_current_option"})
|
||||
def message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT messages."""
|
||||
payload = str(self._value_template(msg.payload))
|
||||
if not payload.strip(): # No output from template, ignore
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
if payload.lower() == "none":
|
||||
self._attr_current_option = None
|
||||
return
|
||||
|
||||
if payload not in self.options:
|
||||
_LOGGER.error(
|
||||
"Invalid option for %s: '%s' (valid options: %s)",
|
||||
self.entity_id,
|
||||
payload,
|
||||
self.options,
|
||||
)
|
||||
return
|
||||
self._attr_current_option = payload
|
||||
|
||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||
# Force into optimistic mode.
|
||||
self._attr_assumed_state = True
|
||||
else:
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"state_topic": {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": message_received,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
"state_topic": {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._message_received,
|
||||
{"_attr_current_option"},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
if self._attr_assumed_state and (
|
||||
last_state := await self.async_get_last_state()
|
||||
|
@@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
@callback
|
||||
def _value_is_expired(self, *_: datetime) -> None:
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -48,12 +49,7 @@ from .const import (
|
||||
PAYLOAD_EMPTY_JSON,
|
||||
PAYLOAD_NONE,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
MqttCommandTemplate,
|
||||
MqttValueTemplate,
|
||||
@@ -205,92 +201,94 @@ class MqttSiren(MqttEntity, SirenEntity):
|
||||
entity=self,
|
||||
).async_render_with_possible_json_value
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_is_on", "_extra_attributes"})
|
||||
def state_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
if not payload or payload == PAYLOAD_EMPTY_JSON:
|
||||
@callback
|
||||
def _state_message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
if not payload or payload == PAYLOAD_EMPTY_JSON:
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
json_payload: dict[str, Any] = {}
|
||||
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]:
|
||||
json_payload = {STATE: payload}
|
||||
else:
|
||||
try:
|
||||
json_payload = json_loads_object(payload)
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
(
|
||||
"JSON payload detected after processing payload '%s' on"
|
||||
" topic %s"
|
||||
),
|
||||
json_payload,
|
||||
msg.topic,
|
||||
)
|
||||
except JSON_DECODE_EXCEPTIONS:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"No valid (JSON) payload detected after processing payload"
|
||||
" '%s' on topic %s"
|
||||
),
|
||||
json_payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
json_payload: dict[str, Any] = {}
|
||||
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]:
|
||||
json_payload = {STATE: payload}
|
||||
else:
|
||||
try:
|
||||
json_payload = json_loads_object(payload)
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"JSON payload detected after processing payload '%s' on"
|
||||
" topic %s"
|
||||
),
|
||||
json_payload,
|
||||
msg.topic,
|
||||
)
|
||||
except JSON_DECODE_EXCEPTIONS:
|
||||
_LOGGER.warning(
|
||||
(
|
||||
"No valid (JSON) payload detected after processing payload"
|
||||
" '%s' on topic %s"
|
||||
),
|
||||
json_payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
if STATE in json_payload:
|
||||
if json_payload[STATE] == self._state_on:
|
||||
self._attr_is_on = True
|
||||
if json_payload[STATE] == self._state_off:
|
||||
self._attr_is_on = False
|
||||
if json_payload[STATE] == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
del json_payload[STATE]
|
||||
if STATE in json_payload:
|
||||
if json_payload[STATE] == self._state_on:
|
||||
self._attr_is_on = True
|
||||
if json_payload[STATE] == self._state_off:
|
||||
self._attr_is_on = False
|
||||
if json_payload[STATE] == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
del json_payload[STATE]
|
||||
|
||||
if json_payload:
|
||||
# process attributes
|
||||
try:
|
||||
params: SirenTurnOnServiceParameters
|
||||
params = vol.All(TURN_ON_SCHEMA)(json_payload)
|
||||
except vol.MultipleInvalid as invalid_siren_parameters:
|
||||
_LOGGER.warning(
|
||||
"Unable to update siren state attributes from payload '%s': %s",
|
||||
json_payload,
|
||||
invalid_siren_parameters,
|
||||
)
|
||||
return
|
||||
# To be able to track changes to self._extra_attributes we assign
|
||||
# a fresh copy to make the original tracked reference immutable.
|
||||
self._extra_attributes = dict(self._extra_attributes)
|
||||
self._update(process_turn_on_params(self, params))
|
||||
if json_payload:
|
||||
# process attributes
|
||||
try:
|
||||
params: SirenTurnOnServiceParameters
|
||||
params = vol.All(TURN_ON_SCHEMA)(json_payload)
|
||||
except vol.MultipleInvalid as invalid_siren_parameters:
|
||||
_LOGGER.warning(
|
||||
"Unable to update siren state attributes from payload '%s': %s",
|
||||
json_payload,
|
||||
invalid_siren_parameters,
|
||||
)
|
||||
return
|
||||
# To be able to track changes to self._extra_attributes we assign
|
||||
# a fresh copy to make the original tracked reference immutable.
|
||||
self._extra_attributes = dict(self._extra_attributes)
|
||||
self._update(process_turn_on_params(self, params))
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||
# Force into optimistic mode.
|
||||
self._optimistic = True
|
||||
else:
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
CONF_STATE_TOPIC: {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": state_message_received,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
CONF_STATE_TOPIC: {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._state_message_received,
|
||||
{"_attr_is_on", "_extra_attributes"},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self) -> dict[str, Any] | None:
|
||||
|
@@ -2,14 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from .. import mqtt
|
||||
from . import debug_info
|
||||
from .client import async_subscribe_internal
|
||||
from .const import DEFAULT_QOS
|
||||
from .models import MessageCallbackType
|
||||
|
||||
@@ -21,7 +22,7 @@ class EntitySubscription:
|
||||
hass: HomeAssistant
|
||||
topic: str | None
|
||||
message_callback: MessageCallbackType
|
||||
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None
|
||||
should_subscribe: bool | None
|
||||
unsubscribe_callback: Callable[[], None] | None
|
||||
qos: int = 0
|
||||
encoding: str = "utf-8"
|
||||
@@ -53,15 +54,16 @@ class EntitySubscription:
|
||||
self.hass, self.message_callback, self.topic, self.entity_id
|
||||
)
|
||||
|
||||
self.subscribe_task = mqtt.async_subscribe(
|
||||
hass, self.topic, self.message_callback, self.qos, self.encoding
|
||||
)
|
||||
self.should_subscribe = True
|
||||
|
||||
async def subscribe(self) -> None:
|
||||
@callback
|
||||
def subscribe(self) -> None:
|
||||
"""Subscribe to a topic."""
|
||||
if not self.subscribe_task:
|
||||
if not self.should_subscribe or not self.topic:
|
||||
return
|
||||
self.unsubscribe_callback = await self.subscribe_task
|
||||
self.unsubscribe_callback = async_subscribe_internal(
|
||||
self.hass, self.topic, self.message_callback, self.qos, self.encoding
|
||||
)
|
||||
|
||||
def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
|
||||
"""Check if we should re-subscribe to the topic using the old state."""
|
||||
@@ -79,6 +81,7 @@ class EntitySubscription:
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_prepare_subscribe_topics(
|
||||
hass: HomeAssistant,
|
||||
new_state: dict[str, EntitySubscription] | None,
|
||||
@@ -107,7 +110,7 @@ def async_prepare_subscribe_topics(
|
||||
qos=value.get("qos", DEFAULT_QOS),
|
||||
encoding=value.get("encoding", "utf-8"),
|
||||
hass=hass,
|
||||
subscribe_task=None,
|
||||
should_subscribe=None,
|
||||
entity_id=value.get("entity_id", None),
|
||||
)
|
||||
# Get the current subscription state
|
||||
@@ -135,12 +138,29 @@ async def async_subscribe_topics(
|
||||
sub_state: dict[str, EntitySubscription],
|
||||
) -> None:
|
||||
"""(Re)Subscribe to a set of MQTT topics."""
|
||||
async_subscribe_topics_internal(hass, sub_state)
|
||||
|
||||
|
||||
@callback
|
||||
def async_subscribe_topics_internal(
|
||||
hass: HomeAssistant,
|
||||
sub_state: dict[str, EntitySubscription],
|
||||
) -> None:
|
||||
"""(Re)Subscribe to a set of MQTT topics.
|
||||
|
||||
This function is internal to the MQTT integration and should not be called
|
||||
from outside the integration.
|
||||
"""
|
||||
for sub in sub_state.values():
|
||||
await sub.subscribe()
|
||||
sub.subscribe()
|
||||
|
||||
|
||||
def async_unsubscribe_topics(
|
||||
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
|
||||
) -> dict[str, EntitySubscription]:
|
||||
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
||||
return async_prepare_subscribe_topics(hass, sub_state, {})
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def async_unsubscribe_topics(
|
||||
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
|
||||
) -> dict[str, EntitySubscription]:
|
||||
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
|
||||
|
||||
|
||||
async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={})
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
@@ -36,12 +37,7 @@ from .const import (
|
||||
CONF_STATE_TOPIC,
|
||||
PAYLOAD_NONE,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import MqttValueTemplate, ReceiveMessage
|
||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||
|
||||
@@ -118,42 +114,44 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
||||
self._config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||
).async_render_with_possible_json_value
|
||||
|
||||
@callback
|
||||
def _state_message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
if payload == self._state_on:
|
||||
self._attr_is_on = True
|
||||
elif payload == self._state_off:
|
||||
self._attr_is_on = False
|
||||
elif payload == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_is_on"})
|
||||
def state_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
if payload == self._state_on:
|
||||
self._attr_is_on = True
|
||||
elif payload == self._state_off:
|
||||
self._attr_is_on = False
|
||||
elif payload == PAYLOAD_NONE:
|
||||
self._attr_is_on = None
|
||||
|
||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||
# Force into optimistic mode.
|
||||
self._optimistic = True
|
||||
else:
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
CONF_STATE_TOPIC: {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": state_message_received,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
return
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass,
|
||||
self._sub_state,
|
||||
{
|
||||
CONF_STATE_TOPIC: {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._state_message_received,
|
||||
{"_attr_is_on"},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
if self._optimistic and (last_state := await self.async_get_last_state()):
|
||||
self._attr_is_on = last_state.state == STATE_ON
|
||||
|
@@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin):
|
||||
}
|
||||
},
|
||||
)
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_tear_down(self) -> None:
|
||||
"""Cleanup tag scanner."""
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
@@ -34,12 +35,7 @@ from .const import (
|
||||
CONF_RETAIN,
|
||||
CONF_STATE_TOPIC,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import (
|
||||
MessageCallbackType,
|
||||
MqttCommandTemplate,
|
||||
@@ -160,32 +156,41 @@ class MqttTextEntity(MqttEntity, TextEntity):
|
||||
self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None
|
||||
self._attr_assumed_state = bool(self._optimistic)
|
||||
|
||||
@callback
|
||||
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle receiving state message via MQTT."""
|
||||
payload = str(self._value_template(msg.payload))
|
||||
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
|
||||
return
|
||||
self._attr_native_value = payload
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics: dict[str, Any] = {}
|
||||
|
||||
def add_subscription(
|
||||
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType
|
||||
topics: dict[str, Any],
|
||||
topic: str,
|
||||
msg_callback: MessageCallbackType,
|
||||
tracked_attributes: set[str],
|
||||
) -> None:
|
||||
if self._config.get(topic) is not None:
|
||||
topics[topic] = {
|
||||
"topic": self._config[topic],
|
||||
"msg_callback": msg_callback,
|
||||
"msg_callback": partial(
|
||||
self._message_callback, msg_callback, tracked_attributes
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_native_value"})
|
||||
def handle_state_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle receiving state message via MQTT."""
|
||||
payload = str(self._value_template(msg.payload))
|
||||
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
|
||||
return
|
||||
self._attr_native_value = payload
|
||||
|
||||
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
|
||||
add_subscription(
|
||||
topics,
|
||||
CONF_STATE_TOPIC,
|
||||
self._handle_state_message_received,
|
||||
{"_attr_native_value"},
|
||||
)
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
self.hass, self._sub_state, topics
|
||||
@@ -193,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_set_value(self, value: str) -> None:
|
||||
"""Change the text."""
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
@@ -32,12 +33,7 @@ from .const import (
|
||||
CONF_STATE_TOPIC,
|
||||
PAYLOAD_EMPTY_JSON,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage
|
||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||
from .util import valid_publish_topic, valid_subscribe_topic
|
||||
@@ -141,25 +137,104 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
||||
).async_render_with_possible_json_value,
|
||||
}
|
||||
|
||||
@callback
|
||||
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle receiving state message via MQTT."""
|
||||
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
|
||||
|
||||
if not payload or payload == PAYLOAD_EMPTY_JSON:
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
|
||||
json_payload: _MqttUpdatePayloadType = {}
|
||||
try:
|
||||
rendered_json_payload = json_loads(payload)
|
||||
if isinstance(rendered_json_payload, dict):
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"JSON payload detected after processing payload '%s' on"
|
||||
" topic %s"
|
||||
),
|
||||
rendered_json_payload,
|
||||
msg.topic,
|
||||
)
|
||||
json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"Non-dictionary JSON payload detected after processing"
|
||||
" payload '%s' on topic %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
json_payload = {"installed_version": str(payload)}
|
||||
except JSON_DECODE_EXCEPTIONS:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"No valid (JSON) payload detected after processing payload '%s'"
|
||||
" on topic %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
json_payload["installed_version"] = str(payload)
|
||||
|
||||
if "installed_version" in json_payload:
|
||||
self._attr_installed_version = json_payload["installed_version"]
|
||||
|
||||
if "latest_version" in json_payload:
|
||||
self._attr_latest_version = json_payload["latest_version"]
|
||||
|
||||
if "title" in json_payload:
|
||||
self._attr_title = json_payload["title"]
|
||||
|
||||
if "release_summary" in json_payload:
|
||||
self._attr_release_summary = json_payload["release_summary"]
|
||||
|
||||
if "release_url" in json_payload:
|
||||
self._attr_release_url = json_payload["release_url"]
|
||||
|
||||
if "entity_picture" in json_payload:
|
||||
self._entity_picture = json_payload["entity_picture"]
|
||||
|
||||
@callback
|
||||
def _handle_latest_version_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle receiving latest version via MQTT."""
|
||||
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
|
||||
|
||||
if isinstance(latest_version, str) and latest_version != "":
|
||||
self._attr_latest_version = latest_version
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics: dict[str, Any] = {}
|
||||
|
||||
def add_subscription(
|
||||
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType
|
||||
topics: dict[str, Any],
|
||||
topic: str,
|
||||
msg_callback: MessageCallbackType,
|
||||
tracked_attributes: set[str],
|
||||
) -> None:
|
||||
if self._config.get(topic) is not None:
|
||||
topics[topic] = {
|
||||
"topic": self._config[topic],
|
||||
"msg_callback": msg_callback,
|
||||
"msg_callback": partial(
|
||||
self._message_callback, msg_callback, tracked_attributes
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self,
|
||||
add_subscription(
|
||||
topics,
|
||||
CONF_STATE_TOPIC,
|
||||
self._handle_state_message_received,
|
||||
{
|
||||
"_attr_installed_version",
|
||||
"_attr_latest_version",
|
||||
@@ -169,84 +244,11 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
||||
"_entity_picture",
|
||||
},
|
||||
)
|
||||
def handle_state_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle receiving state message via MQTT."""
|
||||
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
|
||||
|
||||
if not payload or payload == PAYLOAD_EMPTY_JSON:
|
||||
_LOGGER.debug(
|
||||
"Ignoring empty payload '%s' after rendering for topic %s",
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
return
|
||||
|
||||
json_payload: _MqttUpdatePayloadType = {}
|
||||
try:
|
||||
rendered_json_payload = json_loads(payload)
|
||||
if isinstance(rendered_json_payload, dict):
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"JSON payload detected after processing payload '%s' on"
|
||||
" topic %s"
|
||||
),
|
||||
rendered_json_payload,
|
||||
msg.topic,
|
||||
)
|
||||
json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"Non-dictionary JSON payload detected after processing"
|
||||
" payload '%s' on topic %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
json_payload = {"installed_version": str(payload)}
|
||||
except JSON_DECODE_EXCEPTIONS:
|
||||
_LOGGER.debug(
|
||||
(
|
||||
"No valid (JSON) payload detected after processing payload '%s'"
|
||||
" on topic %s"
|
||||
),
|
||||
payload,
|
||||
msg.topic,
|
||||
)
|
||||
json_payload["installed_version"] = str(payload)
|
||||
|
||||
if "installed_version" in json_payload:
|
||||
self._attr_installed_version = json_payload["installed_version"]
|
||||
|
||||
if "latest_version" in json_payload:
|
||||
self._attr_latest_version = json_payload["latest_version"]
|
||||
|
||||
if "title" in json_payload:
|
||||
self._attr_title = json_payload["title"]
|
||||
|
||||
if "release_summary" in json_payload:
|
||||
self._attr_release_summary = json_payload["release_summary"]
|
||||
|
||||
if "release_url" in json_payload:
|
||||
self._attr_release_url = json_payload["release_url"]
|
||||
|
||||
if "entity_picture" in json_payload:
|
||||
self._entity_picture = json_payload["entity_picture"]
|
||||
|
||||
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(self, {"_attr_latest_version"})
|
||||
def handle_latest_version_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle receiving latest version via MQTT."""
|
||||
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
|
||||
|
||||
if isinstance(latest_version, str) and latest_version != "":
|
||||
self._attr_latest_version = latest_version
|
||||
|
||||
add_subscription(
|
||||
topics, CONF_LATEST_VERSION_TOPIC, handle_latest_version_received
|
||||
topics,
|
||||
CONF_LATEST_VERSION_TOPIC,
|
||||
self._handle_latest_version_received,
|
||||
{"_attr_latest_version"},
|
||||
)
|
||||
|
||||
self._sub_state = subscription.async_prepare_subscribe_topics(
|
||||
@@ -255,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_install(
|
||||
self, version: str | None, backup: bool, **kwargs: Any
|
||||
|
@@ -8,6 +8,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -49,12 +50,7 @@ from .const import (
|
||||
CONF_STATE_TOPIC,
|
||||
DOMAIN,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import ReceiveMessage
|
||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||
from .util import valid_publish_topic
|
||||
@@ -322,31 +318,32 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
|
||||
self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0)
|
||||
self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0)))
|
||||
|
||||
@callback
|
||||
def _state_message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle state MQTT message."""
|
||||
payload = json_loads_object(msg.payload)
|
||||
if STATE in payload and (
|
||||
(state := payload[STATE]) in POSSIBLE_STATES or state is None
|
||||
):
|
||||
self._attr_state = (
|
||||
POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None
|
||||
)
|
||||
del payload[STATE]
|
||||
self._update_state_attributes(payload)
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics: dict[str, Any] = {}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self, {"_attr_battery_level", "_attr_fan_speed", "_attr_state"}
|
||||
)
|
||||
def state_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle state MQTT message."""
|
||||
payload = json_loads_object(msg.payload)
|
||||
if STATE in payload and (
|
||||
(state := payload[STATE]) in POSSIBLE_STATES or state is None
|
||||
):
|
||||
self._attr_state = (
|
||||
POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None
|
||||
)
|
||||
del payload[STATE]
|
||||
self._update_state_attributes(payload)
|
||||
|
||||
if state_topic := self._config.get(CONF_STATE_TOPIC):
|
||||
topics["state_position_topic"] = {
|
||||
"topic": state_topic,
|
||||
"msg_callback": state_message_received,
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._state_message_received,
|
||||
{"_attr_battery_level", "_attr_fan_speed", "_attr_state"},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
@@ -356,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
|
||||
"""Publish a command."""
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -61,12 +62,7 @@ from .const import (
|
||||
DEFAULT_RETAIN,
|
||||
PAYLOAD_NONE,
|
||||
)
|
||||
from .debug_info import log_messages
|
||||
from .mixins import (
|
||||
MqttEntity,
|
||||
async_setup_entity_entry_helper,
|
||||
write_state_on_attr_change,
|
||||
)
|
||||
from .mixins import MqttEntity, async_setup_entity_entry_helper
|
||||
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
|
||||
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
|
||||
from .util import valid_publish_topic, valid_subscribe_topic
|
||||
@@ -302,65 +298,63 @@ class MqttValve(MqttEntity, ValveEntity):
|
||||
return
|
||||
self._update_state(state)
|
||||
|
||||
@callback
|
||||
def _state_message_received(self, msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
payload_dict: Any = None
|
||||
position_payload: Any = payload
|
||||
state_payload: Any = payload
|
||||
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
with suppress(*JSON_DECODE_EXCEPTIONS):
|
||||
payload_dict = json_loads(payload)
|
||||
if isinstance(payload_dict, dict):
|
||||
if self.reports_position and "position" not in payload_dict:
|
||||
_LOGGER.warning(
|
||||
"Missing required `position` attribute in json payload "
|
||||
"on topic '%s', got: %s",
|
||||
msg.topic,
|
||||
payload,
|
||||
)
|
||||
return
|
||||
if not self.reports_position and "state" not in payload_dict:
|
||||
_LOGGER.warning(
|
||||
"Missing required `state` attribute in json payload "
|
||||
" on topic '%s', got: %s",
|
||||
msg.topic,
|
||||
payload,
|
||||
)
|
||||
return
|
||||
position_payload = payload_dict.get("position")
|
||||
state_payload = payload_dict.get("state")
|
||||
|
||||
if self._config[CONF_REPORTS_POSITION]:
|
||||
self._process_position_valve_update(msg, position_payload, state_payload)
|
||||
else:
|
||||
self._process_binary_valve_update(msg, state_payload)
|
||||
|
||||
def _prepare_subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
topics = {}
|
||||
|
||||
@callback
|
||||
@log_messages(self.hass, self.entity_id)
|
||||
@write_state_on_attr_change(
|
||||
self,
|
||||
{
|
||||
"_attr_current_valve_position",
|
||||
"_attr_is_closed",
|
||||
"_attr_is_closing",
|
||||
"_attr_is_opening",
|
||||
},
|
||||
)
|
||||
def state_message_received(msg: ReceiveMessage) -> None:
|
||||
"""Handle new MQTT state messages."""
|
||||
payload = self._value_template(msg.payload)
|
||||
payload_dict: Any = None
|
||||
position_payload: Any = payload
|
||||
state_payload: Any = payload
|
||||
|
||||
if not payload:
|
||||
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
|
||||
return
|
||||
|
||||
with suppress(*JSON_DECODE_EXCEPTIONS):
|
||||
payload_dict = json_loads(payload)
|
||||
if isinstance(payload_dict, dict):
|
||||
if self.reports_position and "position" not in payload_dict:
|
||||
_LOGGER.warning(
|
||||
"Missing required `position` attribute in json payload "
|
||||
"on topic '%s', got: %s",
|
||||
msg.topic,
|
||||
payload,
|
||||
)
|
||||
return
|
||||
if not self.reports_position and "state" not in payload_dict:
|
||||
_LOGGER.warning(
|
||||
"Missing required `state` attribute in json payload "
|
||||
" on topic '%s', got: %s",
|
||||
msg.topic,
|
||||
payload,
|
||||
)
|
||||
return
|
||||
position_payload = payload_dict.get("position")
|
||||
state_payload = payload_dict.get("state")
|
||||
|
||||
if self._config[CONF_REPORTS_POSITION]:
|
||||
self._process_position_valve_update(
|
||||
msg, position_payload, state_payload
|
||||
)
|
||||
else:
|
||||
self._process_binary_valve_update(msg, state_payload)
|
||||
|
||||
if self._config.get(CONF_STATE_TOPIC):
|
||||
topics["state_topic"] = {
|
||||
"topic": self._config.get(CONF_STATE_TOPIC),
|
||||
"msg_callback": state_message_received,
|
||||
"msg_callback": partial(
|
||||
self._message_callback,
|
||||
self._state_message_received,
|
||||
{
|
||||
"_attr_current_valve_position",
|
||||
"_attr_is_closed",
|
||||
"_attr_is_closing",
|
||||
"_attr_is_opening",
|
||||
},
|
||||
),
|
||||
"entity_id": self.entity_id,
|
||||
"qos": self._config[CONF_QOS],
|
||||
"encoding": self._config[CONF_ENCODING] or None,
|
||||
}
|
||||
@@ -371,7 +365,7 @@ class MqttValve(MqttEntity, ValveEntity):
|
||||
|
||||
async def _subscribe_topics(self) -> None:
|
||||
"""(Re)Subscribe to topics."""
|
||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
|
||||
|
||||
async def async_open_valve(self) -> None:
|
||||
"""Move the valve up.
|
||||
|
@@ -9,6 +9,7 @@ from typing import Literal
|
||||
import ollama
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import MATCH_ALL
|
||||
@@ -138,6 +139,11 @@ class OllamaConversationEntity(
|
||||
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
||||
)
|
||||
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
{"messages": message_history.messages},
|
||||
)
|
||||
|
||||
# Get response
|
||||
try:
|
||||
response = await client.chat(
|
||||
|
@@ -31,14 +31,15 @@ from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_P,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -49,6 +50,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
}
|
||||
)
|
||||
|
||||
RECOMMENDED_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
|
||||
CONF_PROMPT: DEFAULT_PROMPT,
|
||||
}
|
||||
|
||||
|
||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||
"""Validate the user input allows us to connect.
|
||||
@@ -88,7 +95,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
return self.async_create_entry(
|
||||
title="ChatGPT",
|
||||
data=user_input,
|
||||
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
|
||||
options=RECOMMENDED_OPTIONS,
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
@@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow):
|
||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
self.config_entry = config_entry
|
||||
self.last_rendered_recommended = config_entry.options.get(
|
||||
CONF_RECOMMENDED, False
|
||||
)
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Manage the options."""
|
||||
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
|
||||
|
||||
if user_input is not None:
|
||||
if user_input[CONF_LLM_HASS_API] == "none":
|
||||
user_input.pop(CONF_LLM_HASS_API)
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
schema = openai_config_option_schema(self.hass, self.config_entry.options)
|
||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||
if user_input[CONF_LLM_HASS_API] == "none":
|
||||
user_input.pop(CONF_LLM_HASS_API)
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
|
||||
# Re-render the options again, now with the recommended options shown/hidden
|
||||
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||
|
||||
options = {
|
||||
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
|
||||
CONF_PROMPT: user_input[CONF_PROMPT],
|
||||
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
|
||||
}
|
||||
|
||||
schema = openai_config_option_schema(self.hass, options)
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema),
|
||||
@@ -127,16 +150,16 @@ class OpenAIOptionsFlow(OptionsFlow):
|
||||
|
||||
def openai_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
options: MappingProxyType[str, Any],
|
||||
options: dict[str, Any] | MappingProxyType[str, Any],
|
||||
) -> dict:
|
||||
"""Return a schema for OpenAI completion options."""
|
||||
apis: list[SelectOptionDict] = [
|
||||
hass_apis: list[SelectOptionDict] = [
|
||||
SelectOptionDict(
|
||||
label="No control",
|
||||
value="none",
|
||||
)
|
||||
]
|
||||
apis.extend(
|
||||
hass_apis.extend(
|
||||
SelectOptionDict(
|
||||
label=api.name,
|
||||
value=api.id,
|
||||
@@ -144,38 +167,46 @@ def openai_config_option_schema(
|
||||
for api in llm.async_get_apis(hass)
|
||||
)
|
||||
|
||||
return {
|
||||
schema = {
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT)},
|
||||
default=DEFAULT_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
|
||||
default="none",
|
||||
): SelectSelector(SelectSelectorConfig(options=apis)),
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={
|
||||
# New key in HA 2023.4
|
||||
"suggested_value": options.get(CONF_CHAT_MODEL)
|
||||
},
|
||||
default=DEFAULT_CHAT_MODEL,
|
||||
): str,
|
||||
vol.Optional(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
default=DEFAULT_MAX_TOKENS,
|
||||
): int,
|
||||
vol.Optional(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||
default=DEFAULT_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=DEFAULT_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
}
|
||||
|
||||
if options.get(CONF_RECOMMENDED):
|
||||
return schema
|
||||
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_CHAT_MODEL,
|
||||
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
|
||||
default=RECOMMENDED_CHAT_MODEL,
|
||||
): str,
|
||||
vol.Optional(
|
||||
CONF_MAX_TOKENS,
|
||||
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
|
||||
default=RECOMMENDED_MAX_TOKENS,
|
||||
): int,
|
||||
vol.Optional(
|
||||
CONF_TOP_P,
|
||||
description={"suggested_value": options.get(CONF_TOP_P)},
|
||||
default=RECOMMENDED_TOP_P,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
|
||||
vol.Optional(
|
||||
CONF_TEMPERATURE,
|
||||
description={"suggested_value": options.get(CONF_TEMPERATURE)},
|
||||
default=RECOMMENDED_TEMPERATURE,
|
||||
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
|
||||
}
|
||||
)
|
||||
return schema
|
||||
|
@@ -4,13 +4,15 @@ import logging
|
||||
|
||||
DOMAIN = "openai_conversation"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_PROMPT = "prompt"
|
||||
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
|
||||
CONF_CHAT_MODEL = "chat_model"
|
||||
DEFAULT_CHAT_MODEL = "gpt-4o"
|
||||
RECOMMENDED_CHAT_MODEL = "gpt-4o"
|
||||
CONF_MAX_TOKENS = "max_tokens"
|
||||
DEFAULT_MAX_TOKENS = 150
|
||||
RECOMMENDED_MAX_TOKENS = 150
|
||||
CONF_TOP_P = "top_p"
|
||||
DEFAULT_TOP_P = 1.0
|
||||
RECOMMENDED_TOP_P = 1.0
|
||||
CONF_TEMPERATURE = "temperature"
|
||||
DEFAULT_TEMPERATURE = 1.0
|
||||
RECOMMENDED_TEMPERATURE = 1.0
|
||||
|
@@ -8,6 +8,7 @@ import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import assist_pipeline, conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -22,13 +23,13 @@ from .const import (
|
||||
CONF_PROMPT,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_MAX_TOKENS,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TEMPERATURE,
|
||||
DEFAULT_TOP_P,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
RECOMMENDED_TOP_P,
|
||||
)
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
@@ -97,15 +98,14 @@ class OpenAIConversationEntity(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
options = self.entry.options
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
llm_api: llm.API | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
|
||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||
if options.get(CONF_LLM_HASS_API):
|
||||
try:
|
||||
llm_api = llm.async_get_api(
|
||||
self.hass, self.entry.options[CONF_LLM_HASS_API]
|
||||
)
|
||||
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
|
||||
except HomeAssistantError as err:
|
||||
LOGGER.error("Error getting LLM API: %s", err)
|
||||
intent_response.async_set_error(
|
||||
@@ -117,26 +117,12 @@ class OpenAIConversationEntity(
|
||||
)
|
||||
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
|
||||
|
||||
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
|
||||
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
|
||||
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
|
||||
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
|
||||
|
||||
if user_input.conversation_id in self.history:
|
||||
conversation_id = user_input.conversation_id
|
||||
messages = self.history[conversation_id]
|
||||
else:
|
||||
conversation_id = ulid.ulid_now()
|
||||
try:
|
||||
prompt = template.Template(
|
||||
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
if llm_api:
|
||||
empty_tool_input = llm.ToolInput(
|
||||
tool_name="",
|
||||
@@ -149,11 +135,24 @@ class OpenAIConversationEntity(
|
||||
device_id=user_input.device_id,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
await llm_api.async_get_api_prompt(empty_tool_input)
|
||||
+ "\n"
|
||||
+ prompt
|
||||
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
|
||||
|
||||
else:
|
||||
api_prompt = llm.PROMPT_NO_API_CONFIGURED
|
||||
|
||||
prompt = "\n".join(
|
||||
(
|
||||
template.Template(
|
||||
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
|
||||
).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
},
|
||||
parse_result=False,
|
||||
),
|
||||
api_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
except TemplateError as err:
|
||||
LOGGER.error("Error rendering prompt: %s", err)
|
||||
@@ -170,7 +169,10 @@ class OpenAIConversationEntity(
|
||||
|
||||
messages.append({"role": "user", "content": user_input.text})
|
||||
|
||||
LOGGER.debug("Prompt for %s: %s", model, messages)
|
||||
LOGGER.debug("Prompt: %s", messages)
|
||||
trace.async_conversation_trace_append(
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
|
||||
)
|
||||
|
||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||
|
||||
@@ -178,12 +180,12 @@ class OpenAIConversationEntity(
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
try:
|
||||
result = await client.chat.completions.create(
|
||||
model=model,
|
||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
user=conversation_id,
|
||||
)
|
||||
except openai.OpenAIError as err:
|
||||
|
@@ -22,7 +22,8 @@
|
||||
"max_tokens": "Maximum tokens to return in response",
|
||||
"temperature": "Temperature",
|
||||
"top_p": "Top P",
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||
"recommended": "Recommended model settings"
|
||||
},
|
||||
"data_description": {
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||
|
@@ -30,7 +30,6 @@ from .util import (
|
||||
|
||||
PLATFORMS = [Platform.MEDIA_PLAYER]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"async_browse_media",
|
||||
"DOMAIN",
|
||||
@@ -50,7 +49,10 @@ class HomeAssistantSpotifyData:
|
||||
session: OAuth2Session
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
type SpotifyConfigEntry = ConfigEntry[HomeAssistantSpotifyData]
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: SpotifyConfigEntry) -> bool:
|
||||
"""Set up Spotify from a config entry."""
|
||||
implementation = await async_get_config_entry_implementation(hass, entry)
|
||||
session = OAuth2Session(hass, entry, implementation)
|
||||
@@ -100,8 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
)
|
||||
await device_coordinator.async_config_entry_first_refresh()
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
hass.data[DOMAIN][entry.entry_id] = HomeAssistantSpotifyData(
|
||||
entry.runtime_data = HomeAssistantSpotifyData(
|
||||
client=spotify,
|
||||
current_user=current_user,
|
||||
devices=device_coordinator,
|
||||
@@ -117,6 +118,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Spotify config entry."""
|
||||
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
|
||||
del hass.data[DOMAIN][entry.entry_id]
|
||||
return unload_ok
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from enum import StrEnum
|
||||
from functools import partial
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from spotipy import Spotify
|
||||
import yarl
|
||||
@@ -22,6 +22,9 @@ from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
|
||||
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES
|
||||
from .util import fetch_image_url
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import HomeAssistantSpotifyData
|
||||
|
||||
BROWSE_LIMIT = 48
|
||||
|
||||
|
||||
@@ -140,21 +143,21 @@ async def async_browse_media(
|
||||
|
||||
# Check if caller is requesting the root nodes
|
||||
if media_content_type is None and media_content_id is None:
|
||||
children = []
|
||||
for config_entry_id in hass.data[DOMAIN]:
|
||||
config_entry = hass.config_entries.async_get_entry(config_entry_id)
|
||||
assert config_entry is not None
|
||||
children.append(
|
||||
BrowseMedia(
|
||||
title=config_entry.title,
|
||||
media_class=MediaClass.APP,
|
||||
media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry_id}",
|
||||
media_content_type=f"{MEDIA_PLAYER_PREFIX}library",
|
||||
thumbnail="https://brands.home-assistant.io/_/spotify/logo.png",
|
||||
can_play=False,
|
||||
can_expand=True,
|
||||
)
|
||||
config_entries = hass.config_entries.async_entries(
|
||||
DOMAIN, include_disabled=False, include_ignore=False
|
||||
)
|
||||
children = [
|
||||
BrowseMedia(
|
||||
title=config_entry.title,
|
||||
media_class=MediaClass.APP,
|
||||
media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry.entry_id}",
|
||||
media_content_type=f"{MEDIA_PLAYER_PREFIX}library",
|
||||
thumbnail="https://brands.home-assistant.io/_/spotify/logo.png",
|
||||
can_play=False,
|
||||
can_expand=True,
|
||||
)
|
||||
for config_entry in config_entries
|
||||
]
|
||||
return BrowseMedia(
|
||||
title="Spotify",
|
||||
media_class=MediaClass.APP,
|
||||
@@ -171,9 +174,15 @@ async def async_browse_media(
|
||||
|
||||
# Check for config entry specifier, and extract Spotify URI
|
||||
parsed_url = yarl.URL(media_content_id)
|
||||
if (info := hass.data[DOMAIN].get(parsed_url.host)) is None:
|
||||
|
||||
if (
|
||||
parsed_url.host is None
|
||||
or (entry := hass.config_entries.async_get_entry(parsed_url.host)) is None
|
||||
or not isinstance(entry.runtime_data, HomeAssistantSpotifyData)
|
||||
):
|
||||
raise BrowseError("Invalid Spotify account specified")
|
||||
media_content_id = parsed_url.name
|
||||
info = entry.runtime_data
|
||||
|
||||
result = await async_browse_media_internal(
|
||||
hass,
|
||||
|
@@ -22,7 +22,6 @@ from homeassistant.components.media_player import (
|
||||
MediaType,
|
||||
RepeatMode,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_ID
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -30,7 +29,7 @@ from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from . import HomeAssistantSpotifyData
|
||||
from . import HomeAssistantSpotifyData, SpotifyConfigEntry
|
||||
from .browse_media import async_browse_media_internal
|
||||
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES
|
||||
from .util import fetch_image_url
|
||||
@@ -70,12 +69,12 @@ SPOTIFY_DJ_PLAYLIST = {"uri": "spotify:playlist:37i9dQZF1EYkqdzj48dyYq", "name":
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
entry: ConfigEntry,
|
||||
entry: SpotifyConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Spotify based on a config entry."""
|
||||
spotify = SpotifyMediaPlayer(
|
||||
hass.data[DOMAIN][entry.entry_id],
|
||||
entry.runtime_data,
|
||||
entry.data[CONF_ID],
|
||||
entry.title,
|
||||
)
|
||||
|
@@ -4,15 +4,14 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from aioswitcher.bridge import SwitcherBridge
|
||||
from aioswitcher.device import SwitcherBase
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
|
||||
from .const import DATA_DEVICE, DOMAIN
|
||||
from .coordinator import SwitcherDataUpdateCoordinator
|
||||
from .utils import async_start_bridge, async_stop_bridge
|
||||
|
||||
PLATFORMS = [
|
||||
Platform.BUTTON,
|
||||
@@ -25,20 +24,20 @@ PLATFORMS = [
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
type SwitcherConfigEntry = ConfigEntry[dict[str, SwitcherDataUpdateCoordinator]]
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
|
||||
"""Set up Switcher from a config entry."""
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
hass.data[DOMAIN][DATA_DEVICE] = {}
|
||||
|
||||
@callback
|
||||
def on_device_data_callback(device: SwitcherBase) -> None:
|
||||
"""Use as a callback for device data."""
|
||||
|
||||
coordinators = entry.runtime_data
|
||||
|
||||
# Existing device update device data
|
||||
if device.device_id in hass.data[DOMAIN][DATA_DEVICE]:
|
||||
coordinator: SwitcherDataUpdateCoordinator = hass.data[DOMAIN][DATA_DEVICE][
|
||||
device.device_id
|
||||
]
|
||||
if coordinator := coordinators.get(device.device_id):
|
||||
coordinator.async_set_updated_data(device)
|
||||
return
|
||||
|
||||
@@ -52,18 +51,21 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
device.device_type.hex_rep,
|
||||
)
|
||||
|
||||
coordinator = hass.data[DOMAIN][DATA_DEVICE][device.device_id] = (
|
||||
SwitcherDataUpdateCoordinator(hass, entry, device)
|
||||
)
|
||||
coordinator = SwitcherDataUpdateCoordinator(hass, entry, device)
|
||||
coordinator.async_setup()
|
||||
coordinators[device.device_id] = coordinator
|
||||
|
||||
# Must be ready before dispatcher is called
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
|
||||
await async_start_bridge(hass, on_device_data_callback)
|
||||
entry.runtime_data = {}
|
||||
bridge = SwitcherBridge(on_device_data_callback)
|
||||
await bridge.start()
|
||||
|
||||
async def stop_bridge(event: Event) -> None:
|
||||
await async_stop_bridge(hass)
|
||||
async def stop_bridge(event: Event | None = None) -> None:
|
||||
await bridge.stop()
|
||||
|
||||
entry.async_on_unload(stop_bridge)
|
||||
|
||||
entry.async_on_unload(
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge)
|
||||
@@ -72,12 +74,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
await async_stop_bridge(hass)
|
||||
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
if unload_ok:
|
||||
hass.data[DOMAIN].pop(DATA_DEVICE)
|
||||
|
||||
return unload_ok
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
@@ -15,7 +15,6 @@ from aioswitcher.api.remotes import SwitcherBreezeRemote
|
||||
from aioswitcher.device import DeviceCategory
|
||||
|
||||
from homeassistant.components.button import ButtonEntity, ButtonEntityDescription
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import EntityCategory
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -25,6 +24,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||
|
||||
from . import SwitcherConfigEntry
|
||||
from .const import SIGNAL_DEVICE_ADD
|
||||
from .coordinator import SwitcherDataUpdateCoordinator
|
||||
from .utils import get_breeze_remote_manager
|
||||
@@ -78,7 +78,7 @@ THERMOSTAT_BUTTONS = [
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
config_entry: SwitcherConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Switcher button from config entry."""
|
||||
|
@@ -25,7 +25,6 @@ from homeassistant.components.climate import (
|
||||
ClimateEntityFeature,
|
||||
HVACMode,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -35,6 +34,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||
|
||||
from . import SwitcherConfigEntry
|
||||
from .const import SIGNAL_DEVICE_ADD
|
||||
from .coordinator import SwitcherDataUpdateCoordinator
|
||||
from .utils import get_breeze_remote_manager
|
||||
@@ -61,7 +61,7 @@ HA_TO_DEVICE_FAN = {value: key for key, value in DEVICE_FAN_TO_HA.items()}
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
config_entry: SwitcherConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Switcher climate from config entry."""
|
||||
|
@@ -2,9 +2,6 @@
|
||||
|
||||
DOMAIN = "switcher_kis"
|
||||
|
||||
DATA_BRIDGE = "bridge"
|
||||
DATA_DEVICE = "device"
|
||||
|
||||
DISCOVERY_TIME_SEC = 12
|
||||
|
||||
SIGNAL_DEVICE_ADD = "switcher_device_add"
|
||||
|
@@ -6,24 +6,23 @@ from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.diagnostics import async_redact_data
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DATA_DEVICE, DOMAIN
|
||||
from . import SwitcherConfigEntry
|
||||
|
||||
TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"}
|
||||
|
||||
|
||||
async def async_get_config_entry_diagnostics(
|
||||
hass: HomeAssistant, entry: ConfigEntry
|
||||
hass: HomeAssistant, entry: SwitcherConfigEntry
|
||||
) -> dict[str, Any]:
|
||||
"""Return diagnostics for a config entry."""
|
||||
devices = hass.data[DOMAIN][DATA_DEVICE]
|
||||
coordinators = entry.runtime_data
|
||||
|
||||
return async_redact_data(
|
||||
{
|
||||
"entry": entry.as_dict(),
|
||||
"devices": [asdict(devices[d].data) for d in devices],
|
||||
"devices": [asdict(coordinators[d].data) for d in coordinators],
|
||||
},
|
||||
TO_REDACT,
|
||||
)
|
||||
|
@@ -3,9 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aioswitcher.api.remotes import SwitcherBreezeRemoteManager
|
||||
from aioswitcher.bridge import SwitcherBase, SwitcherBridge
|
||||
@@ -13,29 +11,11 @@ from aioswitcher.bridge import SwitcherBase, SwitcherBridge
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import singleton
|
||||
|
||||
from .const import DATA_BRIDGE, DISCOVERY_TIME_SEC, DOMAIN
|
||||
from .const import DISCOVERY_TIME_SEC
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_start_bridge(
|
||||
hass: HomeAssistant, on_device_callback: Callable[[SwitcherBase], Any]
|
||||
) -> None:
|
||||
"""Start switcher UDP bridge."""
|
||||
bridge = hass.data[DOMAIN][DATA_BRIDGE] = SwitcherBridge(on_device_callback)
|
||||
_LOGGER.debug("Starting Switcher bridge")
|
||||
await bridge.start()
|
||||
|
||||
|
||||
async def async_stop_bridge(hass: HomeAssistant) -> None:
|
||||
"""Stop switcher UDP bridge."""
|
||||
bridge: SwitcherBridge = hass.data[DOMAIN].get(DATA_BRIDGE)
|
||||
if bridge is not None:
|
||||
_LOGGER.debug("Stopping Switcher bridge")
|
||||
await bridge.stop()
|
||||
hass.data[DOMAIN].pop(DATA_BRIDGE)
|
||||
|
||||
|
||||
async def async_has_devices(hass: HomeAssistant) -> bool:
|
||||
"""Discover Switcher devices."""
|
||||
_LOGGER.debug("Starting discovery")
|
||||
|
@@ -30,6 +30,7 @@ PLATFORMS: Final = [
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.CLIMATE,
|
||||
Platform.COVER,
|
||||
Platform.DEVICE_TRACKER,
|
||||
Platform.LOCK,
|
||||
Platform.SELECT,
|
||||
Platform.SENSOR,
|
||||
|
85
homeassistant/components/teslemetry/device_tracker.py
Normal file
85
homeassistant/components/teslemetry/device_tracker.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Device tracker platform for Teslemetry integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components.device_tracker import SourceType
|
||||
from homeassistant.components.device_tracker.config_entry import TrackerEntity
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .entity import TeslemetryVehicleEntity
|
||||
from .models import TeslemetryVehicleData
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
|
||||
) -> None:
|
||||
"""Set up the Teslemetry device tracker platform from a config entry."""
|
||||
|
||||
async_add_entities(
|
||||
klass(vehicle)
|
||||
for klass in (
|
||||
TeslemetryDeviceTrackerLocationEntity,
|
||||
TeslemetryDeviceTrackerRouteEntity,
|
||||
)
|
||||
for vehicle in entry.runtime_data.vehicles
|
||||
)
|
||||
|
||||
|
||||
class TeslemetryDeviceTrackerEntity(TeslemetryVehicleEntity, TrackerEntity):
|
||||
"""Base class for Teslemetry tracker entities."""
|
||||
|
||||
lat_key: str
|
||||
lon_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vehicle: TeslemetryVehicleData,
|
||||
) -> None:
|
||||
"""Initialize the device tracker."""
|
||||
super().__init__(vehicle, self.key)
|
||||
|
||||
def _async_update_attrs(self) -> None:
|
||||
"""Update the attributes of the device tracker."""
|
||||
|
||||
self._attr_available = (
|
||||
self.get(self.lat_key, False) is not None
|
||||
and self.get(self.lon_key, False) is not None
|
||||
)
|
||||
|
||||
@property
|
||||
def latitude(self) -> float | None:
|
||||
"""Return latitude value of the device."""
|
||||
return self.get(self.lat_key)
|
||||
|
||||
@property
|
||||
def longitude(self) -> float | None:
|
||||
"""Return longitude value of the device."""
|
||||
return self.get(self.lon_key)
|
||||
|
||||
@property
|
||||
def source_type(self) -> SourceType:
|
||||
"""Return the source type of the device tracker."""
|
||||
return SourceType.GPS
|
||||
|
||||
|
||||
class TeslemetryDeviceTrackerLocationEntity(TeslemetryDeviceTrackerEntity):
|
||||
"""Vehicle location device tracker class."""
|
||||
|
||||
key = "location"
|
||||
lat_key = "drive_state_latitude"
|
||||
lon_key = "drive_state_longitude"
|
||||
|
||||
|
||||
class TeslemetryDeviceTrackerRouteEntity(TeslemetryDeviceTrackerEntity):
|
||||
"""Vehicle navigation device tracker class."""
|
||||
|
||||
key = "route"
|
||||
lat_key = "drive_state_active_route_latitude"
|
||||
lon_key = "drive_state_active_route_longitude"
|
||||
|
||||
@property
|
||||
def location_name(self) -> str | None:
|
||||
"""Return a location name for the current location of the device."""
|
||||
return self.get("drive_state_active_route_destination")
|
@@ -109,6 +109,7 @@
|
||||
"off": "mdi:car-seat"
|
||||
}
|
||||
},
|
||||
|
||||
"components_customer_preferred_export_rule": {
|
||||
"default": "mdi:transmission-tower",
|
||||
"state": {
|
||||
@@ -126,6 +127,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"device_tracker": {
|
||||
"location": {
|
||||
"default": "mdi:map-marker"
|
||||
},
|
||||
"route": {
|
||||
"default": "mdi:routes"
|
||||
}
|
||||
},
|
||||
"cover": {
|
||||
"charge_state_charge_port_door_open": {
|
||||
"default": "mdi:ev-plug-ccs2"
|
||||
|
@@ -111,6 +111,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"device_tracker": {
|
||||
"location": {
|
||||
"name": "Location"
|
||||
},
|
||||
"route": {
|
||||
"name": "Route"
|
||||
}
|
||||
},
|
||||
"lock": {
|
||||
"charge_state_charge_port_latch": {
|
||||
"name": "Charge cable lock"
|
||||
|
@@ -13,7 +13,7 @@
|
||||
"velbus-packet",
|
||||
"velbus-protocol"
|
||||
],
|
||||
"requirements": ["velbus-aio==2024.4.1"],
|
||||
"requirements": ["velbus-aio==2024.5.1"],
|
||||
"usb": [
|
||||
{
|
||||
"vid": "10CF",
|
||||
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -157,16 +156,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
|
||||
"""Unload Withings config entry."""
|
||||
"""Unload vera config entry."""
|
||||
controller_data: ControllerData = get_controller_data(hass, config_entry)
|
||||
|
||||
tasks: list[Awaitable] = [
|
||||
hass.config_entries.async_forward_entry_unload(config_entry, platform)
|
||||
for platform in get_configured_platforms(controller_data)
|
||||
]
|
||||
tasks.append(hass.async_add_executor_job(controller_data.controller.stop))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
await asyncio.gather(
|
||||
*(
|
||||
hass.config_entries.async_unload_platforms(
|
||||
config_entry, get_configured_platforms(controller_data)
|
||||
),
|
||||
hass.async_add_executor_job(controller_data.controller.stop),
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
|
@@ -1,6 +1,5 @@
|
||||
"""Support for Zigbee Home Automation devices."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import logging
|
||||
@@ -238,12 +237,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
|
||||
websocket_api.async_unload_api(hass)
|
||||
|
||||
# our components don't have unload methods so no need to look at return values
|
||||
await asyncio.gather(
|
||||
*(
|
||||
hass.config_entries.async_forward_entry_unload(config_entry, platform)
|
||||
for platform in PLATFORMS
|
||||
)
|
||||
)
|
||||
await hass.config_entries.async_unload_platforms(config_entry, PLATFORMS)
|
||||
|
||||
return True
|
||||
|
||||
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Coroutine
|
||||
from contextlib import suppress
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -958,14 +957,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
client: ZwaveClient = entry.runtime_data[DATA_CLIENT]
|
||||
driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS]
|
||||
|
||||
tasks: list[Coroutine] = [
|
||||
hass.config_entries.async_forward_entry_unload(entry, platform)
|
||||
platforms = [
|
||||
platform
|
||||
for platform, task in driver_events.platform_setup_tasks.items()
|
||||
if not task.cancel()
|
||||
]
|
||||
|
||||
unload_ok = all(await asyncio.gather(*tasks)) if tasks else True
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
|
||||
|
||||
if client.connected and client.driver:
|
||||
await async_disable_server_logging_if_needed(hass, entry, client.driver)
|
||||
|
@@ -3,12 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
|
||||
from homeassistant.components.conversation.trace import (
|
||||
ConversationTraceEventType,
|
||||
async_conversation_trace_append,
|
||||
)
|
||||
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -116,6 +120,10 @@ class API(ABC):
|
||||
|
||||
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
|
||||
"""Call a LLM tool, validate args and return the response."""
|
||||
async_conversation_trace_append(
|
||||
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
|
||||
)
|
||||
|
||||
for tool in self.async_get_tools():
|
||||
if tool.name == tool_input.tool_name:
|
||||
break
|
||||
@@ -191,7 +199,10 @@ class AssistAPI(API):
|
||||
|
||||
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
|
||||
"""Return the prompt for the API."""
|
||||
prompt = "Call the intent tools to control Home Assistant. Just pass the name to the intent."
|
||||
prompt = (
|
||||
"Call the intent tools to control Home Assistant. "
|
||||
"Just pass the name to the intent."
|
||||
)
|
||||
if tool_input.device_id:
|
||||
device_reg = device_registry.async_get(self.hass)
|
||||
device = device_reg.async_get(tool_input.device_id)
|
||||
|
@@ -1821,7 +1821,7 @@ pyegps==0.2.5
|
||||
pyenphase==1.20.3
|
||||
|
||||
# homeassistant.components.envisalink
|
||||
pyenvisalink==4.6
|
||||
pyenvisalink==4.7
|
||||
|
||||
# homeassistant.components.ephember
|
||||
pyephember==0.3.1
|
||||
@@ -2817,7 +2817,7 @@ vallox-websocket-api==5.1.1
|
||||
vehicle==2.2.1
|
||||
|
||||
# homeassistant.components.velbus
|
||||
velbus-aio==2024.4.1
|
||||
velbus-aio==2024.5.1
|
||||
|
||||
# homeassistant.components.venstar
|
||||
venstarcolortouch==0.19
|
||||
|
@@ -2185,7 +2185,7 @@ vallox-websocket-api==5.1.1
|
||||
vehicle==2.2.1
|
||||
|
||||
# homeassistant.components.velbus
|
||||
velbus-aio==2024.4.1
|
||||
velbus-aio==2024.5.1
|
||||
|
||||
# homeassistant.components.venstar
|
||||
venstarcolortouch==0.19
|
||||
|
@@ -117,7 +117,6 @@ NO_IOT_CLASS = [
|
||||
# https://github.com/home-assistant/developers.home-assistant/pull/1512
|
||||
NO_DIAGNOSTICS = [
|
||||
"dlna_dms",
|
||||
"fronius",
|
||||
"gdacs",
|
||||
"geonetnz_quakes",
|
||||
"google_assistant_sdk",
|
||||
|
@@ -2,7 +2,9 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.core import Context, HomeAssistant, State
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
@@ -31,6 +33,11 @@ async def test_state_set_and_restore(hass: HomeAssistant) -> None:
|
||||
) as mock_process,
|
||||
patch("homeassistant.util.dt.utcnow", return_value=now),
|
||||
):
|
||||
intent_response = intent.IntentResponse(language="en")
|
||||
intent_response.async_set_speech("response text")
|
||||
mock_process.return_value = conversation.ConversationResult(
|
||||
response=intent_response,
|
||||
)
|
||||
await hass.services.async_call(
|
||||
"conversation",
|
||||
"process",
|
||||
|
80
tests/components/conversation/test_trace.py
Normal file
80
tests/components/conversation/test_trace.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Test for the conversation traces."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(hass: HomeAssistant):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
|
||||
|
||||
async def test_converation_trace(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
sl_setup: None,
|
||||
) -> None:
|
||||
"""Test tracing a conversation."""
|
||||
await conversation.async_converse(
|
||||
hass, "add apples to my shopping list", None, Context()
|
||||
)
|
||||
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
assert last_trace.get("events")
|
||||
assert len(last_trace.get("events")) == 1
|
||||
trace_event = last_trace["events"][0]
|
||||
assert (
|
||||
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
||||
)
|
||||
assert trace_event.get("data")
|
||||
assert trace_event["data"].get("text") == "add apples to my shopping list"
|
||||
assert last_trace.get("result")
|
||||
assert (
|
||||
last_trace["result"]
|
||||
.get("response", {})
|
||||
.get("speech", {})
|
||||
.get("plain", {})
|
||||
.get("speech")
|
||||
== "Added apples"
|
||||
)
|
||||
|
||||
|
||||
async def test_converation_trace_error(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
sl_setup: None,
|
||||
) -> None:
|
||||
"""Test tracing a conversation."""
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
|
||||
side_effect=HomeAssistantError("Failed to talk to agent"),
|
||||
),
|
||||
pytest.raises(HomeAssistantError),
|
||||
):
|
||||
await conversation.async_converse(
|
||||
hass, "add apples to my shopping list", None, Context()
|
||||
)
|
||||
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
assert last_trace.get("events")
|
||||
assert len(last_trace.get("events")) == 1
|
||||
trace_event = last_trace["events"][0]
|
||||
assert (
|
||||
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
|
||||
)
|
||||
assert last_trace.get("error") == "Failed to talk to agent"
|
@@ -25,6 +25,7 @@ async def setup_fronius_integration(
|
||||
"""Create the Fronius integration."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
entry_id="f1e2b9837e8adaed6fa682acaa216fd8",
|
||||
unique_id=unique_id, # has to match mocked logger unique_id
|
||||
data={
|
||||
CONF_HOST: MOCK_HOST,
|
||||
|
370
tests/components/fronius/snapshots/test_diagnostics.ambr
Normal file
370
tests/components/fronius/snapshots/test_diagnostics.ambr
Normal file
@@ -0,0 +1,370 @@
|
||||
# serializer version: 1
|
||||
# name: test_diagnostics
|
||||
dict({
|
||||
'config_entry': dict({
|
||||
'data': dict({
|
||||
'host': 'http://fronius',
|
||||
'is_logger': True,
|
||||
}),
|
||||
'disabled_by': None,
|
||||
'domain': 'fronius',
|
||||
'entry_id': 'f1e2b9837e8adaed6fa682acaa216fd8',
|
||||
'minor_version': 1,
|
||||
'options': dict({
|
||||
}),
|
||||
'pref_disable_new_entities': False,
|
||||
'pref_disable_polling': False,
|
||||
'source': 'user',
|
||||
'title': 'Mock Title',
|
||||
'unique_id': '**REDACTED**',
|
||||
'version': 1,
|
||||
}),
|
||||
'coordinators': dict({
|
||||
'inverters': dict({
|
||||
'1': dict({
|
||||
'current_ac': dict({
|
||||
'unit': 'A',
|
||||
'value': 5.19,
|
||||
}),
|
||||
'current_dc': dict({
|
||||
'unit': 'A',
|
||||
'value': 2.19,
|
||||
}),
|
||||
'energy_day': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 1113,
|
||||
}),
|
||||
'energy_total': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 44188000,
|
||||
}),
|
||||
'energy_year': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 25508798,
|
||||
}),
|
||||
'error_code': dict({
|
||||
'value': 0,
|
||||
}),
|
||||
'frequency_ac': dict({
|
||||
'unit': 'Hz',
|
||||
'value': 49.94,
|
||||
}),
|
||||
'led_color': dict({
|
||||
'value': 2,
|
||||
}),
|
||||
'led_state': dict({
|
||||
'value': 0,
|
||||
}),
|
||||
'power_ac': dict({
|
||||
'unit': 'W',
|
||||
'value': 1190,
|
||||
}),
|
||||
'status': dict({
|
||||
'Code': 0,
|
||||
'Reason': '',
|
||||
'UserMessage': '',
|
||||
}),
|
||||
'status_code': dict({
|
||||
'value': 7,
|
||||
}),
|
||||
'timestamp': dict({
|
||||
'value': '2021-10-07T10:01:17+02:00',
|
||||
}),
|
||||
'voltage_ac': dict({
|
||||
'unit': 'V',
|
||||
'value': 227.9,
|
||||
}),
|
||||
'voltage_dc': dict({
|
||||
'unit': 'V',
|
||||
'value': 518,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'logger': dict({
|
||||
'system': dict({
|
||||
'cash_factor': dict({
|
||||
'unit': 'EUR/kWh',
|
||||
'value': 0.07800000160932541,
|
||||
}),
|
||||
'co2_factor': dict({
|
||||
'unit': 'kg/kWh',
|
||||
'value': 0.5299999713897705,
|
||||
}),
|
||||
'delivery_factor': dict({
|
||||
'unit': 'EUR/kWh',
|
||||
'value': 0.15000000596046448,
|
||||
}),
|
||||
'hardware_platform': dict({
|
||||
'value': 'wilma',
|
||||
}),
|
||||
'hardware_version': dict({
|
||||
'value': '2.4E',
|
||||
}),
|
||||
'product_type': dict({
|
||||
'value': 'fronius-datamanager-card',
|
||||
}),
|
||||
'software_version': dict({
|
||||
'value': '3.18.7-1',
|
||||
}),
|
||||
'status': dict({
|
||||
'Code': 0,
|
||||
'Reason': '',
|
||||
'UserMessage': '',
|
||||
}),
|
||||
'time_zone': dict({
|
||||
'value': 'CEST',
|
||||
}),
|
||||
'time_zone_location': dict({
|
||||
'value': 'Vienna',
|
||||
}),
|
||||
'timestamp': dict({
|
||||
'value': '2021-10-06T23:56:32+02:00',
|
||||
}),
|
||||
'unique_identifier': '**REDACTED**',
|
||||
'utc_offset': dict({
|
||||
'value': 7200,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'meter': dict({
|
||||
'0': dict({
|
||||
'current_ac_phase_1': dict({
|
||||
'unit': 'A',
|
||||
'value': 7.755,
|
||||
}),
|
||||
'current_ac_phase_2': dict({
|
||||
'unit': 'A',
|
||||
'value': 6.68,
|
||||
}),
|
||||
'current_ac_phase_3': dict({
|
||||
'unit': 'A',
|
||||
'value': 10.102,
|
||||
}),
|
||||
'enable': dict({
|
||||
'value': 1,
|
||||
}),
|
||||
'energy_reactive_ac_consumed': dict({
|
||||
'unit': 'VArh',
|
||||
'value': 59960790,
|
||||
}),
|
||||
'energy_reactive_ac_produced': dict({
|
||||
'unit': 'VArh',
|
||||
'value': 723160,
|
||||
}),
|
||||
'energy_real_ac_minus': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 35623065,
|
||||
}),
|
||||
'energy_real_ac_plus': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 15303334,
|
||||
}),
|
||||
'energy_real_consumed': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 15303334,
|
||||
}),
|
||||
'energy_real_produced': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 35623065,
|
||||
}),
|
||||
'frequency_phase_average': dict({
|
||||
'unit': 'Hz',
|
||||
'value': 50,
|
||||
}),
|
||||
'manufacturer': dict({
|
||||
'value': 'Fronius',
|
||||
}),
|
||||
'meter_location': dict({
|
||||
'value': 0,
|
||||
}),
|
||||
'model': dict({
|
||||
'value': 'Smart Meter 63A',
|
||||
}),
|
||||
'power_apparent': dict({
|
||||
'unit': 'VA',
|
||||
'value': 5592.57,
|
||||
}),
|
||||
'power_apparent_phase_1': dict({
|
||||
'unit': 'VA',
|
||||
'value': 1772.793,
|
||||
}),
|
||||
'power_apparent_phase_2': dict({
|
||||
'unit': 'VA',
|
||||
'value': 1527.048,
|
||||
}),
|
||||
'power_apparent_phase_3': dict({
|
||||
'unit': 'VA',
|
||||
'value': 2333.562,
|
||||
}),
|
||||
'power_factor': dict({
|
||||
'value': 1,
|
||||
}),
|
||||
'power_factor_phase_1': dict({
|
||||
'value': -0.99,
|
||||
}),
|
||||
'power_factor_phase_2': dict({
|
||||
'value': -0.99,
|
||||
}),
|
||||
'power_factor_phase_3': dict({
|
||||
'value': 0.99,
|
||||
}),
|
||||
'power_reactive': dict({
|
||||
'unit': 'VAr',
|
||||
'value': 2.87,
|
||||
}),
|
||||
'power_reactive_phase_1': dict({
|
||||
'unit': 'VAr',
|
||||
'value': 51.48,
|
||||
}),
|
||||
'power_reactive_phase_2': dict({
|
||||
'unit': 'VAr',
|
||||
'value': 115.63,
|
||||
}),
|
||||
'power_reactive_phase_3': dict({
|
||||
'unit': 'VAr',
|
||||
'value': -164.24,
|
||||
}),
|
||||
'power_real': dict({
|
||||
'unit': 'W',
|
||||
'value': 5592.57,
|
||||
}),
|
||||
'power_real_phase_1': dict({
|
||||
'unit': 'W',
|
||||
'value': 1765.55,
|
||||
}),
|
||||
'power_real_phase_2': dict({
|
||||
'unit': 'W',
|
||||
'value': 1515.8,
|
||||
}),
|
||||
'power_real_phase_3': dict({
|
||||
'unit': 'W',
|
||||
'value': 2311.22,
|
||||
}),
|
||||
'serial': '**REDACTED**',
|
||||
'visible': dict({
|
||||
'value': 1,
|
||||
}),
|
||||
'voltage_ac_phase_1': dict({
|
||||
'unit': 'V',
|
||||
'value': 228.6,
|
||||
}),
|
||||
'voltage_ac_phase_2': dict({
|
||||
'unit': 'V',
|
||||
'value': 228.6,
|
||||
}),
|
||||
'voltage_ac_phase_3': dict({
|
||||
'unit': 'V',
|
||||
'value': 231,
|
||||
}),
|
||||
'voltage_ac_phase_to_phase_12': dict({
|
||||
'unit': 'V',
|
||||
'value': 395.9,
|
||||
}),
|
||||
'voltage_ac_phase_to_phase_23': dict({
|
||||
'unit': 'V',
|
||||
'value': 398,
|
||||
}),
|
||||
'voltage_ac_phase_to_phase_31': dict({
|
||||
'unit': 'V',
|
||||
'value': 398,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'ohmpilot': None,
|
||||
'power_flow': dict({
|
||||
'power_flow': dict({
|
||||
'energy_day': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 1101.7000732421875,
|
||||
}),
|
||||
'energy_total': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 44188000,
|
||||
}),
|
||||
'energy_year': dict({
|
||||
'unit': 'Wh',
|
||||
'value': 25508788,
|
||||
}),
|
||||
'meter_location': dict({
|
||||
'value': 'grid',
|
||||
}),
|
||||
'meter_mode': dict({
|
||||
'value': 'meter',
|
||||
}),
|
||||
'power_battery': dict({
|
||||
'unit': 'W',
|
||||
'value': None,
|
||||
}),
|
||||
'power_grid': dict({
|
||||
'unit': 'W',
|
||||
'value': 1703.74,
|
||||
}),
|
||||
'power_load': dict({
|
||||
'unit': 'W',
|
||||
'value': -2814.74,
|
||||
}),
|
||||
'power_photovoltaics': dict({
|
||||
'unit': 'W',
|
||||
'value': 1111,
|
||||
}),
|
||||
'relative_autonomy': dict({
|
||||
'unit': '%',
|
||||
'value': 39.4707859340472,
|
||||
}),
|
||||
'relative_self_consumption': dict({
|
||||
'unit': '%',
|
||||
'value': 100,
|
||||
}),
|
||||
'status': dict({
|
||||
'Code': 0,
|
||||
'Reason': '',
|
||||
'UserMessage': '',
|
||||
}),
|
||||
'timestamp': dict({
|
||||
'value': '2021-10-07T10:00:43+02:00',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'storage': None,
|
||||
}),
|
||||
'inverter_info': dict({
|
||||
'inverters': list([
|
||||
dict({
|
||||
'custom_name': dict({
|
||||
'value': 'Symo 20',
|
||||
}),
|
||||
'device_id': dict({
|
||||
'value': '1',
|
||||
}),
|
||||
'device_type': dict({
|
||||
'manufacturer': 'Fronius',
|
||||
'model': 'Symo 20.0-3-M',
|
||||
'value': 121,
|
||||
}),
|
||||
'error_code': dict({
|
||||
'value': 0,
|
||||
}),
|
||||
'pv_power': dict({
|
||||
'unit': 'W',
|
||||
'value': 23100,
|
||||
}),
|
||||
'show': dict({
|
||||
'value': 1,
|
||||
}),
|
||||
'status_code': dict({
|
||||
'value': 7,
|
||||
}),
|
||||
'unique_id': '**REDACTED**',
|
||||
}),
|
||||
]),
|
||||
'status': dict({
|
||||
'Code': 0,
|
||||
'Reason': '',
|
||||
'UserMessage': '',
|
||||
}),
|
||||
'timestamp': dict({
|
||||
'value': '2021-10-07T13:41:00+02:00',
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
31
tests/components/fronius/test_diagnostics.py
Normal file
31
tests/components/fronius/test_diagnostics.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Tests for the diagnostics data provided by the KNX integration."""
|
||||
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import mock_responses, setup_fronius_integration
|
||||
|
||||
from tests.components.diagnostics import get_diagnostics_for_config_entry
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
async def test_diagnostics(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test diagnostics."""
|
||||
mock_responses(aioclient_mock)
|
||||
entry = await setup_fronius_integration(hass)
|
||||
|
||||
assert (
|
||||
await get_diagnostics_for_config_entry(
|
||||
hass,
|
||||
hass_client,
|
||||
entry,
|
||||
)
|
||||
== snapshot
|
||||
)
|
@@ -1,4 +1,114 @@
|
||||
# serializer version: 1
|
||||
# name: test_chat_history
|
||||
list([
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 1.0,
|
||||
'top_k': 64,
|
||||
'top_p': 0.95,
|
||||
}),
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'safety_settings': dict({
|
||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
}),
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'parts': 'Ok',
|
||||
'role': 'model',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat().send_message_async',
|
||||
tuple(
|
||||
'1st user request',
|
||||
),
|
||||
dict({
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'generation_config': dict({
|
||||
'max_output_tokens': 150,
|
||||
'temperature': 1.0,
|
||||
'top_k': 64,
|
||||
'top_p': 0.95,
|
||||
}),
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'safety_settings': dict({
|
||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
}),
|
||||
'tools': None,
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat',
|
||||
tuple(
|
||||
),
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'parts': 'Ok',
|
||||
'role': 'model',
|
||||
}),
|
||||
dict({
|
||||
'parts': '1st user request',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'parts': '1st model response',
|
||||
'role': 'model',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
),
|
||||
tuple(
|
||||
'().start_chat().send_message_async',
|
||||
tuple(
|
||||
'2nd user request',
|
||||
),
|
||||
dict({
|
||||
}),
|
||||
),
|
||||
])
|
||||
# ---
|
||||
# name: test_default_prompt[config_entry_options0-None]
|
||||
list([
|
||||
tuple(
|
||||
@@ -14,10 +124,10 @@
|
||||
}),
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'safety_settings': dict({
|
||||
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
|
||||
'HATE': 'BLOCK_LOW_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
|
||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
}),
|
||||
'tools': None,
|
||||
}),
|
||||
@@ -29,7 +139,10 @@
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'Answer in plain text. Keep it simple and to the point.',
|
||||
'parts': '''
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
@@ -64,10 +177,10 @@
|
||||
}),
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'safety_settings': dict({
|
||||
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
|
||||
'HATE': 'BLOCK_LOW_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
|
||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
}),
|
||||
'tools': None,
|
||||
}),
|
||||
@@ -79,7 +192,10 @@
|
||||
dict({
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': 'Answer in plain text. Keep it simple and to the point.',
|
||||
'parts': '''
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
@@ -114,10 +230,10 @@
|
||||
}),
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'safety_settings': dict({
|
||||
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
|
||||
'HATE': 'BLOCK_LOW_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
|
||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
}),
|
||||
'tools': None,
|
||||
}),
|
||||
@@ -130,8 +246,8 @@
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
@@ -167,10 +283,10 @@
|
||||
}),
|
||||
'model_name': 'models/gemini-1.5-flash-latest',
|
||||
'safety_settings': dict({
|
||||
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
|
||||
'HATE': 'BLOCK_LOW_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
|
||||
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
|
||||
}),
|
||||
'tools': None,
|
||||
}),
|
||||
@@ -183,8 +299,8 @@
|
||||
'history': list([
|
||||
dict({
|
||||
'parts': '''
|
||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
||||
Answer in plain text. Keep it simple and to the point.
|
||||
Call the intent tools to control Home Assistant. Just pass the name to the intent.
|
||||
''',
|
||||
'role': 'user',
|
||||
}),
|
||||
|
@@ -2,12 +2,14 @@
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from google.api_core.exceptions import ClientError
|
||||
from google.api_core.exceptions import GoogleAPICallError
|
||||
import google.generativeai.types as genai_types
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -150,6 +152,57 @@ async def test_default_prompt(
|
||||
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
|
||||
|
||||
|
||||
async def test_chat_history(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the agent keeps track of the chat history."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
chat_response = MagicMock()
|
||||
mock_chat.send_message_async.return_value = chat_response
|
||||
mock_part = MagicMock()
|
||||
mock_part.function_call = None
|
||||
chat_response.parts = [mock_part]
|
||||
chat_response.text = "1st model response"
|
||||
mock_chat.history = [
|
||||
{"role": "user", "parts": "prompt"},
|
||||
{"role": "model", "parts": "Ok"},
|
||||
{"role": "user", "parts": "1st user request"},
|
||||
{"role": "model", "parts": "1st model response"},
|
||||
]
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"1st user request",
|
||||
None,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||
== "1st model response"
|
||||
)
|
||||
chat_response.text = "2nd model response"
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"2nd user request",
|
||||
result.conversation_id,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert (
|
||||
result.response.as_dict()["speech"]["plain"]["speech"]
|
||||
== "2nd model response"
|
||||
)
|
||||
|
||||
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||
)
|
||||
@@ -233,6 +286,20 @@ async def test_function_call(
|
||||
),
|
||||
)
|
||||
|
||||
# Test conversating tracing
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
trace_events = last_trace.get("events", [])
|
||||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"]
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||
@@ -325,7 +392,7 @@ async def test_error_handling(
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
mock_chat.send_message_async.side_effect = ClientError("some error")
|
||||
mock_chat.send_message_async.side_effect = GoogleAPICallError("some error")
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
@@ -340,7 +407,28 @@ async def test_error_handling(
|
||||
async def test_blocked_response(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test response was blocked."""
|
||||
"""Test blocked response."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException(
|
||||
"finish_reason: SAFETY\n"
|
||||
)
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"The message got blocked by your safety settings"
|
||||
)
|
||||
|
||||
|
||||
async def test_empty_response(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test empty response."""
|
||||
with patch("google.generativeai.GenerativeModel") as mock_model:
|
||||
mock_chat = AsyncMock()
|
||||
mock_model.return_value.start_chat.return_value = mock_chat
|
||||
@@ -358,6 +446,32 @@ async def test_blocked_response(
|
||||
)
|
||||
|
||||
|
||||
async def test_invalid_llm_api(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test handling of invalid llm api."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
|
||||
)
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
|
||||
"Error preparing LLM API: API invalid_llm_api not found"
|
||||
)
|
||||
|
||||
|
||||
async def test_template_error(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
|
@@ -529,16 +529,16 @@ async def test_non_unique_triggers(
|
||||
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
assert calls[0].data["some"] == "press1"
|
||||
assert calls[1].data["some"] == "press2"
|
||||
all_calls = {calls[0].data["some"], calls[1].data["some"]}
|
||||
assert all_calls == {"press1", "press2"}
|
||||
|
||||
# Trigger second config references to same trigger
|
||||
# and triggers both attached instances.
|
||||
async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press")
|
||||
await hass.async_block_till_done()
|
||||
assert len(calls) == 2
|
||||
assert calls[0].data["some"] == "press1"
|
||||
assert calls[1].data["some"] == "press2"
|
||||
all_calls = {calls[0].data["some"], calls[1].data["some"]}
|
||||
assert all_calls == {"press1", "press2"}
|
||||
|
||||
# Removing the first trigger will clean up
|
||||
calls.clear()
|
||||
|
@@ -4,6 +4,7 @@ import asyncio
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
import socket
|
||||
@@ -1050,6 +1051,27 @@ async def test_subscribe_topic_not_initialize(
|
||||
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||
|
||||
|
||||
async def test_subscribe_mqtt_config_entry_disabled(
|
||||
hass: HomeAssistant, mqtt_mock: MqttMockHAClient
|
||||
) -> None:
|
||||
"""Test the subscription of a topic when MQTT config entry is disabled."""
|
||||
mqtt_mock.connected = True
|
||||
|
||||
mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
|
||||
assert mqtt_config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
|
||||
assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED
|
||||
|
||||
await hass.config_entries.async_set_disabled_by(
|
||||
mqtt_config_entry.entry_id, ConfigEntryDisabler.USER
|
||||
)
|
||||
mqtt_mock.connected = False
|
||||
|
||||
with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"):
|
||||
await mqtt.async_subscribe(hass, "test-topic", record_calls)
|
||||
|
||||
|
||||
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
|
||||
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
|
||||
async def test_subscribe_and_resubscribe(
|
||||
@@ -2912,8 +2934,8 @@ async def test_message_callback_exception_gets_logged(
|
||||
await mqtt_mock_entry()
|
||||
|
||||
@callback
|
||||
def bad_handler(*args) -> None:
|
||||
"""Record calls."""
|
||||
def bad_handler(msg: ReceiveMessage) -> None:
|
||||
"""Handle callback."""
|
||||
raise ValueError("This is a bad message callback")
|
||||
|
||||
await mqtt.async_subscribe(hass, "test-topic", bad_handler)
|
||||
@@ -2926,6 +2948,40 @@ async def test_message_callback_exception_gets_logged(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.no_fail_on_log_exception
|
||||
async def test_message_partial_callback_exception_gets_logged(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
mqtt_mock_entry: MqttMockHAClientGenerator,
|
||||
) -> None:
|
||||
"""Test exception raised by message handler."""
|
||||
await mqtt_mock_entry()
|
||||
|
||||
@callback
|
||||
def bad_handler(msg: ReceiveMessage) -> None:
|
||||
"""Handle callback."""
|
||||
raise ValueError("This is a bad message callback")
|
||||
|
||||
def parial_handler(
|
||||
msg_callback: MessageCallbackType,
|
||||
attributes: set[str],
|
||||
msg: ReceiveMessage,
|
||||
) -> None:
|
||||
"""Partial callback handler."""
|
||||
msg_callback(msg)
|
||||
|
||||
await mqtt.async_subscribe(
|
||||
hass, "test-topic", partial(parial_handler, bad_handler, {"some_attr"})
|
||||
)
|
||||
async_fire_mqtt_message(hass, "test-topic", "test")
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert (
|
||||
"Exception in bad_handler when handling msg on 'test-topic':"
|
||||
" 'test'" in caplog.text
|
||||
)
|
||||
|
||||
|
||||
async def test_mqtt_ws_subscription(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
@@ -3787,7 +3843,7 @@ async def test_unload_config_entry(
|
||||
async def test_publish_or_subscribe_without_valid_config_entry(
|
||||
hass: HomeAssistant, record_calls: MessageCallbackType
|
||||
) -> None:
|
||||
"""Test internal publish function with bas use cases."""
|
||||
"""Test internal publish function with bad use cases."""
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await mqtt.async_publish(
|
||||
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None
|
||||
|
@@ -11,8 +11,12 @@ from .const import DEFAULT_FORECAST, DEFAULT_OBSERVATION
|
||||
@pytest.fixture
|
||||
def mock_simple_nws():
|
||||
"""Mock pynws SimpleNWS with default values."""
|
||||
|
||||
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws:
|
||||
# set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
|
||||
with (
|
||||
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
|
||||
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
|
||||
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
|
||||
):
|
||||
instance = mock_nws.return_value
|
||||
instance.set_station = AsyncMock(return_value=None)
|
||||
instance.update_observation = AsyncMock(return_value=None)
|
||||
@@ -29,7 +33,12 @@ def mock_simple_nws():
|
||||
@pytest.fixture
|
||||
def mock_simple_nws_times_out():
|
||||
"""Mock pynws SimpleNWS that times out."""
|
||||
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws:
|
||||
# set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
|
||||
with (
|
||||
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
|
||||
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
|
||||
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
|
||||
):
|
||||
instance = mock_nws.return_value
|
||||
instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError)
|
||||
|
@@ -1,7 +1,6 @@
|
||||
"""Tests for the NWS weather component."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import aiohttp
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
@@ -24,7 +23,6 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
|
||||
|
||||
from .const import (
|
||||
@@ -127,47 +125,43 @@ async def test_data_caching_error_observation(
|
||||
caplog,
|
||||
) -> None:
|
||||
"""Test caching of data with errors."""
|
||||
with (
|
||||
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
|
||||
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
|
||||
):
|
||||
instance = mock_simple_nws.return_value
|
||||
instance = mock_simple_nws.return_value
|
||||
|
||||
entry = MockConfigEntry(
|
||||
domain=nws.DOMAIN,
|
||||
data=NWS_CONFIG,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
entry = MockConfigEntry(
|
||||
domain=nws.DOMAIN,
|
||||
data=NWS_CONFIG,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("weather.abc")
|
||||
assert state.state == "sunny"
|
||||
state = hass.states.get("weather.abc")
|
||||
assert state.state == "sunny"
|
||||
|
||||
# data is still valid even when update fails
|
||||
instance.update_observation.side_effect = NwsNoDataError("Test")
|
||||
# data is still valid even when update fails
|
||||
instance.update_observation.side_effect = NwsNoDataError("Test")
|
||||
|
||||
freezer.tick(DEFAULT_SCAN_INTERVAL + timedelta(seconds=100))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done()
|
||||
freezer.tick(DEFAULT_SCAN_INTERVAL + timedelta(seconds=100))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("weather.abc")
|
||||
assert state.state == "sunny"
|
||||
state = hass.states.get("weather.abc")
|
||||
assert state.state == "sunny"
|
||||
|
||||
assert (
|
||||
"NWS observation update failed, but data still valid. Last success: "
|
||||
in caplog.text
|
||||
)
|
||||
assert (
|
||||
"NWS observation update failed, but data still valid. Last success: "
|
||||
in caplog.text
|
||||
)
|
||||
|
||||
# data is no longer valid after OBSERVATION_VALID_TIME
|
||||
freezer.tick(OBSERVATION_VALID_TIME + timedelta(seconds=1))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done()
|
||||
# data is no longer valid after OBSERVATION_VALID_TIME
|
||||
freezer.tick(OBSERVATION_VALID_TIME + timedelta(seconds=1))
|
||||
async_fire_time_changed(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("weather.abc")
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
state = hass.states.get("weather.abc")
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
assert "Error fetching NWS observation station ABC data: Test" in caplog.text
|
||||
assert "Error fetching NWS observation station ABC data: Test" in caplog.text
|
||||
|
||||
|
||||
async def test_no_data_error_observation(
|
||||
@@ -302,26 +296,23 @@ async def test_error_observation(
|
||||
hass: HomeAssistant, mock_simple_nws, no_sensor
|
||||
) -> None:
|
||||
"""Test error during update observation."""
|
||||
utc_time = dt_util.utcnow()
|
||||
with patch("homeassistant.components.nws.coordinator.utcnow") as mock_utc:
|
||||
mock_utc.return_value = utc_time
|
||||
instance = mock_simple_nws.return_value
|
||||
# first update fails
|
||||
instance.update_observation.side_effect = aiohttp.ClientError
|
||||
instance = mock_simple_nws.return_value
|
||||
# first update fails
|
||||
instance.update_observation.side_effect = aiohttp.ClientError
|
||||
|
||||
entry = MockConfigEntry(
|
||||
domain=nws.DOMAIN,
|
||||
data=NWS_CONFIG,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
entry = MockConfigEntry(
|
||||
domain=nws.DOMAIN,
|
||||
data=NWS_CONFIG,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
instance.update_observation.assert_called_once()
|
||||
instance.update_observation.assert_called_once()
|
||||
|
||||
state = hass.states.get("weather.abc")
|
||||
assert state
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
state = hass.states.get("weather.abc")
|
||||
assert state
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def test_new_config_entry(hass: HomeAssistant, no_sensor) -> None:
|
||||
|
@@ -6,6 +6,7 @@ from ollama import Message, ResponseError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation, ollama
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
@@ -110,6 +111,19 @@ async def test_chat(
|
||||
), result
|
||||
assert result.response.speech["plain"]["speech"] == "test response"
|
||||
|
||||
# Test Conversation tracing
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
trace_events = last_trace.get("events", [])
|
||||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
|
||||
|
||||
|
||||
async def test_message_history_trimming(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
|
@@ -9,9 +9,17 @@ import pytest
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.openai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
CONF_MAX_TOKENS,
|
||||
CONF_PROMPT,
|
||||
CONF_RECOMMENDED,
|
||||
CONF_TEMPERATURE,
|
||||
CONF_TOP_P,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TOP_P,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
@@ -75,7 +83,7 @@ async def test_options(
|
||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert options["data"]["prompt"] == "Speak like a pirate"
|
||||
assert options["data"]["max_tokens"] == 200
|
||||
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
|
||||
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
|
||||
|
||||
assert result2["type"] is FlowResultType.FORM
|
||||
assert result2["errors"] == {"base": error}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("current_options", "new_options", "expected_options"),
|
||||
[
|
||||
(
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "none",
|
||||
CONF_PROMPT: "bla",
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_TEMPERATURE: 0.3,
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_TEMPERATURE: 0.3,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
},
|
||||
),
|
||||
(
|
||||
{
|
||||
CONF_RECOMMENDED: False,
|
||||
CONF_PROMPT: "Speak like a pirate",
|
||||
CONF_TEMPERATURE: 0.3,
|
||||
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
|
||||
CONF_TOP_P: RECOMMENDED_TOP_P,
|
||||
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "assist",
|
||||
CONF_PROMPT: "",
|
||||
},
|
||||
{
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: "assist",
|
||||
CONF_PROMPT: "",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_options_switching(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry,
|
||||
mock_init_component,
|
||||
current_options,
|
||||
new_options,
|
||||
expected_options,
|
||||
) -> None:
|
||||
"""Test the options form."""
|
||||
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
|
||||
options_flow = await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{
|
||||
**current_options,
|
||||
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
|
||||
},
|
||||
)
|
||||
options = await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
new_options,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert options["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert options["data"] == expected_options
|
||||
|
@@ -15,6 +15,7 @@ from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.conversation import trace
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -200,6 +201,20 @@ async def test_function_call(
|
||||
),
|
||||
)
|
||||
|
||||
# Test Conversation tracing
|
||||
traces = trace.async_get_traces()
|
||||
assert traces
|
||||
last_trace = traces[-1].as_dict()
|
||||
trace_events = last_trace.get("events", [])
|
||||
assert [event["event_type"] for event in trace_events] == [
|
||||
trace.ConversationTraceEventType.ASYNC_PROCESS,
|
||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
||||
trace.ConversationTraceEventType.LLM_TOOL_CALL,
|
||||
]
|
||||
# AGENT_DETAIL event contains the raw prompt passed to the model
|
||||
detail_event = trace_events[1]
|
||||
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
|
||||
|
||||
|
||||
@patch(
|
||||
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
from pyplaato.models.airlock import PlaatoAirlock
|
||||
from pyplaato.models.device import PlaatoDeviceType
|
||||
from pyplaato.models.keg import PlaatoKeg
|
||||
@@ -23,6 +24,7 @@ AIRLOCK_DATA = {}
|
||||
KEG_DATA = {}
|
||||
|
||||
|
||||
@freeze_time("2024-05-24 12:00:00", tz_offset=0)
|
||||
async def init_integration(
|
||||
hass: HomeAssistant, device_type: PlaatoDeviceType
|
||||
) -> MockConfigEntry:
|
||||
|
@@ -492,7 +492,6 @@ async def test_block_set_mode_auth_error(
|
||||
{ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.state is ConfigEntryState.LOADED
|
||||
|
||||
|
@@ -227,7 +227,6 @@ async def test_block_set_value_auth_error(
|
||||
{ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.state is ConfigEntryState.LOADED
|
||||
|
||||
|
@@ -618,7 +618,6 @@ async def test_rpc_sleeping_update_entity_service(
|
||||
service_data={ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Entity should be available after update_entity service call
|
||||
state = hass.states.get(entity_id)
|
||||
@@ -667,7 +666,6 @@ async def test_block_sleeping_update_entity_service(
|
||||
service_data={ATTR_ENTITY_ID: entity_id},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Entity should be available after update_entity service call
|
||||
state = hass.states.get(entity_id)
|
||||
|
@@ -230,7 +230,6 @@ async def test_block_set_state_auth_error(
|
||||
{ATTR_ENTITY_ID: "switch.test_name_channel_1"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.state is ConfigEntryState.LOADED
|
||||
|
||||
@@ -374,7 +373,6 @@ async def test_rpc_auth_error(
|
||||
{ATTR_ENTITY_ID: "switch.test_switch_0"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.state is ConfigEntryState.LOADED
|
||||
|
||||
|
@@ -207,7 +207,6 @@ async def test_block_update_auth_error(
|
||||
{ATTR_ENTITY_ID: "update.test_name_firmware_update"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.state is ConfigEntryState.LOADED
|
||||
|
||||
@@ -669,7 +668,6 @@ async def test_rpc_update_auth_error(
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
await hass.async_block_till_done()
|
||||
assert entry.state is ConfigEntryState.LOADED
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
|
@@ -18,9 +18,15 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
||||
@pytest.fixture
|
||||
def mock_bridge(request):
|
||||
"""Return a mocked SwitcherBridge."""
|
||||
with patch(
|
||||
"homeassistant.components.switcher_kis.utils.SwitcherBridge", autospec=True
|
||||
) as bridge_mock:
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.switcher_kis.SwitcherBridge", autospec=True
|
||||
) as bridge_mock,
|
||||
patch(
|
||||
"homeassistant.components.switcher_kis.utils.SwitcherBridge",
|
||||
new=bridge_mock,
|
||||
),
|
||||
):
|
||||
bridge = bridge_mock.return_value
|
||||
|
||||
bridge.devices = []
|
||||
|
@@ -4,11 +4,7 @@ from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.switcher_kis.const import (
|
||||
DATA_DEVICE,
|
||||
DOMAIN,
|
||||
MAX_UPDATE_INTERVAL_SEC,
|
||||
)
|
||||
from homeassistant.components.switcher_kis.const import MAX_UPDATE_INTERVAL_SEC
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import STATE_UNAVAILABLE
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -24,15 +20,14 @@ async def test_update_fail(
|
||||
hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test entities state unavailable when updates fail.."""
|
||||
await init_integration(hass)
|
||||
entry = await init_integration(hass)
|
||||
assert mock_bridge
|
||||
|
||||
mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_bridge.is_running is True
|
||||
assert len(hass.data[DOMAIN]) == 2
|
||||
assert len(hass.data[DOMAIN][DATA_DEVICE]) == 2
|
||||
assert len(entry.runtime_data) == 2
|
||||
|
||||
async_fire_time_changed(
|
||||
hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1)
|
||||
@@ -77,11 +72,9 @@ async def test_entry_unload(hass: HomeAssistant, mock_bridge) -> None:
|
||||
|
||||
assert entry.state is ConfigEntryState.LOADED
|
||||
assert mock_bridge.is_running is True
|
||||
assert len(hass.data[DOMAIN]) == 2
|
||||
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.state is ConfigEntryState.NOT_LOADED
|
||||
assert mock_bridge.is_running is False
|
||||
assert len(hass.data[DOMAIN]) == 0
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user