Merge branch 'dev' into jbouwh-mqtt-device-discovery

This commit is contained in:
J. Nick Koston
2024-05-25 11:39:47 -10:00
committed by GitHub
106 changed files with 3208 additions and 1980 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import dataclasses
import logging
from typing import Any
@@ -20,6 +21,11 @@ from .models import (
ConversationInput,
ConversationResult,
)
from .trace import (
ConversationTraceEvent,
ConversationTraceEventType,
async_conversation_trace,
)
_LOGGER = logging.getLogger(__name__)
@@ -84,15 +90,23 @@ async def async_converse(
language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text)
return await method(
ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
)
conversation_input = ConversationInput(
text=text,
context=context,
conversation_id=conversation_id,
device_id=device_id,
language=language,
)
with async_conversation_trace() as trace:
trace.add_event(
ConversationTraceEvent(
ConversationTraceEventType.ASYNC_PROCESS,
dataclasses.asdict(conversation_input),
)
)
result = await method(conversation_input)
trace.set_result(**result.as_dict())
return result
class AgentManager:

View File

@@ -0,0 +1,118 @@
"""Debug traces for conversation."""
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import asdict, dataclass, field
import enum
from typing import Any
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.limited_size_dict import LimitedSizeDict
STORED_TRACES = 3
class ConversationTraceEventType(enum.StrEnum):
"""Type of an event emitted during a conversation."""
ASYNC_PROCESS = "async_process"
"""The conversation is started from user input."""
AGENT_DETAIL = "agent_detail"
"""Event detail added by a conversation agent."""
LLM_TOOL_CALL = "llm_tool_call"
"""An LLM Tool call"""
@dataclass(frozen=True)
class ConversationTraceEvent:
"""Event emitted during a conversation."""
event_type: ConversationTraceEventType
data: dict[str, Any] | None = None
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
class ConversationTrace:
"""Stores debug data related to a conversation."""
def __init__(self) -> None:
"""Initialize ConversationTrace."""
self._trace_id = ulid_util.ulid_now()
self._events: list[ConversationTraceEvent] = []
self._error: Exception | None = None
self._result: dict[str, Any] = {}
@property
def trace_id(self) -> str:
"""Identifier for this trace."""
return self._trace_id
def add_event(self, event: ConversationTraceEvent) -> None:
"""Add an event to the trace."""
self._events.append(event)
def set_error(self, ex: Exception) -> None:
"""Set error."""
self._error = ex
def set_result(self, **kwargs: Any) -> None:
"""Set result."""
self._result = {**kwargs}
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this ConversationTrace."""
result: dict[str, Any] = {
"id": self._trace_id,
"events": [asdict(event) for event in self._events],
}
if self._error is not None:
result["error"] = str(self._error) or self._error.__class__.__name__
if self._result is not None:
result["result"] = self._result
return result
_current_trace: ContextVar[ConversationTrace | None] = ContextVar(
"current_trace", default=None
)
_recent_traces: LimitedSizeDict[str, ConversationTrace] = LimitedSizeDict(
size_limit=STORED_TRACES
)
def async_conversation_trace_append(
event_type: ConversationTraceEventType, event_data: dict[str, Any]
) -> None:
"""Append a ConversationTraceEvent to the current active trace."""
trace = _current_trace.get()
if not trace:
return
trace.add_event(ConversationTraceEvent(event_type, event_data))
@contextmanager
def async_conversation_trace() -> Generator[ConversationTrace, None]:
"""Create a new active ConversationTrace."""
trace = ConversationTrace()
token = _current_trace.set(trace)
_recent_traces[trace.trace_id] = trace
try:
yield trace
except Exception as ex:
trace.set_error(ex)
raise
finally:
_current_trace.reset(token)
def async_get_traces() -> list[ConversationTrace]:
"""Get the most recent traces."""
return list(_recent_traces.values())
def async_clear_traces() -> None:
"""Clear all traces."""
_recent_traces.clear()

View File

@@ -5,5 +5,5 @@
"documentation": "https://www.home-assistant.io/integrations/envisalink",
"iot_class": "local_push",
"loggers": ["pyenvisalink"],
"requirements": ["pyenvisalink==4.6"]
"requirements": ["pyenvisalink==4.7"]
}

View File

@@ -1,6 +1,5 @@
"""Support for Arduino-compatible Microcontrollers through Firmata."""
import asyncio
from copy import copy
import logging
@@ -212,16 +211,15 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Shutdown and close a Firmata board for a config entry."""
_LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME])
unload_entries = []
for conf, platform in CONF_PLATFORM_MAP.items():
if conf in config_entry.data:
unload_entries.append(
hass.config_entries.async_forward_entry_unload(config_entry, platform)
)
results = []
if unload_entries:
results = await asyncio.gather(*unload_entries)
results: list[bool] = []
if platforms := [
platform
for conf, platform in CONF_PLATFORM_MAP.items()
if conf in config_entry.data
]:
results.append(
await hass.config_entries.async_unload_platforms(config_entry, platforms)
)
results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset())
return False not in results

View File

@@ -11,12 +11,13 @@ from .const import (
CONF_DAMPING_EVENING,
CONF_DAMPING_MORNING,
CONF_MODULES_POWER,
DOMAIN,
)
from .coordinator import ForecastSolarDataUpdateCoordinator
PLATFORMS = [Platform.SENSOR]
type ForecastSolarConfigEntry = ConfigEntry[ForecastSolarDataUpdateCoordinator]
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Migrate old config entry."""
@@ -36,12 +37,14 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup_entry(
hass: HomeAssistant, entry: ForecastSolarConfigEntry
) -> bool:
"""Set up Forecast.Solar from a config entry."""
coordinator = ForecastSolarDataUpdateCoordinator(hass, entry)
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
entry.runtime_data = coordinator
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
@@ -52,11 +55,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None:

View File

@@ -4,15 +4,11 @@ from __future__ import annotations
from typing import Any
from forecast_solar import Estimate
from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from .const import DOMAIN
from . import ForecastSolarConfigEntry
TO_REDACT = {
CONF_API_KEY,
@@ -22,10 +18,10 @@ TO_REDACT = {
async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry
hass: HomeAssistant, entry: ForecastSolarConfigEntry
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
coordinator: DataUpdateCoordinator[Estimate] = hass.data[DOMAIN][entry.entry_id]
coordinator = entry.runtime_data
return {
"entry": {

View File

@@ -4,19 +4,21 @@ from __future__ import annotations
from homeassistant.core import HomeAssistant
from .const import DOMAIN
from .coordinator import ForecastSolarDataUpdateCoordinator
async def async_get_solar_forecast(
hass: HomeAssistant, config_entry_id: str
) -> dict[str, dict[str, float | int]] | None:
"""Get solar forecast for a config entry ID."""
if (coordinator := hass.data[DOMAIN].get(config_entry_id)) is None:
if (
entry := hass.config_entries.async_get_entry(config_entry_id)
) is None or not isinstance(entry.runtime_data, ForecastSolarDataUpdateCoordinator):
return None
return {
"wh_hours": {
timestamp.isoformat(): val
for timestamp, val in coordinator.data.wh_period.items()
for timestamp, val in entry.runtime_data.data.wh_period.items()
}
}

View File

@@ -16,7 +16,6 @@ from homeassistant.components.sensor import (
SensorEntityDescription,
SensorStateClass,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import UnitOfEnergy, UnitOfPower
from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
@@ -24,6 +23,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import ForecastSolarConfigEntry
from .const import DOMAIN
from .coordinator import ForecastSolarDataUpdateCoordinator
@@ -133,10 +133,12 @@ SENSORS: tuple[ForecastSolarSensorEntityDescription, ...] = (
async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
hass: HomeAssistant,
entry: ForecastSolarConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Defer sensor setup to the shared sensor module."""
coordinator: ForecastSolarDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id]
coordinator = entry.runtime_data
async_add_entities(
ForecastSolarSensorEntity(

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
import logging
from typing import Final
from typing import Any, Final
from homeassistant.components.button import (
ButtonDeviceClass,
@@ -30,7 +30,7 @@ _LOGGER = logging.getLogger(__name__)
class FritzButtonDescription(ButtonEntityDescription):
"""Class to describe a Button entity."""
press_action: Callable
press_action: Callable[[AvmWrapper], Any]
BUTTONS: Final = [

View File

@@ -57,9 +57,6 @@ ERROR_UPNP_NOT_CONFIGURED = "upnp_not_configured"
ERROR_UNKNOWN = "unknown_error"
FRITZ_SERVICES = "fritz_services"
SERVICE_REBOOT = "reboot"
SERVICE_RECONNECT = "reconnect"
SERVICE_CLEANUP = "cleanup"
SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password"
SWITCH_TYPE_DEFLECTION = "CallDeflection"

View File

@@ -46,9 +46,6 @@ from .const import (
DEFAULT_USERNAME,
DOMAIN,
FRITZ_EXCEPTIONS,
SERVICE_CLEANUP,
SERVICE_REBOOT,
SERVICE_RECONNECT,
SERVICE_SET_GUEST_WIFI_PW,
MeshRoles,
)
@@ -730,30 +727,6 @@ class FritzBoxTools(DataUpdateCoordinator[UpdateCoordinatorDataType]):
)
try:
if service_call.service == SERVICE_REBOOT:
_LOGGER.warning(
'Service "fritz.reboot" is deprecated, please use the corresponding'
" button entity instead"
)
await self.async_trigger_reboot()
return
if service_call.service == SERVICE_RECONNECT:
_LOGGER.warning(
'Service "fritz.reconnect" is deprecated, please use the'
" corresponding button entity instead"
)
await self.async_trigger_reconnect()
return
if service_call.service == SERVICE_CLEANUP:
_LOGGER.warning(
'Service "fritz.cleanup" is deprecated, please use the'
" corresponding button entity instead"
)
await self.async_trigger_cleanup(config_entry)
return
if service_call.service == SERVICE_SET_GUEST_WIFI_PW:
await self.async_trigger_set_guest_password(
service_call.data.get("password"),

View File

@@ -11,14 +11,7 @@ from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.service import async_extract_config_entry_ids
from .const import (
DOMAIN,
FRITZ_SERVICES,
SERVICE_CLEANUP,
SERVICE_REBOOT,
SERVICE_RECONNECT,
SERVICE_SET_GUEST_WIFI_PW,
)
from .const import DOMAIN, FRITZ_SERVICES, SERVICE_SET_GUEST_WIFI_PW
from .coordinator import AvmWrapper
_LOGGER = logging.getLogger(__name__)
@@ -32,9 +25,6 @@ SERVICE_SCHEMA_SET_GUEST_WIFI_PW = vol.Schema(
)
SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [
(SERVICE_CLEANUP, None),
(SERVICE_REBOOT, None),
(SERVICE_RECONNECT, None),
(SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW),
]

View File

@@ -1,31 +1,3 @@
reconnect:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
reboot:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
cleanup:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
set_guest_wifi_password:
fields:
device_id:

View File

@@ -144,42 +144,12 @@
}
},
"services": {
"reconnect": {
"name": "[%key:component::fritz::entity::button::reconnect::name%]",
"description": "Reconnects your FRITZ!Box internet connection.",
"fields": {
"device_id": {
"name": "Fritz!Box Device",
"description": "Select the Fritz!Box to reconnect."
}
}
},
"reboot": {
"name": "Reboot",
"description": "Reboots your FRITZ!Box.",
"fields": {
"device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
"description": "Select the Fritz!Box to reboot."
}
}
},
"cleanup": {
"name": "Remove stale device tracker entities",
"description": "Remove FRITZ!Box stale device_tracker entities.",
"fields": {
"device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
"description": "Select the Fritz!Box to check."
}
}
},
"set_guest_wifi_password": {
"name": "Set guest Wi-Fi password",
"description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.",
"fields": {
"device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
"name": "Fritz!Box Device",
"description": "Select the Fritz!Box to configure."
},
"password": {

View File

@@ -0,0 +1,46 @@
"""Diagnostics support for Fronius."""
from typing import Any
from homeassistant.components.diagnostics import async_redact_data
from homeassistant.core import HomeAssistant
from . import FroniusConfigEntry
TO_REDACT = {"unique_id", "unique_identifier", "serial"}
async def async_get_config_entry_diagnostics(
hass: HomeAssistant, config_entry: FroniusConfigEntry
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
diag: dict[str, Any] = {}
solar_net = config_entry.runtime_data
fronius = solar_net.fronius
diag["config_entry"] = config_entry.as_dict()
diag["inverter_info"] = await fronius.inverter_info()
diag["coordinators"] = {"inverters": {}}
for inv in solar_net.inverter_coordinators:
diag["coordinators"]["inverters"] |= inv.data
diag["coordinators"]["logger"] = (
solar_net.logger_coordinator.data if solar_net.logger_coordinator else None
)
diag["coordinators"]["meter"] = (
solar_net.meter_coordinator.data if solar_net.meter_coordinator else None
)
diag["coordinators"]["ohmpilot"] = (
solar_net.ohmpilot_coordinator.data if solar_net.ohmpilot_coordinator else None
)
diag["coordinators"]["power_flow"] = (
solar_net.power_flow_coordinator.data
if solar_net.power_flow_coordinator
else None
)
diag["coordinators"]["storage"] = (
solar_net.storage_coordinator.data if solar_net.storage_coordinator else None
)
return async_redact_data(diag, TO_REDACT)

View File

@@ -181,8 +181,7 @@ async def google_generative_ai_config_option_schema(
schema = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,

View File

@@ -22,4 +22,4 @@ CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_LOW_AND_ABOVE"
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"

View File

@@ -5,13 +5,14 @@ from __future__ import annotations
from typing import Any, Literal
import google.ai.generativelanguage as glm
from google.api_core.exceptions import ClientError
from google.api_core.exceptions import GoogleAPICallError
import google.generativeai as genai
import google.generativeai.types as genai_types
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
@@ -205,15 +206,6 @@ class GoogleGenerativeAIConversationEntity(
messages = [{}, {}]
try:
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
@@ -226,9 +218,24 @@ class GoogleGenerativeAIConversationEntity(
device_id=user_input.device_id,
)
prompt = (
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
else:
api_prompt = llm.PROMPT_NO_API_CONFIGURED
prompt = "\n".join(
(
template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
),
api_prompt,
)
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
@@ -244,6 +251,9 @@ class GoogleGenerativeAIConversationEntity(
messages[1] = {"role": "model", "parts": "Ok"}
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
chat = model.start_chat(history=messages)
chat_request = user_input.text
@@ -252,15 +262,25 @@ class GoogleGenerativeAIConversationEntity(
try:
chat_response = await chat.send_message_async(chat_request)
except (
ClientError,
GoogleAPICallError,
ValueError,
genai_types.BlockedPromptException,
genai_types.StopCandidateException,
) as err:
LOGGER.error("Error sending message: %s", err)
LOGGER.error("Error sending message: %s %s", type(err), err)
if isinstance(
err, genai_types.StopCandidateException
) and "finish_reason: SAFETY\n" in str(err):
error = "The message got blocked by your safety settings"
else:
error = (
f"Sorry, I had a problem talking to Google Generative AI: {err}"
)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to Google Generative AI: {err}",
error,
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id

View File

@@ -1,6 +1,6 @@
{
"domain": "integration",
"name": "Integration - Riemann sum integral",
"name": "Integral",
"after_dependencies": ["counter"],
"codeowners": ["@dgomes"],
"config_flow": true,

View File

@@ -1,5 +1,5 @@
{
"title": "Integration - Riemann sum integral sensor",
"title": "Integral sensor",
"config": {
"step": {
"user": {

View File

@@ -21,7 +21,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
CONF_VALIDATOR = "validator"
CONF_SECRET = "secret"
URL = "/api/meraki"
VERSION = "2.0"
ACCEPTED_VERSIONS = ["2.0", "2.1"]
_LOGGER = logging.getLogger(__name__)
@@ -74,7 +74,7 @@ class MerakiView(HomeAssistantView):
if data["secret"] != self.secret:
_LOGGER.error("Invalid Secret received from Meraki")
return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY)
if data["version"] != VERSION:
if data["version"] not in ACCEPTED_VERSIONS:
_LOGGER.error("Invalid API version: %s", data["version"])
return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY)
_LOGGER.debug("Valid Secret")

View File

@@ -6,6 +6,6 @@
"documentation": "https://www.home-assistant.io/integrations/minecraft_server",
"iot_class": "local_polling",
"loggers": ["dnspython", "mcstatus"],
"quality_scale": "gold",
"quality_scale": "platinum",
"requirements": ["mcstatus==11.1.1"]
}

View File

@@ -39,6 +39,7 @@ from .client import ( # noqa: F401
MQTT,
async_publish,
async_subscribe,
async_subscribe_internal,
publish,
subscribe,
)
@@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
def collect_msg(msg: ReceiveMessage) -> None:
messages.append((msg.topic, str(msg.payload).replace("\n", "")))
unsub = await async_subscribe(hass, call.data["topic"], collect_msg)
unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)
def write_dump() -> None:
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
@@ -459,7 +460,7 @@ async def websocket_subscribe(
# Perform UTF-8 decoding directly in callback routine
qos: int = msg.get("qos", DEFAULT_QOS)
connection.subscriptions[msg["id"]] = await async_subscribe(
connection.subscriptions[msg["id"]] = async_subscribe_internal(
hass, msg["topic"], forward_messages, encoding=None, qos=qos
)
@@ -522,24 +523,13 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
mqtt_client = mqtt_data.client
# Unload publish and dump services.
hass.services.async_remove(
DOMAIN,
SERVICE_PUBLISH,
)
hass.services.async_remove(
DOMAIN,
SERVICE_DUMP,
)
hass.services.async_remove(DOMAIN, SERVICE_PUBLISH)
hass.services.async_remove(DOMAIN, SERVICE_DUMP)
# Stop the discovery
await discovery.async_stop(hass)
# Unload the platforms
await asyncio.gather(
*(
hass.config_entries.async_forward_entry_unload(entry, component)
for component in mqtt_data.platforms_loaded
)
)
await hass.config_entries.async_unload_platforms(entry, mqtt_data.platforms_loaded)
mqtt_data.platforms_loaded = set()
await asyncio.sleep(0)
# Unsubscribe reload dispatchers

View File

@@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_alarm_disarm(self, code: str | None = None) -> None:
"""Send disarm command.

View File

@@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@callback
def _value_is_expired(self, *_: Any) -> None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from base64 import b64decode
from functools import partial
import logging
from typing import TYPE_CHECKING
@@ -20,7 +21,6 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription
from .config import MQTT_BASE_SCHEMA
from .const import CONF_QOS, CONF_TOPIC
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@@ -97,27 +97,31 @@ class MqttCamera(MqttEntity, Camera):
"""Return the config schema."""
return DISCOVERY_SCHEMA
@callback
def _image_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._config[CONF_TOPIC],
"msg_callback": message_received,
"msg_callback": partial(
self._message_callback,
self._image_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": None,
}
@@ -126,7 +130,7 @@ class MqttCamera(MqttEntity, Camera):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_camera_image(
self, width: int | None = None, height: int | None = None

View File

@@ -77,7 +77,6 @@ from .const import (
)
from .models import (
DATA_MQTT,
AsyncMessageCallbackType,
MessageCallbackType,
MqttData,
PublishMessage,
@@ -184,7 +183,7 @@ async def async_publish(
async def async_subscribe(
hass: HomeAssistant,
topic: str,
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE:
@@ -192,13 +191,25 @@ async def async_subscribe(
Call the return value to unsubscribe.
"""
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
translation_key="mqtt_not_setup_cannot_subscribe",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
)
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
@callback
def async_subscribe_internal(
hass: HomeAssistant,
topic: str,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic.
This function is internal to the MQTT integration
and may change at any time. It should not be considered
a stable API.
Call the return value to unsubscribe.
"""
try:
mqtt_data = hass.data[DATA_MQTT]
except KeyError as exc:
@@ -209,12 +220,15 @@ async def async_subscribe(
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
) from exc
return await mqtt_data.client.async_subscribe(
topic,
msg_callback,
qos,
encoding,
)
client = mqtt_data.client
if not client.connected and not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
translation_key="mqtt_not_setup_cannot_subscribe",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
)
return client.async_subscribe(topic, msg_callback, qos, encoding)
@bind_hass
@@ -429,10 +443,10 @@ class MQTT:
self.config_entry = config_entry
self.conf = conf
self._simple_subscriptions: defaultdict[str, list[Subscription]] = defaultdict(
list
self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
set
)
self._wildcard_subscriptions: list[Subscription] = []
self._wildcard_subscriptions: set[Subscription] = set()
# _retained_topics prevents a Subscription from receiving a
# retained message more than once per topic. This prevents flooding
# already active subscribers when new subscribers subscribe to a topic
@@ -452,7 +466,7 @@ class MQTT:
self._should_reconnect: bool = True
self._available_future: asyncio.Future[bool] | None = None
self._max_qos: dict[str, int] = {} # topic, max qos
self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos
self._pending_subscriptions: dict[str, int] = {} # topic, qos
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
@@ -789,9 +803,9 @@ class MQTT:
The caller is responsible clearing the cache of _matching_subscriptions.
"""
if subscription.is_simple_match:
self._simple_subscriptions[subscription.topic].append(subscription)
self._simple_subscriptions[subscription.topic].add(subscription)
else:
self._wildcard_subscriptions.append(subscription)
self._wildcard_subscriptions.add(subscription)
@callback
def _async_untrack_subscription(self, subscription: Subscription) -> None:
@@ -820,8 +834,8 @@ class MQTT:
"""Queue requested subscriptions."""
for subscription in subscriptions:
topic, qos = subscription
max_qos = max(qos, self._max_qos.setdefault(topic, qos))
self._max_qos[topic] = max_qos
if (max_qos := self._max_qos[topic]) < qos:
self._max_qos[topic] = (max_qos := qos)
self._pending_subscriptions[topic] = max_qos
# Cancel any pending unsubscribe since we are subscribing now
if topic in self._pending_unsubscribes:
@@ -832,26 +846,29 @@ class MQTT:
def _exception_message(
self,
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
msg: ReceiveMessage,
) -> str:
"""Return a string with the exception message."""
# if msg_callback is a partial we return the name of the first argument
if isinstance(msg_callback, partial):
call_back_name = getattr(msg_callback.args[0], "__name__") # type: ignore[unreachable]
else:
call_back_name = getattr(msg_callback, "__name__")
return (
f"Exception in {msg_callback.__name__} when handling msg on "
f"Exception in {call_back_name} when handling msg on "
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
)
async def async_subscribe(
@callback
def async_subscribe(
self,
topic: str,
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int,
encoding: str | None = None,
) -> Callable[[], None]:
"""Set up a subscription to a topic with the provided qos.
This method is a coroutine.
"""
"""Set up a subscription to a topic with the provided qos."""
if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!")
@@ -877,18 +894,18 @@ class MQTT:
if self.connected:
self._async_queue_subscriptions(((topic, qos),))
@callback
def async_remove() -> None:
"""Remove subscription."""
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
if subscription in self._retained_topics:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(topic)
return partial(self._async_remove, subscription)
return async_remove
@callback
def _async_remove(self, subscription: Subscription) -> None:
"""Remove subscription."""
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
if subscription in self._retained_topics:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(subscription.topic)
@callback
def _async_unsubscribe(self, topic: str) -> None:
@@ -1257,9 +1274,7 @@ class MQTT:
last_discovery = self._mqtt_data.last_discovery
last_subscribe = now if self._pending_subscriptions else self._last_subscribe
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
while now < wait_until:
await asyncio.sleep(wait_until - now)
now = time.monotonic()
@@ -1267,9 +1282,7 @@ class MQTT:
last_subscribe = (
now if self._pending_subscriptions else self._last_subscribe
)
wait_until = max(
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:

View File

@@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
if self._topic[topic] is not None:

View File

@@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_open_cover(self, **kwargs: Any) -> None:
"""Move the cover up.

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
from typing import TYPE_CHECKING
@@ -32,13 +33,7 @@ from homeassistant.helpers.typing import ConfigType
from . import subscription
from .config import MQTT_BASE_SCHEMA
from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC
from .debug_info import log_messages
from .mixins import (
CONF_JSON_ATTRS_TOPIC,
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import CONF_JSON_ATTRS_TOPIC, MqttEntity, async_setup_entity_entry_helper
from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_subscribe_topic
@@ -119,33 +114,31 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
@callback
def _tracker_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload == self._config[CONF_PAYLOAD_HOME]:
self._location_name = STATE_HOME
elif payload == self._config[CONF_PAYLOAD_NOT_HOME]:
self._location_name = STATE_NOT_HOME
elif payload == self._config[CONF_PAYLOAD_RESET]:
self._location_name = None
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, str)
self._location_name = msg.payload
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_location_name"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload == self._config[CONF_PAYLOAD_HOME]:
self._location_name = STATE_HOME
elif payload == self._config[CONF_PAYLOAD_NOT_HOME]:
self._location_name = STATE_NOT_HOME
elif payload == self._config[CONF_PAYLOAD_RESET]:
self._location_name = None
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, str)
self._location_name = msg.payload
state_topic: str | None = self._config.get(CONF_STATE_TOPIC)
if state_topic is None:
return
@@ -155,7 +148,12 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
{
"state_topic": {
"topic": state_topic,
"msg_callback": message_received,
"msg_callback": partial(
self._message_callback,
self._tracker_message_received,
{"_location_name"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
}
},
@@ -168,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property
def latitude(self) -> float | None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
from typing import Any
@@ -31,7 +32,6 @@ from .const import (
PAYLOAD_EMPTY_JSON,
PAYLOAD_NONE,
)
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
DATA_MQTT,
@@ -113,90 +113,91 @@ class MqttEvent(MqttEntity, EventEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
@callback
def _event_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if msg.retain:
_LOGGER.debug(
"Ignoring event trigger from replayed retained payload '%s' on topic %s",
msg.payload,
msg.topic,
)
return
event_attributes: dict[str, Any] = {}
event_type: str
try:
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
if (
not payload
or payload is PayloadSentinel.DEFAULT
or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON)
):
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
try:
event_attributes = json_loads_object(payload)
event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE))
_LOGGER.debug(
(
"JSON event data detected after processing payload '%s' on"
" topic %s, type %s, attributes %s"
),
payload,
msg.topic,
event_type,
event_attributes,
)
except KeyError:
_LOGGER.warning(
("`event_type` missing in JSON event payload, " " '%s' on topic %s"),
payload,
msg.topic,
)
return
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid JSON event payload detected, "
"value after processing payload"
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
try:
self._trigger_event(event_type, event_attributes)
except ValueError:
_LOGGER.warning(
"Invalid event type %s for %s received on topic %s, payload %s",
event_type,
self.entity_id,
msg.topic,
payload,
)
return
mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
@callback
@log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if msg.retain:
_LOGGER.debug(
"Ignoring event trigger from replayed retained payload '%s' on topic %s",
msg.payload,
msg.topic,
)
return
event_attributes: dict[str, Any] = {}
event_type: str
try:
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
if (
not payload
or payload is PayloadSentinel.DEFAULT
or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON)
):
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
try:
event_attributes = json_loads_object(payload)
event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE))
_LOGGER.debug(
(
"JSON event data detected after processing payload '%s' on"
" topic %s, type %s, attributes %s"
),
payload,
msg.topic,
event_type,
event_attributes,
)
except KeyError:
_LOGGER.warning(
(
"`event_type` missing in JSON event payload, "
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid JSON event payload detected, "
"value after processing payload"
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
try:
self._trigger_event(event_type, event_attributes)
except ValueError:
_LOGGER.warning(
"Invalid event type %s for %s received on topic %s, payload %s",
event_type,
self.entity_id,
msg.topic,
payload,
)
return
mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self)
topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC],
"msg_callback": message_received,
"msg_callback": partial(
self._message_callback,
self._event_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@@ -207,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
import math
from typing import Any
@@ -49,12 +50,7 @@ from .const import (
CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
MessageCallbackType,
MqttCommandTemplate,
@@ -338,137 +334,142 @@ class MqttFan(MqttEntity, FanEntity):
for key, tpl in value_templates.items()
}
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
return
if payload == self._payload["STATE_ON"]:
self._attr_is_on = True
elif payload == self._payload["STATE_OFF"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
@callback
def _percentage_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the percentage."""
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
msg.payload
)
if not rendered_percentage_payload:
_LOGGER.debug("Ignoring empty speed from '%s'", msg.topic)
return
if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]:
self._attr_percentage = None
return
try:
percentage = ranged_value_to_percentage(
self._speed_range, int(rendered_percentage_payload)
)
except ValueError:
_LOGGER.warning(
(
"'%s' received on topic %s. '%s' is not a valid speed within"
" the speed range"
),
msg.payload,
msg.topic,
rendered_percentage_payload,
)
return
if percentage < 0 or percentage > 100:
_LOGGER.warning(
(
"'%s' received on topic %s. '%s' is not a valid speed within"
" the speed range"
),
msg.payload,
msg.topic,
rendered_percentage_payload,
)
return
self._attr_percentage = percentage
@callback
def _preset_mode_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for preset mode."""
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
if preset_mode == self._payload["PRESET_MODE_RESET"]:
self._attr_preset_mode = None
return
if not preset_mode:
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
return
if not self.preset_modes or preset_mode not in self.preset_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid preset mode",
msg.payload,
msg.topic,
preset_mode,
)
return
self._attr_preset_mode = preset_mode
@callback
def _oscillation_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the oscillation."""
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty oscillation from '%s'", msg.topic)
return
if payload == self._payload["OSCILLATE_ON_PAYLOAD"]:
self._attr_oscillating = True
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
self._attr_oscillating = False
@callback
def _direction_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the direction."""
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
if not direction:
_LOGGER.debug("Ignoring empty direction from '%s'", msg.topic)
return
self._attr_current_direction = str(direction)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
def add_subscribe_topic(
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
) -> bool:
"""Add a topic to subscribe to."""
if has_topic := self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": msg_callback,
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
return has_topic
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
return
if payload == self._payload["STATE_ON"]:
self._attr_is_on = True
elif payload == self._payload["STATE_OFF"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
add_subscribe_topic(CONF_STATE_TOPIC, state_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_percentage"})
def percentage_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the percentage."""
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
msg.payload
)
if not rendered_percentage_payload:
_LOGGER.debug("Ignoring empty speed from '%s'", msg.topic)
return
if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]:
self._attr_percentage = None
return
try:
percentage = ranged_value_to_percentage(
self._speed_range, int(rendered_percentage_payload)
)
except ValueError:
_LOGGER.warning(
(
"'%s' received on topic %s. '%s' is not a valid speed within"
" the speed range"
),
msg.payload,
msg.topic,
rendered_percentage_payload,
)
return
if percentage < 0 or percentage > 100:
_LOGGER.warning(
(
"'%s' received on topic %s. '%s' is not a valid speed within"
" the speed range"
),
msg.payload,
msg.topic,
rendered_percentage_payload,
)
return
self._attr_percentage = percentage
add_subscribe_topic(CONF_PERCENTAGE_STATE_TOPIC, percentage_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_preset_mode"})
def preset_mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for preset mode."""
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
if preset_mode == self._payload["PRESET_MODE_RESET"]:
self._attr_preset_mode = None
return
if not preset_mode:
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
return
if not self.preset_modes or preset_mode not in self.preset_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid preset mode",
msg.payload,
msg.topic,
preset_mode,
)
return
self._attr_preset_mode = preset_mode
add_subscribe_topic(CONF_PRESET_MODE_STATE_TOPIC, preset_mode_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_oscillating"})
def oscillation_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the oscillation."""
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty oscillation from '%s'", msg.topic)
return
if payload == self._payload["OSCILLATE_ON_PAYLOAD"]:
self._attr_oscillating = True
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
self._attr_oscillating = False
if add_subscribe_topic(CONF_OSCILLATION_STATE_TOPIC, oscillation_received):
add_subscribe_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
add_subscribe_topic(
CONF_PERCENTAGE_STATE_TOPIC, self._percentage_received, {"_attr_percentage"}
)
add_subscribe_topic(
CONF_PRESET_MODE_STATE_TOPIC,
self._preset_mode_received,
{"_attr_preset_mode"},
)
if add_subscribe_topic(
CONF_OSCILLATION_STATE_TOPIC,
self._oscillation_received,
{"_attr_oscillating"},
):
self._attr_oscillating = False
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_current_direction"})
def direction_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the direction."""
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
if not direction:
_LOGGER.debug("Ignoring empty direction from '%s'", msg.topic)
return
self._attr_current_direction = str(direction)
add_subscribe_topic(CONF_DIRECTION_STATE_TOPIC, direction_received)
add_subscribe_topic(
CONF_DIRECTION_STATE_TOPIC,
self._direction_received,
{"_attr_current_direction"},
)
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
@@ -476,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property
def is_on(self) -> bool | None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
from typing import Any
@@ -51,12 +52,7 @@ from .const import (
CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
@@ -284,164 +280,166 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
topics: dict[str, dict[str, Any]],
topic: str,
msg_callback: Callable[[ReceiveMessage], None],
tracked_attributes: set[str],
) -> None:
"""Add a subscription."""
qos: int = self._config[CONF_QOS]
if topic in self._topic and self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": msg_callback,
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": qos,
"encoding": self._config[CONF_ENCODING] or None,
}
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
return
if payload == self._payload["STATE_ON"]:
self._attr_is_on = True
elif payload == self._payload["STATE_OFF"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
@callback
def _action_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
action_payload = self._value_templates[ATTR_ACTION](msg.payload)
if not action_payload or action_payload == PAYLOAD_NONE:
_LOGGER.debug("Ignoring empty action from '%s'", msg.topic)
return
try:
self._attr_action = HumidifierAction(str(action_payload))
except ValueError:
_LOGGER.error(
"'%s' received on topic %s. '%s' is not a valid action",
msg.payload,
msg.topic,
action_payload,
)
return
@callback
def _current_humidity_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the current humidity."""
rendered_current_humidity_payload = self._value_templates[
ATTR_CURRENT_HUMIDITY
](msg.payload)
if rendered_current_humidity_payload == self._payload["HUMIDITY_RESET"]:
self._attr_current_humidity = None
return
if not rendered_current_humidity_payload:
_LOGGER.debug("Ignoring empty current humidity from '%s'", msg.topic)
return
try:
current_humidity = round(float(rendered_current_humidity_payload))
except ValueError:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid humidity",
msg.payload,
msg.topic,
rendered_current_humidity_payload,
)
return
if current_humidity < 0 or current_humidity > 100:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid humidity",
msg.payload,
msg.topic,
rendered_current_humidity_payload,
)
return
self._attr_current_humidity = current_humidity
@callback
def _target_humidity_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the target humidity."""
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
msg.payload
)
if not rendered_target_humidity_payload:
_LOGGER.debug("Ignoring empty target humidity from '%s'", msg.topic)
return
if rendered_target_humidity_payload == self._payload["HUMIDITY_RESET"]:
self._attr_target_humidity = None
return
try:
target_humidity = round(float(rendered_target_humidity_payload))
except ValueError:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid target humidity",
msg.payload,
msg.topic,
rendered_target_humidity_payload,
)
return
if (
target_humidity < self._attr_min_humidity
or target_humidity > self._attr_max_humidity
):
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid target humidity",
msg.payload,
msg.topic,
rendered_target_humidity_payload,
)
return
self._attr_target_humidity = target_humidity
@callback
def _mode_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for mode."""
mode = str(self._value_templates[ATTR_MODE](msg.payload))
if mode == self._payload["MODE_RESET"]:
self._attr_mode = None
return
if not mode:
_LOGGER.debug("Ignoring empty mode from '%s'", msg.topic)
return
if not self.available_modes or mode not in self.available_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid mode",
msg.payload,
msg.topic,
mode,
)
return
self._attr_mode = mode
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
return
if payload == self._payload["STATE_ON"]:
self._attr_is_on = True
elif payload == self._payload["STATE_OFF"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
self.add_subscription(topics, CONF_STATE_TOPIC, state_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_action"})
def action_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
action_payload = self._value_templates[ATTR_ACTION](msg.payload)
if not action_payload or action_payload == PAYLOAD_NONE:
_LOGGER.debug("Ignoring empty action from '%s'", msg.topic)
return
try:
self._attr_action = HumidifierAction(str(action_payload))
except ValueError:
_LOGGER.error(
"'%s' received on topic %s. '%s' is not a valid action",
msg.payload,
msg.topic,
action_payload,
)
return
self.add_subscription(topics, CONF_ACTION_TOPIC, action_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_current_humidity"})
def current_humidity_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the current humidity."""
rendered_current_humidity_payload = self._value_templates[
ATTR_CURRENT_HUMIDITY
](msg.payload)
if rendered_current_humidity_payload == self._payload["HUMIDITY_RESET"]:
self._attr_current_humidity = None
return
if not rendered_current_humidity_payload:
_LOGGER.debug("Ignoring empty current humidity from '%s'", msg.topic)
return
try:
current_humidity = round(float(rendered_current_humidity_payload))
except ValueError:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid humidity",
msg.payload,
msg.topic,
rendered_current_humidity_payload,
)
return
if current_humidity < 0 or current_humidity > 100:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid humidity",
msg.payload,
msg.topic,
rendered_current_humidity_payload,
)
return
self._attr_current_humidity = current_humidity
self.add_subscription(
topics, CONF_CURRENT_HUMIDITY_TOPIC, current_humidity_received
topics, CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"}
)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_target_humidity"})
def target_humidity_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the target humidity."""
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
msg.payload
)
if not rendered_target_humidity_payload:
_LOGGER.debug("Ignoring empty target humidity from '%s'", msg.topic)
return
if rendered_target_humidity_payload == self._payload["HUMIDITY_RESET"]:
self._attr_target_humidity = None
return
try:
target_humidity = round(float(rendered_target_humidity_payload))
except ValueError:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid target humidity",
msg.payload,
msg.topic,
rendered_target_humidity_payload,
)
return
if (
target_humidity < self._attr_min_humidity
or target_humidity > self._attr_max_humidity
):
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid target humidity",
msg.payload,
msg.topic,
rendered_target_humidity_payload,
)
return
self._attr_target_humidity = target_humidity
self.add_subscription(
topics, CONF_TARGET_HUMIDITY_STATE_TOPIC, target_humidity_received
topics, CONF_ACTION_TOPIC, self._action_received, {"_attr_action"}
)
self.add_subscription(
topics,
CONF_CURRENT_HUMIDITY_TOPIC,
self._current_humidity_received,
{"_attr_current_humidity"},
)
self.add_subscription(
topics,
CONF_TARGET_HUMIDITY_STATE_TOPIC,
self._target_humidity_received,
{"_attr_target_humidity"},
)
self.add_subscription(
topics, CONF_MODE_STATE_TOPIC, self._mode_received, {"_attr_mode"}
)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_mode"})
def mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for mode."""
mode = str(self._value_templates[ATTR_MODE](msg.payload))
if mode == self._payload["MODE_RESET"]:
self._attr_mode = None
return
if not mode:
_LOGGER.debug("Ignoring empty mode from '%s'", msg.topic)
return
if not self.available_modes or mode not in self.available_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid mode",
msg.payload,
msg.topic,
mode,
)
return
self._attr_mode = mode
self.add_subscription(topics, CONF_MODE_STATE_TOPIC, mode_received)
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
@@ -449,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on the entity.

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from base64 import b64decode
import binascii
from collections.abc import Callable
from functools import partial
import logging
from typing import TYPE_CHECKING, Any
@@ -26,7 +27,6 @@ from homeassistant.util import dt as dt_util
from . import subscription
from .config import MQTT_BASE_SCHEMA
from .const import CONF_ENCODING, CONF_QOS
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
DATA_MQTT,
@@ -143,6 +143,45 @@ class MqttImage(MqttEntity, ImageEntity):
config.get(CONF_URL_TEMPLATE), entity=self
).async_render_with_possible_json_value
@callback
def _image_data_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
except (binascii.Error, ValueError, AssertionError) as err:
_LOGGER.error(
"Error processing image data received at topic %s: %s",
msg.topic,
err,
)
self._last_image = None
self._attr_image_last_updated = dt_util.utcnow()
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
@callback
def _image_from_url_request_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
url = cv.url(self._url_template(msg.payload))
self._attr_image_url = url
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
except vol.Invalid:
_LOGGER.error(
"Invalid image URL '%s' received at topic %s",
msg.payload,
msg.topic,
)
self._attr_image_last_updated = dt_util.utcnow()
self._cached_image = None
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@@ -159,56 +198,15 @@ class MqttImage(MqttEntity, ImageEntity):
if has_topic := self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": msg_callback,
"msg_callback": partial(self._message_callback, msg_callback, None),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": encoding,
}
return has_topic
@callback
@log_messages(self.hass, self.entity_id)
def image_data_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
except (binascii.Error, ValueError, AssertionError) as err:
_LOGGER.error(
"Error processing image data received at topic %s: %s",
msg.topic,
err,
)
self._last_image = None
self._attr_image_last_updated = dt_util.utcnow()
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
@callback
@log_messages(self.hass, self.entity_id)
def image_from_url_request_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
url = cv.url(self._url_template(msg.payload))
self._attr_image_url = url
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
except vol.Invalid:
_LOGGER.error(
"Invalid image URL '%s' received at topic %s",
msg.payload,
msg.topic,
)
self._attr_image_last_updated = dt_util.utcnow()
self._cached_image = None
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)
add_subscribe_topic(CONF_IMAGE_TOPIC, self._image_data_received)
add_subscribe_topic(CONF_URL_TOPIC, self._image_from_url_request_received)
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
@@ -216,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_image(self) -> bytes | None:
"""Return bytes of image."""

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable
import contextlib
from functools import partial
import logging
import voluptuous as vol
@@ -31,12 +32,7 @@ from .const import (
DEFAULT_OPTIMISTIC,
DEFAULT_RETAIN,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
@@ -150,57 +146,59 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self
).async_render
@callback
def _message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = str(self._value_template(msg.payload))
if not payload:
_LOGGER.debug(
"Invalid empty activity payload from topic %s, for entity %s",
msg.topic,
self.entity_id,
)
return
if payload.lower() == "none":
self._attr_activity = None
return
try:
self._attr_activity = LawnMowerActivity(payload)
except ValueError:
_LOGGER.error(
"Invalid activity for %s: '%s' (valid activities: %s)",
self.entity_id,
payload,
[option.value for option in LawnMowerActivity],
)
return
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_activity"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = str(self._value_template(msg.payload))
if not payload:
_LOGGER.debug(
"Invalid empty activity payload from topic %s, for entity %s",
msg.topic,
self.entity_id,
)
return
if payload.lower() == "none":
self._attr_activity = None
return
try:
self._attr_activity = LawnMowerActivity(payload)
except ValueError:
_LOGGER.error(
"Invalid activity for %s: '%s' (valid activities: %s)",
self.entity_id,
payload,
[option.value for option in LawnMowerActivity],
)
return
if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None:
# Force into optimistic mode.
self._attr_assumed_state = True
else:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
CONF_ACTIVITY_STATE_TOPIC: {
"topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC),
"msg_callback": message_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
return
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
CONF_ACTIVITY_STATE_TOPIC: {
"topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC),
"msg_callback": partial(
self._message_callback,
self._message_received,
{"_attr_activity"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and (
last_state := await self.async_get_last_state()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
from typing import Any, cast
@@ -53,8 +54,7 @@ from ..const import (
CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE,
)
from ..debug_info import log_messages
from ..mixins import MqttEntity, write_state_on_attr_change
from ..mixins import MqttEntity
from ..models import (
MessageCallbackType,
MqttCommandTemplate,
@@ -378,263 +378,248 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
attr: bool = getattr(self, f"_optimistic_{attribute}")
return attr
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.NONE
)
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
return
if payload == self._payload["on"]:
self._attr_is_on = True
elif payload == self._payload["off"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
@callback
def _brightness_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for the brightness."""
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty brightness message from '%s'", msg.topic)
return
device_value = float(payload)
if device_value == 0:
_LOGGER.debug("Ignoring zero brightness from '%s'", msg.topic)
return
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
self._attr_brightness = min(round(percent_bright * 255), 255)
@callback
def _rgbx_received(
self,
msg: ReceiveMessage,
template: str,
color_mode: ColorMode,
convert_color: Callable[..., tuple[int, ...]],
) -> tuple[int, ...] | None:
"""Process MQTT messages for RGBW and RGBWW."""
payload = self._value_templates[template](msg.payload, PayloadSentinel.DEFAULT)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty %s message from '%s'", color_mode, msg.topic)
return None
color = tuple(int(val) for val in str(payload).split(","))
if self._optimistic_color_mode:
self._attr_color_mode = color_mode
if self._topic[CONF_BRIGHTNESS_STATE_TOPIC] is None:
rgb = convert_color(*color)
brightness = max(rgb)
if brightness == 0:
_LOGGER.debug(
"Ignoring %s message with zero rgb brightness from '%s'",
color_mode,
msg.topic,
)
return None
self._attr_brightness = brightness
# Normalize the color to 100% brightness
color = tuple(
min(round(channel / brightness * 255), 255) for channel in color
)
return color
@callback
def _rgb_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGB."""
rgb = self._rgbx_received(
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x
)
if rgb is None:
return
self._attr_rgb_color = cast(tuple[int, int, int], rgb)
@callback
def _rgbw_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBW."""
rgbw = self._rgbx_received(
msg,
CONF_RGBW_VALUE_TEMPLATE,
ColorMode.RGBW,
color_util.color_rgbw_to_rgb,
)
if rgbw is None:
return
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
@callback
def _rgbww_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBWW."""
@callback
def _converter(
r: int, g: int, b: int, cw: int, ww: int
) -> tuple[int, int, int]:
min_kelvin = color_util.color_temperature_mired_to_kelvin(self.max_mireds)
max_kelvin = color_util.color_temperature_mired_to_kelvin(self.min_mireds)
return color_util.color_rgbww_to_rgb(
r, g, b, cw, ww, min_kelvin, max_kelvin
)
rgbww = self._rgbx_received(
msg,
CONF_RGBWW_VALUE_TEMPLATE,
ColorMode.RGBWW,
_converter,
)
if rgbww is None:
return
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
@callback
def _color_mode_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color mode."""
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty color mode message from '%s'", msg.topic)
return
self._attr_color_mode = ColorMode(str(payload))
@callback
def _color_temp_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color temperature."""
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty color temp message from '%s'", msg.topic)
return
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.COLOR_TEMP
self._attr_color_temp = int(payload)
@callback
def _effect_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for effect."""
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty effect message from '%s'", msg.topic)
return
self._attr_effect = str(payload)
@callback
def _hs_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for hs color."""
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic)
return
try:
hs_color = tuple(float(val) for val in str(payload).split(",", 2))
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.HS
self._attr_hs_color = cast(tuple[float, float], hs_color)
except ValueError:
_LOGGER.warning("Failed to parse hs state update: '%s'", payload)
@callback
def _xy_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for xy color."""
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty xy-color message from '%s'", msg.topic)
return
xy_color = tuple(float(val) for val in str(payload).split(",", 2))
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.XY
self._attr_xy_color = cast(tuple[float, float], xy_color)
def _prepare_subscribe_topics(self) -> None: # noqa: C901
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
def add_topic(topic: str, msg_callback: MessageCallbackType) -> None:
def add_topic(
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
) -> None:
"""Add a topic."""
if self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": msg_callback,
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.NONE
)
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
return
if payload == self._payload["on"]:
self._attr_is_on = True
elif payload == self._payload["off"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
if self._topic[CONF_STATE_TOPIC] is not None:
topics[CONF_STATE_TOPIC] = {
"topic": self._topic[CONF_STATE_TOPIC],
"msg_callback": state_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_brightness"})
def brightness_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for the brightness."""
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty brightness message from '%s'", msg.topic)
return
device_value = float(payload)
if device_value == 0:
_LOGGER.debug("Ignoring zero brightness from '%s'", msg.topic)
return
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
self._attr_brightness = min(round(percent_bright * 255), 255)
add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received)
@callback
def _rgbx_received(
msg: ReceiveMessage,
template: str,
color_mode: ColorMode,
convert_color: Callable[..., tuple[int, ...]],
) -> tuple[int, ...] | None:
"""Handle new MQTT messages for RGBW and RGBWW."""
payload = self._value_templates[template](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug(
"Ignoring empty %s message from '%s'", color_mode, msg.topic
)
return None
color = tuple(int(val) for val in str(payload).split(","))
if self._optimistic_color_mode:
self._attr_color_mode = color_mode
if self._topic[CONF_BRIGHTNESS_STATE_TOPIC] is None:
rgb = convert_color(*color)
brightness = max(rgb)
if brightness == 0:
_LOGGER.debug(
"Ignoring %s message with zero rgb brightness from '%s'",
color_mode,
msg.topic,
)
return None
self._attr_brightness = brightness
# Normalize the color to 100% brightness
color = tuple(
min(round(channel / brightness * 255), 255) for channel in color
)
return color
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"}
add_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
add_topic(
CONF_BRIGHTNESS_STATE_TOPIC, self._brightness_received, {"_attr_brightness"}
)
def rgb_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGB."""
rgb = _rgbx_received(
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x
)
if rgb is None:
return
self._attr_rgb_color = cast(tuple[int, int, int], rgb)
add_topic(CONF_RGB_STATE_TOPIC, rgb_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"}
add_topic(
CONF_RGB_STATE_TOPIC,
self._rgb_received,
{"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"},
)
def rgbw_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBW."""
rgbw = _rgbx_received(
msg,
CONF_RGBW_VALUE_TEMPLATE,
ColorMode.RGBW,
color_util.color_rgbw_to_rgb,
)
if rgbw is None:
return
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"}
add_topic(
CONF_RGBW_STATE_TOPIC,
self._rgbw_received,
{"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"},
)
add_topic(
CONF_RGBWW_STATE_TOPIC,
self._rgbww_received,
{"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"},
)
add_topic(
CONF_COLOR_MODE_STATE_TOPIC, self._color_mode_received, {"_attr_color_mode"}
)
add_topic(
CONF_COLOR_TEMP_STATE_TOPIC,
self._color_temp_received,
{"_attr_color_mode", "_attr_color_temp"},
)
add_topic(CONF_EFFECT_STATE_TOPIC, self._effect_received, {"_attr_effect"})
add_topic(
CONF_HS_STATE_TOPIC,
self._hs_received,
{"_attr_color_mode", "_attr_hs_color"},
)
add_topic(
CONF_XY_STATE_TOPIC,
self._xy_received,
{"_attr_color_mode", "_attr_xy_color"},
)
def rgbww_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBWW."""
@callback
def _converter(
r: int, g: int, b: int, cw: int, ww: int
) -> tuple[int, int, int]:
min_kelvin = color_util.color_temperature_mired_to_kelvin(
self.max_mireds
)
max_kelvin = color_util.color_temperature_mired_to_kelvin(
self.min_mireds
)
return color_util.color_rgbww_to_rgb(
r, g, b, cw, ww, min_kelvin, max_kelvin
)
rgbww = _rgbx_received(
msg,
CONF_RGBWW_VALUE_TEMPLATE,
ColorMode.RGBWW,
_converter,
)
if rgbww is None:
return
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_color_mode"})
def color_mode_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color mode."""
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty color mode message from '%s'", msg.topic)
return
self._attr_color_mode = ColorMode(str(payload))
add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_color_temp"})
def color_temp_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color temperature."""
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty color temp message from '%s'", msg.topic)
return
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.COLOR_TEMP
self._attr_color_temp = int(payload)
add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_effect"})
def effect_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for effect."""
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty effect message from '%s'", msg.topic)
return
self._attr_effect = str(payload)
add_topic(CONF_EFFECT_STATE_TOPIC, effect_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_hs_color"})
def hs_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for hs color."""
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic)
return
try:
hs_color = tuple(float(val) for val in str(payload).split(",", 2))
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.HS
self._attr_hs_color = cast(tuple[float, float], hs_color)
except ValueError:
_LOGGER.warning("Failed to parse hs state update: '%s'", payload)
add_topic(CONF_HS_STATE_TOPIC, hs_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_xy_color"})
def xy_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for xy color."""
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty xy-color message from '%s'", msg.topic)
return
xy_color = tuple(float(val) for val in str(payload).split(",", 2))
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.XY
self._attr_xy_color = cast(tuple[float, float], xy_color)
add_topic(CONF_XY_STATE_TOPIC, xy_received)
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
@@ -642,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
def restore_state(

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable
from contextlib import suppress
from functools import partial
import logging
from typing import TYPE_CHECKING, Any, cast
@@ -66,8 +67,7 @@ from ..const import (
CONF_STATE_TOPIC,
DOMAIN as MQTT_DOMAIN,
)
from ..debug_info import log_messages
from ..mixins import MqttEntity, write_state_on_attr_change
from ..mixins import MqttEntity
from ..models import ReceiveMessage
from ..schemas import MQTT_ENTITY_COMMON_SCHEMA
from ..util import valid_subscribe_topic
@@ -414,118 +414,121 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
self.entity_id,
)
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
values = json_loads_object(msg.payload)
if values["state"] == "ON":
self._attr_is_on = True
elif values["state"] == "OFF":
self._attr_is_on = False
elif values["state"] is None:
self._attr_is_on = None
if (
self._deprecated_color_handling
and color_supported(self.supported_color_modes)
and "color" in values
):
# Deprecated color handling
if values["color"] is None:
self._attr_hs_color = None
else:
self._update_color(values)
if not self._deprecated_color_handling and "color_mode" in values:
self._update_color(values)
if brightness_supported(self.supported_color_modes):
try:
if brightness := values["brightness"]:
if TYPE_CHECKING:
assert isinstance(brightness, float)
self._attr_brightness = color_util.value_to_brightness(
(1, self._config[CONF_BRIGHTNESS_SCALE]), brightness
)
else:
_LOGGER.debug(
"Ignoring zero brightness value for entity %s",
self.entity_id,
)
except KeyError:
pass
except (TypeError, ValueError):
_LOGGER.warning(
"Invalid brightness value '%s' received for entity %s",
values["brightness"],
self.entity_id,
)
if (
self._deprecated_color_handling
and self.supported_color_modes
and ColorMode.COLOR_TEMP in self.supported_color_modes
):
# Deprecated color handling
try:
if values["color_temp"] is None:
self._attr_color_temp = None
else:
self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type]
except KeyError:
pass
except ValueError:
_LOGGER.warning(
"Invalid color temp value '%s' received for entity %s",
values["color_temp"],
self.entity_id,
)
# Allow to switch back to color_temp
if "color" not in values:
self._attr_hs_color = None
if self.supported_features and LightEntityFeature.EFFECT:
with suppress(KeyError):
self._attr_effect = cast(str, values["effect"])
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self,
#
if self._topic[CONF_STATE_TOPIC] is None:
return
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"_attr_brightness",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
"_attr_rgb_color",
"_attr_rgbw_color",
"_attr_rgbww_color",
"_attr_xy_color",
"color_mode",
CONF_STATE_TOPIC: {
"topic": self._topic[CONF_STATE_TOPIC],
"msg_callback": partial(
self._message_callback,
self._state_received,
{
"_attr_brightness",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
"_attr_rgb_color",
"_attr_rgbw_color",
"_attr_rgbww_color",
"_attr_xy_color",
"color_mode",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
values = json_loads_object(msg.payload)
if values["state"] == "ON":
self._attr_is_on = True
elif values["state"] == "OFF":
self._attr_is_on = False
elif values["state"] is None:
self._attr_is_on = None
if (
self._deprecated_color_handling
and color_supported(self.supported_color_modes)
and "color" in values
):
# Deprecated color handling
if values["color"] is None:
self._attr_hs_color = None
else:
self._update_color(values)
if not self._deprecated_color_handling and "color_mode" in values:
self._update_color(values)
if brightness_supported(self.supported_color_modes):
try:
if brightness := values["brightness"]:
if TYPE_CHECKING:
assert isinstance(brightness, float)
self._attr_brightness = color_util.value_to_brightness(
(1, self._config[CONF_BRIGHTNESS_SCALE]), brightness
)
else:
_LOGGER.debug(
"Ignoring zero brightness value for entity %s",
self.entity_id,
)
except KeyError:
pass
except (TypeError, ValueError):
_LOGGER.warning(
"Invalid brightness value '%s' received for entity %s",
values["brightness"],
self.entity_id,
)
if (
self._deprecated_color_handling
and self.supported_color_modes
and ColorMode.COLOR_TEMP in self.supported_color_modes
):
# Deprecated color handling
try:
if values["color_temp"] is None:
self._attr_color_temp = None
else:
self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type]
except KeyError:
pass
except ValueError:
_LOGGER.warning(
"Invalid color temp value '%s' received for entity %s",
values["color_temp"],
self.entity_id,
)
# Allow to switch back to color_temp
if "color" not in values:
self._attr_hs_color = None
if self.supported_features and LightEntityFeature.EFFECT:
with suppress(KeyError):
self._attr_effect = cast(str, values["effect"])
if self._topic[CONF_STATE_TOPIC] is not None:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._topic[CONF_STATE_TOPIC],
"msg_callback": state_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
if self._optimistic and last_state:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
from typing import Any
@@ -44,8 +45,7 @@ from ..const import (
CONF_STATE_TOPIC,
PAYLOAD_NONE,
)
from ..debug_info import log_messages
from ..mixins import MqttEntity, write_state_on_attr_change
from ..mixins import MqttEntity
from ..models import (
MqttCommandTemplate,
MqttValueTemplate,
@@ -188,107 +188,107 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
# Support for ct + hs, prioritize hs
self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
if state == STATE_ON:
self._attr_is_on = True
elif state == STATE_OFF:
self._attr_is_on = False
elif state == PAYLOAD_NONE:
self._attr_is_on = None
else:
_LOGGER.warning("Invalid state value received")
if CONF_BRIGHTNESS_TEMPLATE in self._config:
try:
if brightness := int(
self._value_templates[CONF_BRIGHTNESS_TEMPLATE](msg.payload)
):
self._attr_brightness = brightness
else:
_LOGGER.debug(
"Ignoring zero brightness value for entity %s",
self.entity_id,
)
except ValueError:
_LOGGER.warning("Invalid brightness value received from %s", msg.topic)
if CONF_COLOR_TEMP_TEMPLATE in self._config:
try:
color_temp = self._value_templates[CONF_COLOR_TEMP_TEMPLATE](
msg.payload
)
self._attr_color_temp = (
int(color_temp) if color_temp != "None" else None
)
except ValueError:
_LOGGER.warning("Invalid color temperature value received")
if (
CONF_RED_TEMPLATE in self._config
and CONF_GREEN_TEMPLATE in self._config
and CONF_BLUE_TEMPLATE in self._config
):
try:
red = self._value_templates[CONF_RED_TEMPLATE](msg.payload)
green = self._value_templates[CONF_GREEN_TEMPLATE](msg.payload)
blue = self._value_templates[CONF_BLUE_TEMPLATE](msg.payload)
if red == "None" and green == "None" and blue == "None":
self._attr_hs_color = None
else:
self._attr_hs_color = color_util.color_RGB_to_hs(
int(red), int(green), int(blue)
)
self._update_color_mode()
except ValueError:
_LOGGER.warning("Invalid color value received")
if CONF_EFFECT_TEMPLATE in self._config:
effect = str(self._value_templates[CONF_EFFECT_TEMPLATE](msg.payload))
if (
effect_list := self._config[CONF_EFFECT_LIST]
) and effect in effect_list:
self._attr_effect = effect
else:
_LOGGER.warning("Unsupported effect value received")
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self,
if self._topics[CONF_STATE_TOPIC] is None:
return
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"_attr_brightness",
"_attr_color_mode",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
"state_topic": {
"topic": self._topics[CONF_STATE_TOPIC],
"msg_callback": partial(
self._message_callback,
self._state_received,
{
"_attr_brightness",
"_attr_color_mode",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
if state == STATE_ON:
self._attr_is_on = True
elif state == STATE_OFF:
self._attr_is_on = False
elif state == PAYLOAD_NONE:
self._attr_is_on = None
else:
_LOGGER.warning("Invalid state value received")
if CONF_BRIGHTNESS_TEMPLATE in self._config:
try:
if brightness := int(
self._value_templates[CONF_BRIGHTNESS_TEMPLATE](msg.payload)
):
self._attr_brightness = brightness
else:
_LOGGER.debug(
"Ignoring zero brightness value for entity %s",
self.entity_id,
)
except ValueError:
_LOGGER.warning(
"Invalid brightness value received from %s", msg.topic
)
if CONF_COLOR_TEMP_TEMPLATE in self._config:
try:
color_temp = self._value_templates[CONF_COLOR_TEMP_TEMPLATE](
msg.payload
)
self._attr_color_temp = (
int(color_temp) if color_temp != "None" else None
)
except ValueError:
_LOGGER.warning("Invalid color temperature value received")
if (
CONF_RED_TEMPLATE in self._config
and CONF_GREEN_TEMPLATE in self._config
and CONF_BLUE_TEMPLATE in self._config
):
try:
red = self._value_templates[CONF_RED_TEMPLATE](msg.payload)
green = self._value_templates[CONF_GREEN_TEMPLATE](msg.payload)
blue = self._value_templates[CONF_BLUE_TEMPLATE](msg.payload)
if red == "None" and green == "None" and blue == "None":
self._attr_hs_color = None
else:
self._attr_hs_color = color_util.color_RGB_to_hs(
int(red), int(green), int(blue)
)
self._update_color_mode()
except ValueError:
_LOGGER.warning("Invalid color value received")
if CONF_EFFECT_TEMPLATE in self._config:
effect = str(self._value_templates[CONF_EFFECT_TEMPLATE](msg.payload))
if (
effect_list := self._config[CONF_EFFECT_LIST]
) and effect in effect_list:
self._attr_effect = effect
else:
_LOGGER.warning("Unsupported effect value received")
if self._topics[CONF_STATE_TOPIC] is not None:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._topics[CONF_STATE_TOPIC],
"msg_callback": state_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state()
if self._optimistic and last_state:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
import re
from typing import Any
@@ -36,12 +37,7 @@ from .const import (
CONF_STATE_OPENING,
CONF_STATE_TOPIC,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
@@ -186,57 +182,58 @@ class MqttLock(MqttEntity, LockEntity):
self._valid_states = [config[state] for state in STATE_CONFIG_KEYS]
@callback
def _message_received(self, msg: ReceiveMessage) -> None:
"""Handle new lock state messages."""
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload == self._config[CONF_PAYLOAD_RESET]:
# Reset the state to `unknown`
self._attr_is_locked = None
elif payload in self._valid_states:
self._attr_is_locked = payload == self._config[CONF_STATE_LOCKED]
self._attr_is_locking = payload == self._config[CONF_STATE_LOCKING]
self._attr_is_open = payload == self._config[CONF_STATE_OPEN]
self._attr_is_opening = payload == self._config[CONF_STATE_OPENING]
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
topics: dict[str, dict[str, Any]]
qos: int = self._config[CONF_QOS]
encoding: str | None = self._config[CONF_ENCODING] or None
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self,
{
"_attr_is_jammed",
"_attr_is_locked",
"_attr_is_locking",
"_attr_is_open",
"_attr_is_opening",
"_attr_is_unlocking",
},
)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new lock state messages."""
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload == self._config[CONF_PAYLOAD_RESET]:
# Reset the state to `unknown`
self._attr_is_locked = None
elif payload in self._valid_states:
self._attr_is_locked = payload == self._config[CONF_STATE_LOCKED]
self._attr_is_locking = payload == self._config[CONF_STATE_LOCKING]
self._attr_is_open = payload == self._config[CONF_STATE_OPEN]
self._attr_is_opening = payload == self._config[CONF_STATE_OPENING]
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode.
self._optimistic = True
else:
topics[CONF_STATE_TOPIC] = {
return
topics = {
CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received,
"msg_callback": partial(
self._message_callback,
self._message_received,
{
"_attr_is_jammed",
"_attr_is_locked",
"_attr_is_locking",
"_attr_is_open",
"_attr_is_opening",
"_attr_is_unlocking",
},
),
"entity_id": self.entity_id,
CONF_QOS: qos,
CONF_ENCODING: encoding,
}
}
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
@@ -246,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_lock(self, **kwargs: Any) -> None:
"""Lock the device.

View File

@@ -114,7 +114,7 @@ from .models import (
from .subscription import (
EntitySubscription,
async_prepare_subscribe_topics,
async_subscribe_topics,
async_subscribe_topics_internal,
async_unsubscribe_topics,
)
from .util import mqtt_config_entry_enabled
@@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity):
"""Subscribe MQTT events."""
await super().async_added_to_hass()
self._attributes_prepare_subscribe_topics()
await self._attributes_subscribe_topics()
self._attributes_subscribe_topics()
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message."""
@@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity):
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message."""
await self._attributes_subscribe_topics()
self._attributes_subscribe_topics()
def _attributes_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity):
},
)
async def _attributes_subscribe_topics(self) -> None:
@callback
def _attributes_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._attributes_sub_state)
async_subscribe_topics_internal(self.hass, self._attributes_sub_state)
async def async_will_remove_from_hass(self) -> None:
"""Unsubscribe when removed."""
@@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity):
"""Subscribe MQTT events."""
await super().async_added_to_hass()
self._availability_prepare_subscribe_topics()
await self._availability_subscribe_topics()
self._availability_subscribe_topics()
self.async_on_remove(
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
)
@@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity):
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message."""
await self._availability_subscribe_topics()
self._availability_subscribe_topics()
def _availability_setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup."""
@@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity):
self._available[topic] = False
self._available_latest = False
async def _availability_subscribe_topics(self) -> None:
@callback
def _availability_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state)
async_subscribe_topics_internal(self.hass, self._availability_sub_state)
@callback
def async_mqtt_connect(self) -> None:
@@ -1254,13 +1256,15 @@ class MqttEntity(
def _message_callback(
self,
msg_callback: MessageCallbackType,
attributes: set[str],
attributes: set[str] | None,
msg: ReceiveMessage,
) -> None:
"""Process the message callback."""
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes
)
if attributes is not None:
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
(attribute, getattr(self, attribute, UNDEFINED))
for attribute in attributes
)
mqtt_data = self.hass.data[DATA_MQTT]
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
msg.subscribed_topic
@@ -1274,7 +1278,7 @@ class MqttEntity(
_LOGGER.warning(exc)
return
if self._attrs_have_changed(attrs_snapshot):
if attributes is not None and self._attrs_have_changed(attrs_snapshot):
mqtt_data.state_write_requests.write_state_request(self)

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from ast import literal_eval
import asyncio
from collections import deque
from collections.abc import Callable, Coroutine
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import StrEnum
import logging
@@ -70,7 +70,6 @@ class ReceiveMessage:
timestamp: float
type AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]]
type MessageCallbackType = Callable[[ReceiveMessage], None]

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
import voluptuous as vol
@@ -41,12 +42,7 @@ from .const import (
CONF_RETAIN,
CONF_STATE_TOPIC,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
@@ -165,64 +161,66 @@ class MqttNumber(MqttEntity, RestoreNumber):
self._attr_native_step = config[CONF_STEP]
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
@callback
def _message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
num_value: int | float | None
payload = str(self._value_template(msg.payload))
if not payload.strip():
_LOGGER.debug("Ignoring empty state update from '%s'", msg.topic)
return
try:
if payload == self._config[CONF_PAYLOAD_RESET]:
num_value = None
elif payload.isnumeric():
num_value = int(payload)
else:
num_value = float(payload)
except ValueError:
_LOGGER.warning("Payload '%s' is not a Number", msg.payload)
return
if num_value is not None and (
num_value < self.min_value or num_value > self.max_value
):
_LOGGER.error(
"Invalid value for %s: %s (range %s - %s)",
self.entity_id,
num_value,
self.min_value,
self.max_value,
)
return
self._attr_native_value = num_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_native_value"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
num_value: int | float | None
payload = str(self._value_template(msg.payload))
if not payload.strip():
_LOGGER.debug("Ignoring empty state update from '%s'", msg.topic)
return
try:
if payload == self._config[CONF_PAYLOAD_RESET]:
num_value = None
elif payload.isnumeric():
num_value = int(payload)
else:
num_value = float(payload)
except ValueError:
_LOGGER.warning("Payload '%s' is not a Number", msg.payload)
return
if num_value is not None and (
num_value < self.min_value or num_value > self.max_value
):
_LOGGER.error(
"Invalid value for %s: %s (range %s - %s)",
self.entity_id,
num_value,
self.min_value,
self.max_value,
)
return
self._attr_native_value = num_value
if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode.
self._attr_assumed_state = True
else:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
return
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": partial(
self._message_callback,
self._message_received,
{"_attr_native_value"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and (
last_number_data := await self.async_get_last_number_data()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
import voluptuous as vol
@@ -27,12 +28,7 @@ from .const import (
CONF_RETAIN,
CONF_STATE_TOPIC,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
@@ -113,56 +109,58 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
@callback
def _message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = str(self._value_template(msg.payload))
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload.lower() == "none":
self._attr_current_option = None
return
if payload not in self.options:
_LOGGER.error(
"Invalid option for %s: '%s' (valid options: %s)",
self.entity_id,
payload,
self.options,
)
return
self._attr_current_option = payload
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_current_option"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = str(self._value_template(msg.payload))
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload.lower() == "none":
self._attr_current_option = None
return
if payload not in self.options:
_LOGGER.error(
"Invalid option for %s: '%s' (valid options: %s)",
self.entity_id,
payload,
self.options,
)
return
self._attr_current_option = payload
if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode.
self._attr_assumed_state = True
else:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
return
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": partial(
self._message_callback,
self._message_received,
{"_attr_current_option"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and (
last_state := await self.async_get_last_state()

View File

@@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@callback
def _value_is_expired(self, *_: datetime) -> None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
from typing import Any, cast
@@ -48,12 +49,7 @@ from .const import (
PAYLOAD_EMPTY_JSON,
PAYLOAD_NONE,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
MqttCommandTemplate,
MqttValueTemplate,
@@ -205,92 +201,94 @@ class MqttSiren(MqttEntity, SirenEntity):
entity=self,
).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on", "_extra_attributes"})
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
if not payload or payload == PAYLOAD_EMPTY_JSON:
@callback
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
if not payload or payload == PAYLOAD_EMPTY_JSON:
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
json_payload: dict[str, Any] = {}
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]:
json_payload = {STATE: payload}
else:
try:
json_payload = json_loads_object(payload)
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
(
"JSON payload detected after processing payload '%s' on"
" topic %s"
),
json_payload,
msg.topic,
)
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid (JSON) payload detected after processing payload"
" '%s' on topic %s"
),
json_payload,
msg.topic,
)
return
json_payload: dict[str, Any] = {}
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]:
json_payload = {STATE: payload}
else:
try:
json_payload = json_loads_object(payload)
_LOGGER.debug(
(
"JSON payload detected after processing payload '%s' on"
" topic %s"
),
json_payload,
msg.topic,
)
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid (JSON) payload detected after processing payload"
" '%s' on topic %s"
),
json_payload,
msg.topic,
)
return
if STATE in json_payload:
if json_payload[STATE] == self._state_on:
self._attr_is_on = True
if json_payload[STATE] == self._state_off:
self._attr_is_on = False
if json_payload[STATE] == PAYLOAD_NONE:
self._attr_is_on = None
del json_payload[STATE]
if STATE in json_payload:
if json_payload[STATE] == self._state_on:
self._attr_is_on = True
if json_payload[STATE] == self._state_off:
self._attr_is_on = False
if json_payload[STATE] == PAYLOAD_NONE:
self._attr_is_on = None
del json_payload[STATE]
if json_payload:
# process attributes
try:
params: SirenTurnOnServiceParameters
params = vol.All(TURN_ON_SCHEMA)(json_payload)
except vol.MultipleInvalid as invalid_siren_parameters:
_LOGGER.warning(
"Unable to update siren state attributes from payload '%s': %s",
json_payload,
invalid_siren_parameters,
)
return
# To be able to track changes to self._extra_attributes we assign
# a fresh copy to make the original tracked reference immutable.
self._extra_attributes = dict(self._extra_attributes)
self._update(process_turn_on_params(self, params))
if json_payload:
# process attributes
try:
params: SirenTurnOnServiceParameters
params = vol.All(TURN_ON_SCHEMA)(json_payload)
except vol.MultipleInvalid as invalid_siren_parameters:
_LOGGER.warning(
"Unable to update siren state attributes from payload '%s': %s",
json_payload,
invalid_siren_parameters,
)
return
# To be able to track changes to self._extra_attributes we assign
# a fresh copy to make the original tracked reference immutable.
self._extra_attributes = dict(self._extra_attributes)
self._update(process_turn_on_params(self, params))
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode.
self._optimistic = True
else:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
return
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_is_on", "_extra_attributes"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property
def extra_state_attributes(self) -> dict[str, Any] | None:

View File

@@ -2,14 +2,15 @@
from __future__ import annotations
from collections.abc import Callable, Coroutine
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from .. import mqtt
from . import debug_info
from .client import async_subscribe_internal
from .const import DEFAULT_QOS
from .models import MessageCallbackType
@@ -21,7 +22,7 @@ class EntitySubscription:
hass: HomeAssistant
topic: str | None
message_callback: MessageCallbackType
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None
should_subscribe: bool | None
unsubscribe_callback: Callable[[], None] | None
qos: int = 0
encoding: str = "utf-8"
@@ -53,15 +54,16 @@ class EntitySubscription:
self.hass, self.message_callback, self.topic, self.entity_id
)
self.subscribe_task = mqtt.async_subscribe(
hass, self.topic, self.message_callback, self.qos, self.encoding
)
self.should_subscribe = True
async def subscribe(self) -> None:
@callback
def subscribe(self) -> None:
"""Subscribe to a topic."""
if not self.subscribe_task:
if not self.should_subscribe or not self.topic:
return
self.unsubscribe_callback = await self.subscribe_task
self.unsubscribe_callback = async_subscribe_internal(
self.hass, self.topic, self.message_callback, self.qos, self.encoding
)
def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
"""Check if we should re-subscribe to the topic using the old state."""
@@ -79,6 +81,7 @@ class EntitySubscription:
)
@callback
def async_prepare_subscribe_topics(
hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None,
@@ -107,7 +110,7 @@ def async_prepare_subscribe_topics(
qos=value.get("qos", DEFAULT_QOS),
encoding=value.get("encoding", "utf-8"),
hass=hass,
subscribe_task=None,
should_subscribe=None,
entity_id=value.get("entity_id", None),
)
# Get the current subscription state
@@ -135,12 +138,29 @@ async def async_subscribe_topics(
sub_state: dict[str, EntitySubscription],
) -> None:
"""(Re)Subscribe to a set of MQTT topics."""
async_subscribe_topics_internal(hass, sub_state)
@callback
def async_subscribe_topics_internal(
hass: HomeAssistant,
sub_state: dict[str, EntitySubscription],
) -> None:
"""(Re)Subscribe to a set of MQTT topics.
This function is internal to the MQTT integration and should not be called
from outside the integration.
"""
for sub in sub_state.values():
await sub.subscribe()
sub.subscribe()
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
return async_prepare_subscribe_topics(hass, sub_state, {})
if TYPE_CHECKING:
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={})

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import Any
import voluptuous as vol
@@ -36,12 +37,7 @@ from .const import (
CONF_STATE_TOPIC,
PAYLOAD_NONE,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@@ -118,42 +114,44 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value
@callback
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
if payload == self._state_on:
self._attr_is_on = True
elif payload == self._state_off:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
if payload == self._state_on:
self._attr_is_on = True
elif payload == self._state_off:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode.
self._optimistic = True
else:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
return
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_is_on"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()):
self._attr_is_on = last_state.state == STATE_ON

View File

@@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin):
}
},
)
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_tear_down(self) -> None:
"""Cleanup tag scanner."""

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
import re
from typing import Any
@@ -34,12 +35,7 @@ from .const import (
CONF_RETAIN,
CONF_STATE_TOPIC,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import (
MessageCallbackType,
MqttCommandTemplate,
@@ -160,32 +156,41 @@ class MqttTextEntity(MqttEntity, TextEntity):
self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None
self._attr_assumed_state = bool(self._optimistic)
@callback
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = str(self._value_template(msg.payload))
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
return
self._attr_native_value = payload
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscription(
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType
topics: dict[str, Any],
topic: str,
msg_callback: MessageCallbackType,
tracked_attributes: set[str],
) -> None:
if self._config.get(topic) is not None:
topics[topic] = {
"topic": self._config[topic],
"msg_callback": msg_callback,
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_native_value"})
def handle_state_message_received(msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = str(self._value_template(msg.payload))
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
return
self._attr_native_value = payload
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
add_subscription(
topics,
CONF_STATE_TOPIC,
self._handle_state_message_received,
{"_attr_native_value"},
)
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics
@@ -193,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_set_value(self, value: str) -> None:
"""Change the text."""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from functools import partial
import logging
from typing import Any, TypedDict, cast
@@ -32,12 +33,7 @@ from .const import (
CONF_STATE_TOPIC,
PAYLOAD_EMPTY_JSON,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic, valid_subscribe_topic
@@ -141,25 +137,104 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
).async_render_with_possible_json_value,
}
@callback
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
if not payload or payload == PAYLOAD_EMPTY_JSON:
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
json_payload: _MqttUpdatePayloadType = {}
try:
rendered_json_payload = json_loads(payload)
if isinstance(rendered_json_payload, dict):
_LOGGER.debug(
(
"JSON payload detected after processing payload '%s' on"
" topic %s"
),
rendered_json_payload,
msg.topic,
)
json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload)
else:
_LOGGER.debug(
(
"Non-dictionary JSON payload detected after processing"
" payload '%s' on topic %s"
),
payload,
msg.topic,
)
json_payload = {"installed_version": str(payload)}
except JSON_DECODE_EXCEPTIONS:
_LOGGER.debug(
(
"No valid (JSON) payload detected after processing payload '%s'"
" on topic %s"
),
payload,
msg.topic,
)
json_payload["installed_version"] = str(payload)
if "installed_version" in json_payload:
self._attr_installed_version = json_payload["installed_version"]
if "latest_version" in json_payload:
self._attr_latest_version = json_payload["latest_version"]
if "title" in json_payload:
self._attr_title = json_payload["title"]
if "release_summary" in json_payload:
self._attr_release_summary = json_payload["release_summary"]
if "release_url" in json_payload:
self._attr_release_url = json_payload["release_url"]
if "entity_picture" in json_payload:
self._entity_picture = json_payload["entity_picture"]
@callback
def _handle_latest_version_received(self, msg: ReceiveMessage) -> None:
"""Handle receiving latest version via MQTT."""
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
if isinstance(latest_version, str) and latest_version != "":
self._attr_latest_version = latest_version
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscription(
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType
topics: dict[str, Any],
topic: str,
msg_callback: MessageCallbackType,
tracked_attributes: set[str],
) -> None:
if self._config.get(topic) is not None:
topics[topic] = {
"topic": self._config[topic],
"msg_callback": msg_callback,
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self,
add_subscription(
topics,
CONF_STATE_TOPIC,
self._handle_state_message_received,
{
"_attr_installed_version",
"_attr_latest_version",
@@ -169,84 +244,11 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
"_entity_picture",
},
)
def handle_state_message_received(msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
if not payload or payload == PAYLOAD_EMPTY_JSON:
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
json_payload: _MqttUpdatePayloadType = {}
try:
rendered_json_payload = json_loads(payload)
if isinstance(rendered_json_payload, dict):
_LOGGER.debug(
(
"JSON payload detected after processing payload '%s' on"
" topic %s"
),
rendered_json_payload,
msg.topic,
)
json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload)
else:
_LOGGER.debug(
(
"Non-dictionary JSON payload detected after processing"
" payload '%s' on topic %s"
),
payload,
msg.topic,
)
json_payload = {"installed_version": str(payload)}
except JSON_DECODE_EXCEPTIONS:
_LOGGER.debug(
(
"No valid (JSON) payload detected after processing payload '%s'"
" on topic %s"
),
payload,
msg.topic,
)
json_payload["installed_version"] = str(payload)
if "installed_version" in json_payload:
self._attr_installed_version = json_payload["installed_version"]
if "latest_version" in json_payload:
self._attr_latest_version = json_payload["latest_version"]
if "title" in json_payload:
self._attr_title = json_payload["title"]
if "release_summary" in json_payload:
self._attr_release_summary = json_payload["release_summary"]
if "release_url" in json_payload:
self._attr_release_url = json_payload["release_url"]
if "entity_picture" in json_payload:
self._entity_picture = json_payload["entity_picture"]
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_latest_version"})
def handle_latest_version_received(msg: ReceiveMessage) -> None:
"""Handle receiving latest version via MQTT."""
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
if isinstance(latest_version, str) and latest_version != "":
self._attr_latest_version = latest_version
add_subscription(
topics, CONF_LATEST_VERSION_TOPIC, handle_latest_version_received
topics,
CONF_LATEST_VERSION_TOPIC,
self._handle_latest_version_received,
{"_attr_latest_version"},
)
self._sub_state = subscription.async_prepare_subscribe_topics(
@@ -255,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_install(
self, version: str | None, backup: bool, **kwargs: Any

View File

@@ -8,6 +8,7 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
import logging
from typing import Any, cast
@@ -49,12 +50,7 @@ from .const import (
CONF_STATE_TOPIC,
DOMAIN,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic
@@ -322,31 +318,32 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0)
self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0)))
@callback
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle state MQTT message."""
payload = json_loads_object(msg.payload)
if STATE in payload and (
(state := payload[STATE]) in POSSIBLE_STATES or state is None
):
self._attr_state = (
POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None
)
del payload[STATE]
self._update_state_attributes(payload)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self, {"_attr_battery_level", "_attr_fan_speed", "_attr_state"}
)
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle state MQTT message."""
payload = json_loads_object(msg.payload)
if STATE in payload and (
(state := payload[STATE]) in POSSIBLE_STATES or state is None
):
self._attr_state = (
POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None
)
del payload[STATE]
self._update_state_attributes(payload)
if state_topic := self._config.get(CONF_STATE_TOPIC):
topics["state_position_topic"] = {
"topic": state_topic,
"msg_callback": state_message_received,
"msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_battery_level", "_attr_fan_speed", "_attr_state"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@@ -356,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
"""Publish a command."""

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from contextlib import suppress
from functools import partial
import logging
from typing import Any
@@ -61,12 +62,7 @@ from .const import (
DEFAULT_RETAIN,
PAYLOAD_NONE,
)
from .debug_info import log_messages
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic, valid_subscribe_topic
@@ -302,65 +298,63 @@ class MqttValve(MqttEntity, ValveEntity):
return
self._update_state(state)
@callback
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
payload_dict: Any = None
position_payload: Any = payload
state_payload: Any = payload
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
return
with suppress(*JSON_DECODE_EXCEPTIONS):
payload_dict = json_loads(payload)
if isinstance(payload_dict, dict):
if self.reports_position and "position" not in payload_dict:
_LOGGER.warning(
"Missing required `position` attribute in json payload "
"on topic '%s', got: %s",
msg.topic,
payload,
)
return
if not self.reports_position and "state" not in payload_dict:
_LOGGER.warning(
"Missing required `state` attribute in json payload "
" on topic '%s', got: %s",
msg.topic,
payload,
)
return
position_payload = payload_dict.get("position")
state_payload = payload_dict.get("state")
if self._config[CONF_REPORTS_POSITION]:
self._process_position_valve_update(msg, position_payload, state_payload)
else:
self._process_binary_valve_update(msg, state_payload)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics = {}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self,
{
"_attr_current_valve_position",
"_attr_is_closed",
"_attr_is_closing",
"_attr_is_opening",
},
)
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
payload_dict: Any = None
position_payload: Any = payload
state_payload: Any = payload
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
return
with suppress(*JSON_DECODE_EXCEPTIONS):
payload_dict = json_loads(payload)
if isinstance(payload_dict, dict):
if self.reports_position and "position" not in payload_dict:
_LOGGER.warning(
"Missing required `position` attribute in json payload "
"on topic '%s', got: %s",
msg.topic,
payload,
)
return
if not self.reports_position and "state" not in payload_dict:
_LOGGER.warning(
"Missing required `state` attribute in json payload "
" on topic '%s', got: %s",
msg.topic,
payload,
)
return
position_payload = payload_dict.get("position")
state_payload = payload_dict.get("state")
if self._config[CONF_REPORTS_POSITION]:
self._process_position_valve_update(
msg, position_payload, state_payload
)
else:
self._process_binary_valve_update(msg, state_payload)
if self._config.get(CONF_STATE_TOPIC):
topics["state_topic"] = {
"topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received,
"msg_callback": partial(
self._message_callback,
self._state_message_received,
{
"_attr_current_valve_position",
"_attr_is_closed",
"_attr_is_closing",
"_attr_is_opening",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@@ -371,7 +365,7 @@ class MqttValve(MqttEntity, ValveEntity):
async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state)
subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_open_valve(self) -> None:
"""Move the valve up.

View File

@@ -9,6 +9,7 @@ from typing import Literal
import ollama
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
@@ -138,6 +139,11 @@ class OllamaConversationEntity(
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": message_history.messages},
)
# Get response
try:
response = await client.chat(

View File

@@ -31,14 +31,15 @@ from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)
_LOGGER = logging.getLogger(__name__)
@@ -49,6 +50,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
}
)
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT,
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
@@ -88,7 +95,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry(
title="ChatGPT",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
options=RECOMMENDED_OPTIONS,
)
return self.async_show_form(
@@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow):
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
self.last_rendered_recommended = config_entry.options.get(
CONF_RECOMMENDED, False
)
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
if user_input is not None:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = openai_config_option_schema(self.hass, self.config_entry.options)
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
schema = openai_config_option_schema(self.hass, options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
@@ -127,16 +150,16 @@ class OpenAIOptionsFlow(OptionsFlow):
def openai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any],
options: dict[str, Any] | MappingProxyType[str, Any],
) -> dict:
"""Return a schema for OpenAI completion options."""
apis: list[SelectOptionDict] = [
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
hass_apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
@@ -144,38 +167,46 @@ def openai_config_option_schema(
for api in llm.async_get_apis(hass)
)
return {
schema = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_CHAT_MODEL,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
if options.get(CONF_RECOMMENDED):
return schema
schema.update(
{
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
}
)
return schema

View File

@@ -4,13 +4,15 @@ import logging
DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_RECOMMENDED = "recommended"
CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-4o"
RECOMMENDED_CHAT_MODEL = "gpt-4o"
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150
RECOMMENDED_MAX_TOKENS = 150
CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 1.0
RECOMMENDED_TOP_P = 1.0
CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 1.0
RECOMMENDED_TEMPERATURE = 1.0

View File

@@ -8,6 +8,7 @@ import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
@@ -22,13 +23,13 @@ from .const import (
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)
# Max number of back and forth with the LLM to generate a response
@@ -97,15 +98,14 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API):
if options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
@@ -117,26 +117,12 @@ class OpenAIConversationEntity(
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
try:
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
@@ -149,11 +135,24 @@ class OpenAIConversationEntity(
device_id=user_input.device_id,
)
prompt = (
await llm_api.async_get_api_prompt(empty_tool_input)
+ "\n"
+ prompt
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
else:
api_prompt = llm.PROMPT_NO_API_CONFIGURED
prompt = "\n".join(
(
template.Template(
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
),
api_prompt,
)
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
@@ -170,7 +169,10 @@ class OpenAIConversationEntity(
messages.append({"role": "user", "content": user_input.text})
LOGGER.debug("Prompt for %s: %s", model, messages)
LOGGER.debug("Prompt: %s", messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
client = self.hass.data[DOMAIN][self.entry.entry_id]
@@ -178,12 +180,12 @@ class OpenAIConversationEntity(
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=model,
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages,
tools=tools,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=conversation_id,
)
except openai.OpenAIError as err:

View File

@@ -22,7 +22,8 @@
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template."

View File

@@ -30,7 +30,6 @@ from .util import (
PLATFORMS = [Platform.MEDIA_PLAYER]
__all__ = [
"async_browse_media",
"DOMAIN",
@@ -50,7 +49,10 @@ class HomeAssistantSpotifyData:
session: OAuth2Session
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
type SpotifyConfigEntry = ConfigEntry[HomeAssistantSpotifyData]
async def async_setup_entry(hass: HomeAssistant, entry: SpotifyConfigEntry) -> bool:
"""Set up Spotify from a config entry."""
implementation = await async_get_config_entry_implementation(hass, entry)
session = OAuth2Session(hass, entry, implementation)
@@ -100,8 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
await device_coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})
hass.data[DOMAIN][entry.entry_id] = HomeAssistantSpotifyData(
entry.runtime_data = HomeAssistantSpotifyData(
client=spotify,
current_user=current_user,
devices=device_coordinator,
@@ -117,6 +118,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Spotify config entry."""
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
del hass.data[DOMAIN][entry.entry_id]
return unload_ok
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from enum import StrEnum
from functools import partial
import logging
from typing import Any
from typing import TYPE_CHECKING, Any
from spotipy import Spotify
import yarl
@@ -22,6 +22,9 @@ from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES
from .util import fetch_image_url
if TYPE_CHECKING:
from . import HomeAssistantSpotifyData
BROWSE_LIMIT = 48
@@ -140,21 +143,21 @@ async def async_browse_media(
# Check if caller is requesting the root nodes
if media_content_type is None and media_content_id is None:
children = []
for config_entry_id in hass.data[DOMAIN]:
config_entry = hass.config_entries.async_get_entry(config_entry_id)
assert config_entry is not None
children.append(
BrowseMedia(
title=config_entry.title,
media_class=MediaClass.APP,
media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry_id}",
media_content_type=f"{MEDIA_PLAYER_PREFIX}library",
thumbnail="https://brands.home-assistant.io/_/spotify/logo.png",
can_play=False,
can_expand=True,
)
config_entries = hass.config_entries.async_entries(
DOMAIN, include_disabled=False, include_ignore=False
)
children = [
BrowseMedia(
title=config_entry.title,
media_class=MediaClass.APP,
media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry.entry_id}",
media_content_type=f"{MEDIA_PLAYER_PREFIX}library",
thumbnail="https://brands.home-assistant.io/_/spotify/logo.png",
can_play=False,
can_expand=True,
)
for config_entry in config_entries
]
return BrowseMedia(
title="Spotify",
media_class=MediaClass.APP,
@@ -171,9 +174,15 @@ async def async_browse_media(
# Check for config entry specifier, and extract Spotify URI
parsed_url = yarl.URL(media_content_id)
if (info := hass.data[DOMAIN].get(parsed_url.host)) is None:
if (
parsed_url.host is None
or (entry := hass.config_entries.async_get_entry(parsed_url.host)) is None
or not isinstance(entry.runtime_data, HomeAssistantSpotifyData)
):
raise BrowseError("Invalid Spotify account specified")
media_content_id = parsed_url.name
info = entry.runtime_data
result = await async_browse_media_internal(
hass,

View File

@@ -22,7 +22,6 @@ from homeassistant.components.media_player import (
MediaType,
RepeatMode,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ID
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
@@ -30,7 +29,7 @@ from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.dt import utcnow
from . import HomeAssistantSpotifyData
from . import HomeAssistantSpotifyData, SpotifyConfigEntry
from .browse_media import async_browse_media_internal
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES
from .util import fetch_image_url
@@ -70,12 +69,12 @@ SPOTIFY_DJ_PLAYLIST = {"uri": "spotify:playlist:37i9dQZF1EYkqdzj48dyYq", "name":
async def async_setup_entry(
hass: HomeAssistant,
entry: ConfigEntry,
entry: SpotifyConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Spotify based on a config entry."""
spotify = SpotifyMediaPlayer(
hass.data[DOMAIN][entry.entry_id],
entry.runtime_data,
entry.data[CONF_ID],
entry.title,
)

View File

@@ -4,15 +4,14 @@ from __future__ import annotations
import logging
from aioswitcher.bridge import SwitcherBridge
from aioswitcher.device import SwitcherBase
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform
from homeassistant.core import Event, HomeAssistant, callback
from .const import DATA_DEVICE, DOMAIN
from .coordinator import SwitcherDataUpdateCoordinator
from .utils import async_start_bridge, async_stop_bridge
PLATFORMS = [
Platform.BUTTON,
@@ -25,20 +24,20 @@ PLATFORMS = [
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
type SwitcherConfigEntry = ConfigEntry[dict[str, SwitcherDataUpdateCoordinator]]
async def async_setup_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
"""Set up Switcher from a config entry."""
hass.data.setdefault(DOMAIN, {})
hass.data[DOMAIN][DATA_DEVICE] = {}
@callback
def on_device_data_callback(device: SwitcherBase) -> None:
"""Use as a callback for device data."""
coordinators = entry.runtime_data
# Existing device update device data
if device.device_id in hass.data[DOMAIN][DATA_DEVICE]:
coordinator: SwitcherDataUpdateCoordinator = hass.data[DOMAIN][DATA_DEVICE][
device.device_id
]
if coordinator := coordinators.get(device.device_id):
coordinator.async_set_updated_data(device)
return
@@ -52,18 +51,21 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
device.device_type.hex_rep,
)
coordinator = hass.data[DOMAIN][DATA_DEVICE][device.device_id] = (
SwitcherDataUpdateCoordinator(hass, entry, device)
)
coordinator = SwitcherDataUpdateCoordinator(hass, entry, device)
coordinator.async_setup()
coordinators[device.device_id] = coordinator
# Must be ready before dispatcher is called
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
await async_start_bridge(hass, on_device_data_callback)
entry.runtime_data = {}
bridge = SwitcherBridge(on_device_data_callback)
await bridge.start()
async def stop_bridge(event: Event) -> None:
await async_stop_bridge(hass)
async def stop_bridge(event: Event | None = None) -> None:
await bridge.stop()
entry.async_on_unload(stop_bridge)
entry.async_on_unload(
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge)
@@ -72,12 +74,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
"""Unload a config entry."""
await async_stop_bridge(hass)
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(DATA_DEVICE)
return unload_ok
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

View File

@@ -15,7 +15,6 @@ from aioswitcher.api.remotes import SwitcherBreezeRemote
from aioswitcher.device import DeviceCategory
from homeassistant.components.button import ButtonEntity, ButtonEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@@ -25,6 +24,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import SwitcherConfigEntry
from .const import SIGNAL_DEVICE_ADD
from .coordinator import SwitcherDataUpdateCoordinator
from .utils import get_breeze_remote_manager
@@ -78,7 +78,7 @@ THERMOSTAT_BUTTONS = [
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
config_entry: SwitcherConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Switcher button from config entry."""

View File

@@ -25,7 +25,6 @@ from homeassistant.components.climate import (
ClimateEntityFeature,
HVACMode,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
@@ -35,6 +34,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import SwitcherConfigEntry
from .const import SIGNAL_DEVICE_ADD
from .coordinator import SwitcherDataUpdateCoordinator
from .utils import get_breeze_remote_manager
@@ -61,7 +61,7 @@ HA_TO_DEVICE_FAN = {value: key for key, value in DEVICE_FAN_TO_HA.items()}
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
config_entry: SwitcherConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Switcher climate from config entry."""

View File

@@ -2,9 +2,6 @@
DOMAIN = "switcher_kis"
DATA_BRIDGE = "bridge"
DATA_DEVICE = "device"
DISCOVERY_TIME_SEC = 12
SIGNAL_DEVICE_ADD = "switcher_device_add"

View File

@@ -6,24 +6,23 @@ from dataclasses import asdict
from typing import Any
from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from .const import DATA_DEVICE, DOMAIN
from . import SwitcherConfigEntry
TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"}
async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry
hass: HomeAssistant, entry: SwitcherConfigEntry
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
devices = hass.data[DOMAIN][DATA_DEVICE]
coordinators = entry.runtime_data
return async_redact_data(
{
"entry": entry.as_dict(),
"devices": [asdict(devices[d].data) for d in devices],
"devices": [asdict(coordinators[d].data) for d in coordinators],
},
TO_REDACT,
)

View File

@@ -3,9 +3,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
import logging
from typing import Any
from aioswitcher.api.remotes import SwitcherBreezeRemoteManager
from aioswitcher.bridge import SwitcherBase, SwitcherBridge
@@ -13,29 +11,11 @@ from aioswitcher.bridge import SwitcherBase, SwitcherBridge
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import singleton
from .const import DATA_BRIDGE, DISCOVERY_TIME_SEC, DOMAIN
from .const import DISCOVERY_TIME_SEC
_LOGGER = logging.getLogger(__name__)
async def async_start_bridge(
hass: HomeAssistant, on_device_callback: Callable[[SwitcherBase], Any]
) -> None:
"""Start switcher UDP bridge."""
bridge = hass.data[DOMAIN][DATA_BRIDGE] = SwitcherBridge(on_device_callback)
_LOGGER.debug("Starting Switcher bridge")
await bridge.start()
async def async_stop_bridge(hass: HomeAssistant) -> None:
"""Stop switcher UDP bridge."""
bridge: SwitcherBridge = hass.data[DOMAIN].get(DATA_BRIDGE)
if bridge is not None:
_LOGGER.debug("Stopping Switcher bridge")
await bridge.stop()
hass.data[DOMAIN].pop(DATA_BRIDGE)
async def async_has_devices(hass: HomeAssistant) -> bool:
"""Discover Switcher devices."""
_LOGGER.debug("Starting discovery")

View File

@@ -30,6 +30,7 @@ PLATFORMS: Final = [
Platform.BINARY_SENSOR,
Platform.CLIMATE,
Platform.COVER,
Platform.DEVICE_TRACKER,
Platform.LOCK,
Platform.SELECT,
Platform.SENSOR,

View File

@@ -0,0 +1,85 @@
"""Device tracker platform for Teslemetry integration."""
from __future__ import annotations
from homeassistant.components.device_tracker import SourceType
from homeassistant.components.device_tracker.config_entry import TrackerEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .entity import TeslemetryVehicleEntity
from .models import TeslemetryVehicleData
async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
) -> None:
"""Set up the Teslemetry device tracker platform from a config entry."""
async_add_entities(
klass(vehicle)
for klass in (
TeslemetryDeviceTrackerLocationEntity,
TeslemetryDeviceTrackerRouteEntity,
)
for vehicle in entry.runtime_data.vehicles
)
class TeslemetryDeviceTrackerEntity(TeslemetryVehicleEntity, TrackerEntity):
"""Base class for Teslemetry tracker entities."""
lat_key: str
lon_key: str
def __init__(
self,
vehicle: TeslemetryVehicleData,
) -> None:
"""Initialize the device tracker."""
super().__init__(vehicle, self.key)
def _async_update_attrs(self) -> None:
"""Update the attributes of the device tracker."""
self._attr_available = (
self.get(self.lat_key, False) is not None
and self.get(self.lon_key, False) is not None
)
@property
def latitude(self) -> float | None:
"""Return latitude value of the device."""
return self.get(self.lat_key)
@property
def longitude(self) -> float | None:
"""Return longitude value of the device."""
return self.get(self.lon_key)
@property
def source_type(self) -> SourceType:
"""Return the source type of the device tracker."""
return SourceType.GPS
class TeslemetryDeviceTrackerLocationEntity(TeslemetryDeviceTrackerEntity):
"""Vehicle location device tracker class."""
key = "location"
lat_key = "drive_state_latitude"
lon_key = "drive_state_longitude"
class TeslemetryDeviceTrackerRouteEntity(TeslemetryDeviceTrackerEntity):
"""Vehicle navigation device tracker class."""
key = "route"
lat_key = "drive_state_active_route_latitude"
lon_key = "drive_state_active_route_longitude"
@property
def location_name(self) -> str | None:
"""Return a location name for the current location of the device."""
return self.get("drive_state_active_route_destination")

View File

@@ -109,6 +109,7 @@
"off": "mdi:car-seat"
}
},
"components_customer_preferred_export_rule": {
"default": "mdi:transmission-tower",
"state": {
@@ -126,6 +127,14 @@
}
}
},
"device_tracker": {
"location": {
"default": "mdi:map-marker"
},
"route": {
"default": "mdi:routes"
}
},
"cover": {
"charge_state_charge_port_door_open": {
"default": "mdi:ev-plug-ccs2"

View File

@@ -111,6 +111,14 @@
}
}
},
"device_tracker": {
"location": {
"name": "Location"
},
"route": {
"name": "Route"
}
},
"lock": {
"charge_state_charge_port_latch": {
"name": "Charge cable lock"

View File

@@ -13,7 +13,7 @@
"velbus-packet",
"velbus-protocol"
],
"requirements": ["velbus-aio==2024.4.1"],
"requirements": ["velbus-aio==2024.5.1"],
"usb": [
{
"vid": "10CF",

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Awaitable
import logging
from typing import Any
@@ -157,16 +156,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Unload Withings config entry."""
"""Unload vera config entry."""
controller_data: ControllerData = get_controller_data(hass, config_entry)
tasks: list[Awaitable] = [
hass.config_entries.async_forward_entry_unload(config_entry, platform)
for platform in get_configured_platforms(controller_data)
]
tasks.append(hass.async_add_executor_job(controller_data.controller.stop))
await asyncio.gather(*tasks)
await asyncio.gather(
*(
hass.config_entries.async_unload_platforms(
config_entry, get_configured_platforms(controller_data)
),
hass.async_add_executor_job(controller_data.controller.stop),
)
)
return True

View File

@@ -1,6 +1,5 @@
"""Support for Zigbee Home Automation devices."""
import asyncio
import contextlib
import copy
import logging
@@ -238,12 +237,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
websocket_api.async_unload_api(hass)
# our components don't have unload methods so no need to look at return values
await asyncio.gather(
*(
hass.config_entries.async_forward_entry_unload(config_entry, platform)
for platform in PLATFORMS
)
)
await hass.config_entries.async_unload_platforms(config_entry, PLATFORMS)
return True

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Coroutine
from contextlib import suppress
import logging
from typing import Any
@@ -958,14 +957,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
client: ZwaveClient = entry.runtime_data[DATA_CLIENT]
driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS]
tasks: list[Coroutine] = [
hass.config_entries.async_forward_entry_unload(entry, platform)
platforms = [
platform
for platform, task in driver_events.platform_setup_tasks.items()
if not task.cancel()
]
unload_ok = all(await asyncio.gather(*tasks)) if tasks else True
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
if client.connected and client.driver:
await async_disable_server_logging_if_needed(hass, entry, client.driver)

View File

@@ -3,12 +3,16 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any
import voluptuous as vol
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
from homeassistant.components.conversation.trace import (
ConversationTraceEventType,
async_conversation_trace_append,
)
from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
@@ -116,6 +120,10 @@ class API(ABC):
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
)
for tool in self.async_get_tools():
if tool.name == tool_input.tool_name:
break
@@ -191,7 +199,10 @@ class AssistAPI(API):
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
prompt = "Call the intent tools to control Home Assistant. Just pass the name to the intent."
prompt = (
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent."
)
if tool_input.device_id:
device_reg = device_registry.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id)

View File

@@ -1821,7 +1821,7 @@ pyegps==0.2.5
pyenphase==1.20.3
# homeassistant.components.envisalink
pyenvisalink==4.6
pyenvisalink==4.7
# homeassistant.components.ephember
pyephember==0.3.1
@@ -2817,7 +2817,7 @@ vallox-websocket-api==5.1.1
vehicle==2.2.1
# homeassistant.components.velbus
velbus-aio==2024.4.1
velbus-aio==2024.5.1
# homeassistant.components.venstar
venstarcolortouch==0.19

View File

@@ -2185,7 +2185,7 @@ vallox-websocket-api==5.1.1
vehicle==2.2.1
# homeassistant.components.velbus
velbus-aio==2024.4.1
velbus-aio==2024.5.1
# homeassistant.components.venstar
venstarcolortouch==0.19

View File

@@ -117,7 +117,6 @@ NO_IOT_CLASS = [
# https://github.com/home-assistant/developers.home-assistant/pull/1512
NO_DIAGNOSTICS = [
"dlna_dms",
"fronius",
"gdacs",
"geonetnz_quakes",
"google_assistant_sdk",

View File

@@ -2,7 +2,9 @@
from unittest.mock import patch
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant, State
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
@@ -31,6 +33,11 @@ async def test_state_set_and_restore(hass: HomeAssistant) -> None:
) as mock_process,
patch("homeassistant.util.dt.utcnow", return_value=now),
):
intent_response = intent.IntentResponse(language="en")
intent_response.async_set_speech("response text")
mock_process.return_value = conversation.ConversationResult(
response=intent_response,
)
await hass.services.async_call(
"conversation",
"process",

View File

@@ -0,0 +1,80 @@
"""Test for the conversation traces."""
from unittest.mock import patch
import pytest
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
@pytest.fixture
async def init_components(hass: HomeAssistant):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {})
async def test_converation_trace(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert trace_event.get("data")
assert trace_event["data"].get("text") == "add apples to my shopping list"
assert last_trace.get("result")
assert (
last_trace["result"]
.get("response", {})
.get("speech", {})
.get("plain", {})
.get("speech")
== "Added apples"
)
async def test_converation_trace_error(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
side_effect=HomeAssistantError("Failed to talk to agent"),
),
pytest.raises(HomeAssistantError),
):
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert last_trace.get("error") == "Failed to talk to agent"

View File

@@ -25,6 +25,7 @@ async def setup_fronius_integration(
"""Create the Fronius integration."""
entry = MockConfigEntry(
domain=DOMAIN,
entry_id="f1e2b9837e8adaed6fa682acaa216fd8",
unique_id=unique_id, # has to match mocked logger unique_id
data={
CONF_HOST: MOCK_HOST,

View File

@@ -0,0 +1,370 @@
# serializer version: 1
# name: test_diagnostics
dict({
'config_entry': dict({
'data': dict({
'host': 'http://fronius',
'is_logger': True,
}),
'disabled_by': None,
'domain': 'fronius',
'entry_id': 'f1e2b9837e8adaed6fa682acaa216fd8',
'minor_version': 1,
'options': dict({
}),
'pref_disable_new_entities': False,
'pref_disable_polling': False,
'source': 'user',
'title': 'Mock Title',
'unique_id': '**REDACTED**',
'version': 1,
}),
'coordinators': dict({
'inverters': dict({
'1': dict({
'current_ac': dict({
'unit': 'A',
'value': 5.19,
}),
'current_dc': dict({
'unit': 'A',
'value': 2.19,
}),
'energy_day': dict({
'unit': 'Wh',
'value': 1113,
}),
'energy_total': dict({
'unit': 'Wh',
'value': 44188000,
}),
'energy_year': dict({
'unit': 'Wh',
'value': 25508798,
}),
'error_code': dict({
'value': 0,
}),
'frequency_ac': dict({
'unit': 'Hz',
'value': 49.94,
}),
'led_color': dict({
'value': 2,
}),
'led_state': dict({
'value': 0,
}),
'power_ac': dict({
'unit': 'W',
'value': 1190,
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'status_code': dict({
'value': 7,
}),
'timestamp': dict({
'value': '2021-10-07T10:01:17+02:00',
}),
'voltage_ac': dict({
'unit': 'V',
'value': 227.9,
}),
'voltage_dc': dict({
'unit': 'V',
'value': 518,
}),
}),
}),
'logger': dict({
'system': dict({
'cash_factor': dict({
'unit': 'EUR/kWh',
'value': 0.07800000160932541,
}),
'co2_factor': dict({
'unit': 'kg/kWh',
'value': 0.5299999713897705,
}),
'delivery_factor': dict({
'unit': 'EUR/kWh',
'value': 0.15000000596046448,
}),
'hardware_platform': dict({
'value': 'wilma',
}),
'hardware_version': dict({
'value': '2.4E',
}),
'product_type': dict({
'value': 'fronius-datamanager-card',
}),
'software_version': dict({
'value': '3.18.7-1',
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'time_zone': dict({
'value': 'CEST',
}),
'time_zone_location': dict({
'value': 'Vienna',
}),
'timestamp': dict({
'value': '2021-10-06T23:56:32+02:00',
}),
'unique_identifier': '**REDACTED**',
'utc_offset': dict({
'value': 7200,
}),
}),
}),
'meter': dict({
'0': dict({
'current_ac_phase_1': dict({
'unit': 'A',
'value': 7.755,
}),
'current_ac_phase_2': dict({
'unit': 'A',
'value': 6.68,
}),
'current_ac_phase_3': dict({
'unit': 'A',
'value': 10.102,
}),
'enable': dict({
'value': 1,
}),
'energy_reactive_ac_consumed': dict({
'unit': 'VArh',
'value': 59960790,
}),
'energy_reactive_ac_produced': dict({
'unit': 'VArh',
'value': 723160,
}),
'energy_real_ac_minus': dict({
'unit': 'Wh',
'value': 35623065,
}),
'energy_real_ac_plus': dict({
'unit': 'Wh',
'value': 15303334,
}),
'energy_real_consumed': dict({
'unit': 'Wh',
'value': 15303334,
}),
'energy_real_produced': dict({
'unit': 'Wh',
'value': 35623065,
}),
'frequency_phase_average': dict({
'unit': 'Hz',
'value': 50,
}),
'manufacturer': dict({
'value': 'Fronius',
}),
'meter_location': dict({
'value': 0,
}),
'model': dict({
'value': 'Smart Meter 63A',
}),
'power_apparent': dict({
'unit': 'VA',
'value': 5592.57,
}),
'power_apparent_phase_1': dict({
'unit': 'VA',
'value': 1772.793,
}),
'power_apparent_phase_2': dict({
'unit': 'VA',
'value': 1527.048,
}),
'power_apparent_phase_3': dict({
'unit': 'VA',
'value': 2333.562,
}),
'power_factor': dict({
'value': 1,
}),
'power_factor_phase_1': dict({
'value': -0.99,
}),
'power_factor_phase_2': dict({
'value': -0.99,
}),
'power_factor_phase_3': dict({
'value': 0.99,
}),
'power_reactive': dict({
'unit': 'VAr',
'value': 2.87,
}),
'power_reactive_phase_1': dict({
'unit': 'VAr',
'value': 51.48,
}),
'power_reactive_phase_2': dict({
'unit': 'VAr',
'value': 115.63,
}),
'power_reactive_phase_3': dict({
'unit': 'VAr',
'value': -164.24,
}),
'power_real': dict({
'unit': 'W',
'value': 5592.57,
}),
'power_real_phase_1': dict({
'unit': 'W',
'value': 1765.55,
}),
'power_real_phase_2': dict({
'unit': 'W',
'value': 1515.8,
}),
'power_real_phase_3': dict({
'unit': 'W',
'value': 2311.22,
}),
'serial': '**REDACTED**',
'visible': dict({
'value': 1,
}),
'voltage_ac_phase_1': dict({
'unit': 'V',
'value': 228.6,
}),
'voltage_ac_phase_2': dict({
'unit': 'V',
'value': 228.6,
}),
'voltage_ac_phase_3': dict({
'unit': 'V',
'value': 231,
}),
'voltage_ac_phase_to_phase_12': dict({
'unit': 'V',
'value': 395.9,
}),
'voltage_ac_phase_to_phase_23': dict({
'unit': 'V',
'value': 398,
}),
'voltage_ac_phase_to_phase_31': dict({
'unit': 'V',
'value': 398,
}),
}),
}),
'ohmpilot': None,
'power_flow': dict({
'power_flow': dict({
'energy_day': dict({
'unit': 'Wh',
'value': 1101.7000732421875,
}),
'energy_total': dict({
'unit': 'Wh',
'value': 44188000,
}),
'energy_year': dict({
'unit': 'Wh',
'value': 25508788,
}),
'meter_location': dict({
'value': 'grid',
}),
'meter_mode': dict({
'value': 'meter',
}),
'power_battery': dict({
'unit': 'W',
'value': None,
}),
'power_grid': dict({
'unit': 'W',
'value': 1703.74,
}),
'power_load': dict({
'unit': 'W',
'value': -2814.74,
}),
'power_photovoltaics': dict({
'unit': 'W',
'value': 1111,
}),
'relative_autonomy': dict({
'unit': '%',
'value': 39.4707859340472,
}),
'relative_self_consumption': dict({
'unit': '%',
'value': 100,
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'timestamp': dict({
'value': '2021-10-07T10:00:43+02:00',
}),
}),
}),
'storage': None,
}),
'inverter_info': dict({
'inverters': list([
dict({
'custom_name': dict({
'value': 'Symo 20',
}),
'device_id': dict({
'value': '1',
}),
'device_type': dict({
'manufacturer': 'Fronius',
'model': 'Symo 20.0-3-M',
'value': 121,
}),
'error_code': dict({
'value': 0,
}),
'pv_power': dict({
'unit': 'W',
'value': 23100,
}),
'show': dict({
'value': 1,
}),
'status_code': dict({
'value': 7,
}),
'unique_id': '**REDACTED**',
}),
]),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'timestamp': dict({
'value': '2021-10-07T13:41:00+02:00',
}),
}),
})
# ---

View File

@@ -0,0 +1,31 @@
"""Tests for the diagnostics data provided by the KNX integration."""
from syrupy import SnapshotAssertion
from homeassistant.core import HomeAssistant
from . import mock_responses, setup_fronius_integration
from tests.components.diagnostics import get_diagnostics_for_config_entry
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
async def test_diagnostics(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
snapshot: SnapshotAssertion,
) -> None:
"""Test diagnostics."""
mock_responses(aioclient_mock)
entry = await setup_fronius_integration(hass)
assert (
await get_diagnostics_for_config_entry(
hass,
hass_client,
entry,
)
== snapshot
)

View File

@@ -1,4 +1,114 @@
# serializer version: 1
# name: test_chat_history
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'1st user request',
),
dict({
}),
),
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
dict({
'parts': '1st user request',
'role': 'user',
}),
dict({
'parts': '1st model response',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'2nd user request',
),
dict({
}),
),
])
# ---
# name: test_default_prompt[config_entry_options0-None]
list([
tuple(
@@ -14,10 +124,10 @@
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
@@ -29,7 +139,10 @@
dict({
'history': list([
dict({
'parts': 'Answer in plain text. Keep it simple and to the point.',
'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
dict({
@@ -64,10 +177,10 @@
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
@@ -79,7 +192,10 @@
dict({
'history': list([
dict({
'parts': 'Answer in plain text. Keep it simple and to the point.',
'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
dict({
@@ -114,10 +230,10 @@
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
@@ -130,8 +246,8 @@
'history': list([
dict({
'parts': '''
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
''',
'role': 'user',
}),
@@ -167,10 +283,10 @@
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE',
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
@@ -183,8 +299,8 @@
'history': list([
dict({
'parts': '''
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
''',
'role': 'user',
}),

View File

@@ -2,12 +2,14 @@
from unittest.mock import AsyncMock, MagicMock, patch
from google.api_core.exceptions import ClientError
from google.api_core.exceptions import GoogleAPICallError
import google.generativeai.types as genai_types
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@@ -150,6 +152,57 @@ async def test_default_prompt(
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
async def test_chat_history(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the agent keeps track of the chat history."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call = None
chat_response.parts = [mock_part]
chat_response.text = "1st model response"
mock_chat.history = [
{"role": "user", "parts": "prompt"},
{"role": "model", "parts": "Ok"},
{"role": "user", "parts": "1st user request"},
{"role": "model", "parts": "1st model response"},
]
result = await conversation.async_converse(
hass,
"1st user request",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.as_dict()["speech"]["plain"]["speech"]
== "1st model response"
)
chat_response.text = "2nd model response"
result = await conversation.async_converse(
hass,
"2nd user request",
result.conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.as_dict()["speech"]["plain"]["speech"]
== "2nd model response"
)
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
)
@@ -233,6 +286,20 @@ async def test_function_call(
),
)
# Test conversating tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"]
@patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
@@ -325,7 +392,7 @@ async def test_error_handling(
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = ClientError("some error")
mock_chat.send_message_async.side_effect = GoogleAPICallError("some error")
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
@@ -340,7 +407,28 @@ async def test_error_handling(
async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test response was blocked."""
"""Test blocked response."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException(
"finish_reason: SAFETY\n"
)
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"The message got blocked by your safety settings"
)
async def test_empty_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test empty response."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
@@ -358,6 +446,32 @@ async def test_blocked_response(
)
async def test_invalid_llm_api(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test handling of invalid llm api."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
)
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Error preparing LLM API: API invalid_llm_api not found"
)
async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:

View File

@@ -529,16 +529,16 @@ async def test_non_unique_triggers(
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[0].data["some"] == "press1"
assert calls[1].data["some"] == "press2"
all_calls = {calls[0].data["some"], calls[1].data["some"]}
assert all_calls == {"press1", "press2"}
# Trigger second config references to same trigger
# and triggers both attached instances.
async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press")
await hass.async_block_till_done()
assert len(calls) == 2
assert calls[0].data["some"] == "press1"
assert calls[1].data["some"] == "press2"
all_calls = {calls[0].data["some"], calls[1].data["some"]}
assert all_calls == {"press1", "press2"}
# Removing the first trigger will clean up
calls.clear()

View File

@@ -4,6 +4,7 @@ import asyncio
from collections.abc import Generator
from copy import deepcopy
from datetime import datetime, timedelta
from functools import partial
import json
import logging
import socket
@@ -1050,6 +1051,27 @@ async def test_subscribe_topic_not_initialize(
await mqtt.async_subscribe(hass, "test-topic", record_calls)
async def test_subscribe_mqtt_config_entry_disabled(
hass: HomeAssistant, mqtt_mock: MqttMockHAClient
) -> None:
"""Test the subscription of a topic when MQTT config entry is disabled."""
mqtt_mock.connected = True
mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
assert mqtt_config_entry.state is ConfigEntryState.LOADED
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED
await hass.config_entries.async_set_disabled_by(
mqtt_config_entry.entry_id, ConfigEntryDisabler.USER
)
mqtt_mock.connected = False
with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"):
await mqtt.async_subscribe(hass, "test-topic", record_calls)
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
async def test_subscribe_and_resubscribe(
@@ -2912,8 +2934,8 @@ async def test_message_callback_exception_gets_logged(
await mqtt_mock_entry()
@callback
def bad_handler(*args) -> None:
"""Record calls."""
def bad_handler(msg: ReceiveMessage) -> None:
"""Handle callback."""
raise ValueError("This is a bad message callback")
await mqtt.async_subscribe(hass, "test-topic", bad_handler)
@@ -2926,6 +2948,40 @@ async def test_message_callback_exception_gets_logged(
)
@pytest.mark.no_fail_on_log_exception
async def test_message_partial_callback_exception_gets_logged(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mqtt_mock_entry: MqttMockHAClientGenerator,
) -> None:
"""Test exception raised by message handler."""
await mqtt_mock_entry()
@callback
def bad_handler(msg: ReceiveMessage) -> None:
"""Handle callback."""
raise ValueError("This is a bad message callback")
def parial_handler(
msg_callback: MessageCallbackType,
attributes: set[str],
msg: ReceiveMessage,
) -> None:
"""Partial callback handler."""
msg_callback(msg)
await mqtt.async_subscribe(
hass, "test-topic", partial(parial_handler, bad_handler, {"some_attr"})
)
async_fire_mqtt_message(hass, "test-topic", "test")
await hass.async_block_till_done()
assert (
"Exception in bad_handler when handling msg on 'test-topic':"
" 'test'" in caplog.text
)
async def test_mqtt_ws_subscription(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
@@ -3787,7 +3843,7 @@ async def test_unload_config_entry(
async def test_publish_or_subscribe_without_valid_config_entry(
hass: HomeAssistant, record_calls: MessageCallbackType
) -> None:
"""Test internal publish function with bas use cases."""
"""Test internal publish function with bad use cases."""
with pytest.raises(HomeAssistantError):
await mqtt.async_publish(
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None

View File

@@ -11,8 +11,12 @@ from .const import DEFAULT_FORECAST, DEFAULT_OBSERVATION
@pytest.fixture
def mock_simple_nws():
"""Mock pynws SimpleNWS with default values."""
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws:
# set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
with (
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_nws.return_value
instance.set_station = AsyncMock(return_value=None)
instance.update_observation = AsyncMock(return_value=None)
@@ -29,7 +33,12 @@ def mock_simple_nws():
@pytest.fixture
def mock_simple_nws_times_out():
"""Mock pynws SimpleNWS that times out."""
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws:
# set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
with (
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_nws.return_value
instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError)
instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError)

View File

@@ -1,7 +1,6 @@
"""Tests for the NWS weather component."""
from datetime import timedelta
from unittest.mock import patch
import aiohttp
from freezegun.api import FrozenDateTimeFactory
@@ -24,7 +23,6 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
from .const import (
@@ -127,47 +125,43 @@ async def test_data_caching_error_observation(
caplog,
) -> None:
"""Test caching of data with errors."""
with (
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_simple_nws.return_value
instance = mock_simple_nws.return_value
entry = MockConfigEntry(
domain=nws.DOMAIN,
data=NWS_CONFIG,
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
entry = MockConfigEntry(
domain=nws.DOMAIN,
data=NWS_CONFIG,
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
state = hass.states.get("weather.abc")
assert state.state == "sunny"
state = hass.states.get("weather.abc")
assert state.state == "sunny"
# data is still valid even when update fails
instance.update_observation.side_effect = NwsNoDataError("Test")
# data is still valid even when update fails
instance.update_observation.side_effect = NwsNoDataError("Test")
freezer.tick(DEFAULT_SCAN_INTERVAL + timedelta(seconds=100))
async_fire_time_changed(hass)
await hass.async_block_till_done()
freezer.tick(DEFAULT_SCAN_INTERVAL + timedelta(seconds=100))
async_fire_time_changed(hass)
await hass.async_block_till_done()
state = hass.states.get("weather.abc")
assert state.state == "sunny"
state = hass.states.get("weather.abc")
assert state.state == "sunny"
assert (
"NWS observation update failed, but data still valid. Last success: "
in caplog.text
)
assert (
"NWS observation update failed, but data still valid. Last success: "
in caplog.text
)
# data is no longer valid after OBSERVATION_VALID_TIME
freezer.tick(OBSERVATION_VALID_TIME + timedelta(seconds=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
# data is no longer valid after OBSERVATION_VALID_TIME
freezer.tick(OBSERVATION_VALID_TIME + timedelta(seconds=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
state = hass.states.get("weather.abc")
assert state.state == STATE_UNAVAILABLE
state = hass.states.get("weather.abc")
assert state.state == STATE_UNAVAILABLE
assert "Error fetching NWS observation station ABC data: Test" in caplog.text
assert "Error fetching NWS observation station ABC data: Test" in caplog.text
async def test_no_data_error_observation(
@@ -302,26 +296,23 @@ async def test_error_observation(
hass: HomeAssistant, mock_simple_nws, no_sensor
) -> None:
"""Test error during update observation."""
utc_time = dt_util.utcnow()
with patch("homeassistant.components.nws.coordinator.utcnow") as mock_utc:
mock_utc.return_value = utc_time
instance = mock_simple_nws.return_value
# first update fails
instance.update_observation.side_effect = aiohttp.ClientError
instance = mock_simple_nws.return_value
# first update fails
instance.update_observation.side_effect = aiohttp.ClientError
entry = MockConfigEntry(
domain=nws.DOMAIN,
data=NWS_CONFIG,
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
entry = MockConfigEntry(
domain=nws.DOMAIN,
data=NWS_CONFIG,
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
instance.update_observation.assert_called_once()
instance.update_observation.assert_called_once()
state = hass.states.get("weather.abc")
assert state
assert state.state == STATE_UNAVAILABLE
state = hass.states.get("weather.abc")
assert state
assert state.state == STATE_UNAVAILABLE
async def test_new_config_entry(hass: HomeAssistant, no_sensor) -> None:

View File

@@ -6,6 +6,7 @@ from ollama import Message, ResponseError
import pytest
from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
@@ -110,6 +111,19 @@ async def test_chat(
), result
assert result.response.speech["plain"]["speech"] == "test response"
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
async def test_message_history_trimming(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component

View File

@@ -9,9 +9,17 @@ import pytest
from homeassistant import config_entries
from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL,
DEFAULT_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_P,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@@ -75,7 +83,7 @@ async def test_options(
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"]["prompt"] == "Speak like a pirate"
assert options["data"]["max_tokens"] == 200
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL
assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
@pytest.mark.parametrize(
@@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": error}
@pytest.mark.parametrize(
("current_options", "new_options", "expected_options"),
[
(
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "none",
CONF_PROMPT: "bla",
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
},
),
(
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
),
],
)
async def test_options_switching(
hass: HomeAssistant,
mock_config_entry,
mock_init_component,
current_options,
new_options,
expected_options,
) -> None:
"""Test the options form."""
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
options_flow = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
**current_options,
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
},
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
new_options,
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == expected_options

View File

@@ -15,6 +15,7 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@@ -200,6 +201,20 @@ async def test_function_call(
),
)
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
@patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"

View File

@@ -2,6 +2,7 @@
from unittest.mock import patch
from freezegun import freeze_time
from pyplaato.models.airlock import PlaatoAirlock
from pyplaato.models.device import PlaatoDeviceType
from pyplaato.models.keg import PlaatoKeg
@@ -23,6 +24,7 @@ AIRLOCK_DATA = {}
KEG_DATA = {}
@freeze_time("2024-05-24 12:00:00", tz_offset=0)
async def init_integration(
hass: HomeAssistant, device_type: PlaatoDeviceType
) -> MockConfigEntry:

View File

@@ -492,7 +492,6 @@ async def test_block_set_mode_auth_error(
{ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT},
blocking=True,
)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED

View File

@@ -227,7 +227,6 @@ async def test_block_set_value_auth_error(
{ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30},
blocking=True,
)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED

View File

@@ -618,7 +618,6 @@ async def test_rpc_sleeping_update_entity_service(
service_data={ATTR_ENTITY_ID: entity_id},
blocking=True,
)
await hass.async_block_till_done()
# Entity should be available after update_entity service call
state = hass.states.get(entity_id)
@@ -667,7 +666,6 @@ async def test_block_sleeping_update_entity_service(
service_data={ATTR_ENTITY_ID: entity_id},
blocking=True,
)
await hass.async_block_till_done()
# Entity should be available after update_entity service call
state = hass.states.get(entity_id)

View File

@@ -230,7 +230,6 @@ async def test_block_set_state_auth_error(
{ATTR_ENTITY_ID: "switch.test_name_channel_1"},
blocking=True,
)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED
@@ -374,7 +373,6 @@ async def test_rpc_auth_error(
{ATTR_ENTITY_ID: "switch.test_switch_0"},
blocking=True,
)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED

View File

@@ -207,7 +207,6 @@ async def test_block_update_auth_error(
{ATTR_ENTITY_ID: "update.test_name_firmware_update"},
blocking=True,
)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED
@@ -669,7 +668,6 @@ async def test_rpc_update_auth_error(
blocking=True,
)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress()

View File

@@ -18,9 +18,15 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
@pytest.fixture
def mock_bridge(request):
"""Return a mocked SwitcherBridge."""
with patch(
"homeassistant.components.switcher_kis.utils.SwitcherBridge", autospec=True
) as bridge_mock:
with (
patch(
"homeassistant.components.switcher_kis.SwitcherBridge", autospec=True
) as bridge_mock,
patch(
"homeassistant.components.switcher_kis.utils.SwitcherBridge",
new=bridge_mock,
),
):
bridge = bridge_mock.return_value
bridge.devices = []

View File

@@ -4,11 +4,7 @@ from datetime import timedelta
import pytest
from homeassistant.components.switcher_kis.const import (
DATA_DEVICE,
DOMAIN,
MAX_UPDATE_INTERVAL_SEC,
)
from homeassistant.components.switcher_kis.const import MAX_UPDATE_INTERVAL_SEC
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant
@@ -24,15 +20,14 @@ async def test_update_fail(
hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture
) -> None:
"""Test entities state unavailable when updates fail.."""
await init_integration(hass)
entry = await init_integration(hass)
assert mock_bridge
mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES)
await hass.async_block_till_done()
assert mock_bridge.is_running is True
assert len(hass.data[DOMAIN]) == 2
assert len(hass.data[DOMAIN][DATA_DEVICE]) == 2
assert len(entry.runtime_data) == 2
async_fire_time_changed(
hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1)
@@ -77,11 +72,9 @@ async def test_entry_unload(hass: HomeAssistant, mock_bridge) -> None:
assert entry.state is ConfigEntryState.LOADED
assert mock_bridge.is_running is True
assert len(hass.data[DOMAIN]) == 2
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.NOT_LOADED
assert mock_bridge.is_running is False
assert len(hass.data[DOMAIN]) == 0

Some files were not shown because too many files have changed in this diff Show More