Merge branch 'dev' into jbouwh-mqtt-device-discovery

This commit is contained in:
J. Nick Koston
2024-05-25 11:39:47 -10:00
committed by GitHub
106 changed files with 3208 additions and 1980 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import logging import logging
from typing import Any from typing import Any
@@ -20,6 +21,11 @@ from .models import (
ConversationInput, ConversationInput,
ConversationResult, ConversationResult,
) )
from .trace import (
ConversationTraceEvent,
ConversationTraceEventType,
async_conversation_trace,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -84,15 +90,23 @@ async def async_converse(
language = hass.config.language language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text) _LOGGER.debug("Processing in %s: %s", language, text)
return await method( conversation_input = ConversationInput(
ConversationInput( text=text,
text=text, context=context,
context=context, conversation_id=conversation_id,
conversation_id=conversation_id, device_id=device_id,
device_id=device_id, language=language,
language=language,
)
) )
with async_conversation_trace() as trace:
trace.add_event(
ConversationTraceEvent(
ConversationTraceEventType.ASYNC_PROCESS,
dataclasses.asdict(conversation_input),
)
)
result = await method(conversation_input)
trace.set_result(**result.as_dict())
return result
class AgentManager: class AgentManager:

View File

@@ -0,0 +1,118 @@
"""Debug traces for conversation."""
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import asdict, dataclass, field
import enum
from typing import Any
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.limited_size_dict import LimitedSizeDict
STORED_TRACES = 3
class ConversationTraceEventType(enum.StrEnum):
"""Type of an event emitted during a conversation."""
ASYNC_PROCESS = "async_process"
"""The conversation is started from user input."""
AGENT_DETAIL = "agent_detail"
"""Event detail added by a conversation agent."""
LLM_TOOL_CALL = "llm_tool_call"
"""An LLM Tool call"""
@dataclass(frozen=True)
class ConversationTraceEvent:
"""Event emitted during a conversation."""
event_type: ConversationTraceEventType
data: dict[str, Any] | None = None
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
class ConversationTrace:
"""Stores debug data related to a conversation."""
def __init__(self) -> None:
"""Initialize ConversationTrace."""
self._trace_id = ulid_util.ulid_now()
self._events: list[ConversationTraceEvent] = []
self._error: Exception | None = None
self._result: dict[str, Any] = {}
@property
def trace_id(self) -> str:
"""Identifier for this trace."""
return self._trace_id
def add_event(self, event: ConversationTraceEvent) -> None:
"""Add an event to the trace."""
self._events.append(event)
def set_error(self, ex: Exception) -> None:
"""Set error."""
self._error = ex
def set_result(self, **kwargs: Any) -> None:
"""Set result."""
self._result = {**kwargs}
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this ConversationTrace."""
result: dict[str, Any] = {
"id": self._trace_id,
"events": [asdict(event) for event in self._events],
}
if self._error is not None:
result["error"] = str(self._error) or self._error.__class__.__name__
if self._result is not None:
result["result"] = self._result
return result
_current_trace: ContextVar[ConversationTrace | None] = ContextVar(
"current_trace", default=None
)
_recent_traces: LimitedSizeDict[str, ConversationTrace] = LimitedSizeDict(
size_limit=STORED_TRACES
)
def async_conversation_trace_append(
event_type: ConversationTraceEventType, event_data: dict[str, Any]
) -> None:
"""Append a ConversationTraceEvent to the current active trace."""
trace = _current_trace.get()
if not trace:
return
trace.add_event(ConversationTraceEvent(event_type, event_data))
@contextmanager
def async_conversation_trace() -> Generator[ConversationTrace, None]:
"""Create a new active ConversationTrace."""
trace = ConversationTrace()
token = _current_trace.set(trace)
_recent_traces[trace.trace_id] = trace
try:
yield trace
except Exception as ex:
trace.set_error(ex)
raise
finally:
_current_trace.reset(token)
def async_get_traces() -> list[ConversationTrace]:
"""Get the most recent traces."""
return list(_recent_traces.values())
def async_clear_traces() -> None:
"""Clear all traces."""
_recent_traces.clear()

View File

@@ -5,5 +5,5 @@
"documentation": "https://www.home-assistant.io/integrations/envisalink", "documentation": "https://www.home-assistant.io/integrations/envisalink",
"iot_class": "local_push", "iot_class": "local_push",
"loggers": ["pyenvisalink"], "loggers": ["pyenvisalink"],
"requirements": ["pyenvisalink==4.6"] "requirements": ["pyenvisalink==4.7"]
} }

View File

@@ -1,6 +1,5 @@
"""Support for Arduino-compatible Microcontrollers through Firmata.""" """Support for Arduino-compatible Microcontrollers through Firmata."""
import asyncio
from copy import copy from copy import copy
import logging import logging
@@ -212,16 +211,15 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Shutdown and close a Firmata board for a config entry.""" """Shutdown and close a Firmata board for a config entry."""
_LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME]) _LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME])
results: list[bool] = []
unload_entries = [] if platforms := [
for conf, platform in CONF_PLATFORM_MAP.items(): platform
if conf in config_entry.data: for conf, platform in CONF_PLATFORM_MAP.items()
unload_entries.append( if conf in config_entry.data
hass.config_entries.async_forward_entry_unload(config_entry, platform) ]:
) results.append(
results = [] await hass.config_entries.async_unload_platforms(config_entry, platforms)
if unload_entries: )
results = await asyncio.gather(*unload_entries)
results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset()) results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset())
return False not in results return False not in results

View File

@@ -11,12 +11,13 @@ from .const import (
CONF_DAMPING_EVENING, CONF_DAMPING_EVENING,
CONF_DAMPING_MORNING, CONF_DAMPING_MORNING,
CONF_MODULES_POWER, CONF_MODULES_POWER,
DOMAIN,
) )
from .coordinator import ForecastSolarDataUpdateCoordinator from .coordinator import ForecastSolarDataUpdateCoordinator
PLATFORMS = [Platform.SENSOR] PLATFORMS = [Platform.SENSOR]
type ForecastSolarConfigEntry = ConfigEntry[ForecastSolarDataUpdateCoordinator]
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Migrate old config entry.""" """Migrate old config entry."""
@@ -36,12 +37,14 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(
hass: HomeAssistant, entry: ForecastSolarConfigEntry
) -> bool:
"""Set up Forecast.Solar from a config entry.""" """Set up Forecast.Solar from a config entry."""
coordinator = ForecastSolarDataUpdateCoordinator(hass, entry) coordinator = ForecastSolarDataUpdateCoordinator(hass, entry)
await coordinator.async_config_entry_first_refresh() await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator entry.runtime_data = coordinator
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
@@ -52,11 +55,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok
async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None: async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None:

View File

@@ -4,15 +4,11 @@ from __future__ import annotations
from typing import Any from typing import Any
from forecast_solar import Estimate
from homeassistant.components.diagnostics import async_redact_data from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from .const import DOMAIN from . import ForecastSolarConfigEntry
TO_REDACT = { TO_REDACT = {
CONF_API_KEY, CONF_API_KEY,
@@ -22,10 +18,10 @@ TO_REDACT = {
async def async_get_config_entry_diagnostics( async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: ForecastSolarConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
coordinator: DataUpdateCoordinator[Estimate] = hass.data[DOMAIN][entry.entry_id] coordinator = entry.runtime_data
return { return {
"entry": { "entry": {

View File

@@ -4,19 +4,21 @@ from __future__ import annotations
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import DOMAIN from .coordinator import ForecastSolarDataUpdateCoordinator
async def async_get_solar_forecast( async def async_get_solar_forecast(
hass: HomeAssistant, config_entry_id: str hass: HomeAssistant, config_entry_id: str
) -> dict[str, dict[str, float | int]] | None: ) -> dict[str, dict[str, float | int]] | None:
"""Get solar forecast for a config entry ID.""" """Get solar forecast for a config entry ID."""
if (coordinator := hass.data[DOMAIN].get(config_entry_id)) is None: if (
entry := hass.config_entries.async_get_entry(config_entry_id)
) is None or not isinstance(entry.runtime_data, ForecastSolarDataUpdateCoordinator):
return None return None
return { return {
"wh_hours": { "wh_hours": {
timestamp.isoformat(): val timestamp.isoformat(): val
for timestamp, val in coordinator.data.wh_period.items() for timestamp, val in entry.runtime_data.data.wh_period.items()
} }
} }

View File

@@ -16,7 +16,6 @@ from homeassistant.components.sensor import (
SensorEntityDescription, SensorEntityDescription,
SensorStateClass, SensorStateClass,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import UnitOfEnergy, UnitOfPower from homeassistant.const import UnitOfEnergy, UnitOfPower
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
@@ -24,6 +23,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import ForecastSolarConfigEntry
from .const import DOMAIN from .const import DOMAIN
from .coordinator import ForecastSolarDataUpdateCoordinator from .coordinator import ForecastSolarDataUpdateCoordinator
@@ -133,10 +133,12 @@ SENSORS: tuple[ForecastSolarSensorEntityDescription, ...] = (
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback hass: HomeAssistant,
entry: ForecastSolarConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Defer sensor setup to the shared sensor module.""" """Defer sensor setup to the shared sensor module."""
coordinator: ForecastSolarDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id] coordinator = entry.runtime_data
async_add_entities( async_add_entities(
ForecastSolarSensorEntity( ForecastSolarSensorEntity(

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
from typing import Final from typing import Any, Final
from homeassistant.components.button import ( from homeassistant.components.button import (
ButtonDeviceClass, ButtonDeviceClass,
@@ -30,7 +30,7 @@ _LOGGER = logging.getLogger(__name__)
class FritzButtonDescription(ButtonEntityDescription): class FritzButtonDescription(ButtonEntityDescription):
"""Class to describe a Button entity.""" """Class to describe a Button entity."""
press_action: Callable press_action: Callable[[AvmWrapper], Any]
BUTTONS: Final = [ BUTTONS: Final = [

View File

@@ -57,9 +57,6 @@ ERROR_UPNP_NOT_CONFIGURED = "upnp_not_configured"
ERROR_UNKNOWN = "unknown_error" ERROR_UNKNOWN = "unknown_error"
FRITZ_SERVICES = "fritz_services" FRITZ_SERVICES = "fritz_services"
SERVICE_REBOOT = "reboot"
SERVICE_RECONNECT = "reconnect"
SERVICE_CLEANUP = "cleanup"
SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password" SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password"
SWITCH_TYPE_DEFLECTION = "CallDeflection" SWITCH_TYPE_DEFLECTION = "CallDeflection"

View File

@@ -46,9 +46,6 @@ from .const import (
DEFAULT_USERNAME, DEFAULT_USERNAME,
DOMAIN, DOMAIN,
FRITZ_EXCEPTIONS, FRITZ_EXCEPTIONS,
SERVICE_CLEANUP,
SERVICE_REBOOT,
SERVICE_RECONNECT,
SERVICE_SET_GUEST_WIFI_PW, SERVICE_SET_GUEST_WIFI_PW,
MeshRoles, MeshRoles,
) )
@@ -730,30 +727,6 @@ class FritzBoxTools(DataUpdateCoordinator[UpdateCoordinatorDataType]):
) )
try: try:
if service_call.service == SERVICE_REBOOT:
_LOGGER.warning(
'Service "fritz.reboot" is deprecated, please use the corresponding'
" button entity instead"
)
await self.async_trigger_reboot()
return
if service_call.service == SERVICE_RECONNECT:
_LOGGER.warning(
'Service "fritz.reconnect" is deprecated, please use the'
" corresponding button entity instead"
)
await self.async_trigger_reconnect()
return
if service_call.service == SERVICE_CLEANUP:
_LOGGER.warning(
'Service "fritz.cleanup" is deprecated, please use the'
" corresponding button entity instead"
)
await self.async_trigger_cleanup(config_entry)
return
if service_call.service == SERVICE_SET_GUEST_WIFI_PW: if service_call.service == SERVICE_SET_GUEST_WIFI_PW:
await self.async_trigger_set_guest_password( await self.async_trigger_set_guest_password(
service_call.data.get("password"), service_call.data.get("password"),

View File

@@ -11,14 +11,7 @@ from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.service import async_extract_config_entry_ids from homeassistant.helpers.service import async_extract_config_entry_ids
from .const import ( from .const import DOMAIN, FRITZ_SERVICES, SERVICE_SET_GUEST_WIFI_PW
DOMAIN,
FRITZ_SERVICES,
SERVICE_CLEANUP,
SERVICE_REBOOT,
SERVICE_RECONNECT,
SERVICE_SET_GUEST_WIFI_PW,
)
from .coordinator import AvmWrapper from .coordinator import AvmWrapper
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -32,9 +25,6 @@ SERVICE_SCHEMA_SET_GUEST_WIFI_PW = vol.Schema(
) )
SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [ SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [
(SERVICE_CLEANUP, None),
(SERVICE_REBOOT, None),
(SERVICE_RECONNECT, None),
(SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW), (SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW),
] ]

View File

@@ -1,31 +1,3 @@
reconnect:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
reboot:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
cleanup:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
set_guest_wifi_password: set_guest_wifi_password:
fields: fields:
device_id: device_id:

View File

@@ -144,42 +144,12 @@
} }
}, },
"services": { "services": {
"reconnect": {
"name": "[%key:component::fritz::entity::button::reconnect::name%]",
"description": "Reconnects your FRITZ!Box internet connection.",
"fields": {
"device_id": {
"name": "Fritz!Box Device",
"description": "Select the Fritz!Box to reconnect."
}
}
},
"reboot": {
"name": "Reboot",
"description": "Reboots your FRITZ!Box.",
"fields": {
"device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
"description": "Select the Fritz!Box to reboot."
}
}
},
"cleanup": {
"name": "Remove stale device tracker entities",
"description": "Remove FRITZ!Box stale device_tracker entities.",
"fields": {
"device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
"description": "Select the Fritz!Box to check."
}
}
},
"set_guest_wifi_password": { "set_guest_wifi_password": {
"name": "Set guest Wi-Fi password", "name": "Set guest Wi-Fi password",
"description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.", "description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.",
"fields": { "fields": {
"device_id": { "device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]", "name": "Fritz!Box Device",
"description": "Select the Fritz!Box to configure." "description": "Select the Fritz!Box to configure."
}, },
"password": { "password": {

View File

@@ -0,0 +1,46 @@
"""Diagnostics support for Fronius."""
from typing import Any
from homeassistant.components.diagnostics import async_redact_data
from homeassistant.core import HomeAssistant
from . import FroniusConfigEntry
TO_REDACT = {"unique_id", "unique_identifier", "serial"}
async def async_get_config_entry_diagnostics(
hass: HomeAssistant, config_entry: FroniusConfigEntry
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
diag: dict[str, Any] = {}
solar_net = config_entry.runtime_data
fronius = solar_net.fronius
diag["config_entry"] = config_entry.as_dict()
diag["inverter_info"] = await fronius.inverter_info()
diag["coordinators"] = {"inverters": {}}
for inv in solar_net.inverter_coordinators:
diag["coordinators"]["inverters"] |= inv.data
diag["coordinators"]["logger"] = (
solar_net.logger_coordinator.data if solar_net.logger_coordinator else None
)
diag["coordinators"]["meter"] = (
solar_net.meter_coordinator.data if solar_net.meter_coordinator else None
)
diag["coordinators"]["ohmpilot"] = (
solar_net.ohmpilot_coordinator.data if solar_net.ohmpilot_coordinator else None
)
diag["coordinators"]["power_flow"] = (
solar_net.power_flow_coordinator.data
if solar_net.power_flow_coordinator
else None
)
diag["coordinators"]["storage"] = (
solar_net.storage_coordinator.data if solar_net.storage_coordinator else None
)
return async_redact_data(diag, TO_REDACT)

View File

@@ -181,8 +181,7 @@ async def google_generative_ai_config_option_schema(
schema = { schema = {
vol.Optional( vol.Optional(
CONF_PROMPT, CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)}, description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(), ): TemplateSelector(),
vol.Optional( vol.Optional(
CONF_LLM_HASS_API, CONF_LLM_HASS_API,

View File

@@ -22,4 +22,4 @@ CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold" CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold" CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold" CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_LOW_AND_ABOVE" RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"

View File

@@ -5,13 +5,14 @@ from __future__ import annotations
from typing import Any, Literal from typing import Any, Literal
import google.ai.generativelanguage as glm import google.ai.generativelanguage as glm
from google.api_core.exceptions import ClientError from google.api_core.exceptions import GoogleAPICallError
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.types as genai_types import google.generativeai.types as genai_types
import voluptuous as vol import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -205,15 +206,6 @@ class GoogleGenerativeAIConversationEntity(
messages = [{}, {}] messages = [{}, {}]
try: try:
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api: if llm_api:
empty_tool_input = llm.ToolInput( empty_tool_input = llm.ToolInput(
tool_name="", tool_name="",
@@ -226,9 +218,24 @@ class GoogleGenerativeAIConversationEntity(
device_id=user_input.device_id, device_id=user_input.device_id,
) )
prompt = ( api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
else:
api_prompt = llm.PROMPT_NO_API_CONFIGURED
prompt = "\n".join(
(
template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
),
api_prompt,
) )
)
except TemplateError as err: except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err) LOGGER.error("Error rendering prompt: %s", err)
@@ -244,6 +251,9 @@ class GoogleGenerativeAIConversationEntity(
messages[1] = {"role": "model", "parts": "Ok"} messages[1] = {"role": "model", "parts": "Ok"}
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
chat = model.start_chat(history=messages) chat = model.start_chat(history=messages)
chat_request = user_input.text chat_request = user_input.text
@@ -252,15 +262,25 @@ class GoogleGenerativeAIConversationEntity(
try: try:
chat_response = await chat.send_message_async(chat_request) chat_response = await chat.send_message_async(chat_request)
except ( except (
ClientError, GoogleAPICallError,
ValueError, ValueError,
genai_types.BlockedPromptException, genai_types.BlockedPromptException,
genai_types.StopCandidateException, genai_types.StopCandidateException,
) as err: ) as err:
LOGGER.error("Error sending message: %s", err) LOGGER.error("Error sending message: %s %s", type(err), err)
if isinstance(
err, genai_types.StopCandidateException
) and "finish_reason: SAFETY\n" in str(err):
error = "The message got blocked by your safety settings"
else:
error = (
f"Sorry, I had a problem talking to Google Generative AI: {err}"
)
intent_response.async_set_error( intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to Google Generative AI: {err}", error,
) )
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id

View File

@@ -1,6 +1,6 @@
{ {
"domain": "integration", "domain": "integration",
"name": "Integration - Riemann sum integral", "name": "Integral",
"after_dependencies": ["counter"], "after_dependencies": ["counter"],
"codeowners": ["@dgomes"], "codeowners": ["@dgomes"],
"config_flow": true, "config_flow": true,

View File

@@ -1,5 +1,5 @@
{ {
"title": "Integration - Riemann sum integral sensor", "title": "Integral sensor",
"config": { "config": {
"step": { "step": {
"user": { "user": {

View File

@@ -21,7 +21,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
CONF_VALIDATOR = "validator" CONF_VALIDATOR = "validator"
CONF_SECRET = "secret" CONF_SECRET = "secret"
URL = "/api/meraki" URL = "/api/meraki"
VERSION = "2.0" ACCEPTED_VERSIONS = ["2.0", "2.1"]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -74,7 +74,7 @@ class MerakiView(HomeAssistantView):
if data["secret"] != self.secret: if data["secret"] != self.secret:
_LOGGER.error("Invalid Secret received from Meraki") _LOGGER.error("Invalid Secret received from Meraki")
return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY) return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY)
if data["version"] != VERSION: if data["version"] not in ACCEPTED_VERSIONS:
_LOGGER.error("Invalid API version: %s", data["version"]) _LOGGER.error("Invalid API version: %s", data["version"])
return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY) return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY)
_LOGGER.debug("Valid Secret") _LOGGER.debug("Valid Secret")

View File

@@ -6,6 +6,6 @@
"documentation": "https://www.home-assistant.io/integrations/minecraft_server", "documentation": "https://www.home-assistant.io/integrations/minecraft_server",
"iot_class": "local_polling", "iot_class": "local_polling",
"loggers": ["dnspython", "mcstatus"], "loggers": ["dnspython", "mcstatus"],
"quality_scale": "gold", "quality_scale": "platinum",
"requirements": ["mcstatus==11.1.1"] "requirements": ["mcstatus==11.1.1"]
} }

View File

@@ -39,6 +39,7 @@ from .client import ( # noqa: F401
MQTT, MQTT,
async_publish, async_publish,
async_subscribe, async_subscribe,
async_subscribe_internal,
publish, publish,
subscribe, subscribe,
) )
@@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
def collect_msg(msg: ReceiveMessage) -> None: def collect_msg(msg: ReceiveMessage) -> None:
messages.append((msg.topic, str(msg.payload).replace("\n", ""))) messages.append((msg.topic, str(msg.payload).replace("\n", "")))
unsub = await async_subscribe(hass, call.data["topic"], collect_msg) unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)
def write_dump() -> None: def write_dump() -> None:
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp: with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
@@ -459,7 +460,7 @@ async def websocket_subscribe(
# Perform UTF-8 decoding directly in callback routine # Perform UTF-8 decoding directly in callback routine
qos: int = msg.get("qos", DEFAULT_QOS) qos: int = msg.get("qos", DEFAULT_QOS)
connection.subscriptions[msg["id"]] = await async_subscribe( connection.subscriptions[msg["id"]] = async_subscribe_internal(
hass, msg["topic"], forward_messages, encoding=None, qos=qos hass, msg["topic"], forward_messages, encoding=None, qos=qos
) )
@@ -522,24 +523,13 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
mqtt_client = mqtt_data.client mqtt_client = mqtt_data.client
# Unload publish and dump services. # Unload publish and dump services.
hass.services.async_remove( hass.services.async_remove(DOMAIN, SERVICE_PUBLISH)
DOMAIN, hass.services.async_remove(DOMAIN, SERVICE_DUMP)
SERVICE_PUBLISH,
)
hass.services.async_remove(
DOMAIN,
SERVICE_DUMP,
)
# Stop the discovery # Stop the discovery
await discovery.async_stop(hass) await discovery.async_stop(hass)
# Unload the platforms # Unload the platforms
await asyncio.gather( await hass.config_entries.async_unload_platforms(entry, mqtt_data.platforms_loaded)
*(
hass.config_entries.async_forward_entry_unload(entry, component)
for component in mqtt_data.platforms_loaded
)
)
mqtt_data.platforms_loaded = set() mqtt_data.platforms_loaded = set()
await asyncio.sleep(0) await asyncio.sleep(0)
# Unsubscribe reload dispatchers # Unsubscribe reload dispatchers

View File

@@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_alarm_disarm(self, code: str | None = None) -> None: async def async_alarm_disarm(self, code: str | None = None) -> None:
"""Send disarm command. """Send disarm command.

View File

@@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@callback @callback
def _value_is_expired(self, *_: Any) -> None: def _value_is_expired(self, *_: Any) -> None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from base64 import b64decode from base64 import b64decode
from functools import partial
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -20,7 +21,6 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_QOS, CONF_TOPIC from .const import CONF_QOS, CONF_TOPIC
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ReceiveMessage from .models import ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@@ -97,27 +97,31 @@ class MqttCamera(MqttEntity, Camera):
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
@callback
def _image_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config[CONF_TOPIC], "topic": self._config[CONF_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._image_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": None, "encoding": None,
} }
@@ -126,7 +130,7 @@ class MqttCamera(MqttEntity, Camera):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_camera_image( async def async_camera_image(
self, width: int | None = None, height: int | None = None self, width: int | None = None, height: int | None = None

View File

@@ -77,7 +77,6 @@ from .const import (
) )
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
AsyncMessageCallbackType,
MessageCallbackType, MessageCallbackType,
MqttData, MqttData,
PublishMessage, PublishMessage,
@@ -184,7 +183,7 @@ async def async_publish(
async def async_subscribe( async def async_subscribe(
hass: HomeAssistant, hass: HomeAssistant,
topic: str, topic: str,
msg_callback: AsyncMessageCallbackType | MessageCallbackType, msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS, qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING, encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
@@ -192,13 +191,25 @@ async def async_subscribe(
Call the return value to unsubscribe. Call the return value to unsubscribe.
""" """
if not mqtt_config_entry_enabled(hass): return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
translation_key="mqtt_not_setup_cannot_subscribe", @callback
translation_domain=DOMAIN, def async_subscribe_internal(
translation_placeholders={"topic": topic}, hass: HomeAssistant,
) topic: str,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic.
This function is internal to the MQTT integration
and may change at any time. It should not be considered
a stable API.
Call the return value to unsubscribe.
"""
try: try:
mqtt_data = hass.data[DATA_MQTT] mqtt_data = hass.data[DATA_MQTT]
except KeyError as exc: except KeyError as exc:
@@ -209,12 +220,15 @@ async def async_subscribe(
translation_domain=DOMAIN, translation_domain=DOMAIN,
translation_placeholders={"topic": topic}, translation_placeholders={"topic": topic},
) from exc ) from exc
return await mqtt_data.client.async_subscribe( client = mqtt_data.client
topic, if not client.connected and not mqtt_config_entry_enabled(hass):
msg_callback, raise HomeAssistantError(
qos, f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
encoding, translation_key="mqtt_not_setup_cannot_subscribe",
) translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
)
return client.async_subscribe(topic, msg_callback, qos, encoding)
@bind_hass @bind_hass
@@ -429,10 +443,10 @@ class MQTT:
self.config_entry = config_entry self.config_entry = config_entry
self.conf = conf self.conf = conf
self._simple_subscriptions: defaultdict[str, list[Subscription]] = defaultdict( self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
list set
) )
self._wildcard_subscriptions: list[Subscription] = [] self._wildcard_subscriptions: set[Subscription] = set()
# _retained_topics prevents a Subscription from receiving a # _retained_topics prevents a Subscription from receiving a
# retained message more than once per topic. This prevents flooding # retained message more than once per topic. This prevents flooding
# already active subscribers when new subscribers subscribe to a topic # already active subscribers when new subscribers subscribe to a topic
@@ -452,7 +466,7 @@ class MQTT:
self._should_reconnect: bool = True self._should_reconnect: bool = True
self._available_future: asyncio.Future[bool] | None = None self._available_future: asyncio.Future[bool] | None = None
self._max_qos: dict[str, int] = {} # topic, max qos self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos
self._pending_subscriptions: dict[str, int] = {} # topic, qos self._pending_subscriptions: dict[str, int] = {} # topic, qos
self._unsubscribe_debouncer = EnsureJobAfterCooldown( self._unsubscribe_debouncer = EnsureJobAfterCooldown(
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
@@ -789,9 +803,9 @@ class MQTT:
The caller is responsible clearing the cache of _matching_subscriptions. The caller is responsible clearing the cache of _matching_subscriptions.
""" """
if subscription.is_simple_match: if subscription.is_simple_match:
self._simple_subscriptions[subscription.topic].append(subscription) self._simple_subscriptions[subscription.topic].add(subscription)
else: else:
self._wildcard_subscriptions.append(subscription) self._wildcard_subscriptions.add(subscription)
@callback @callback
def _async_untrack_subscription(self, subscription: Subscription) -> None: def _async_untrack_subscription(self, subscription: Subscription) -> None:
@@ -820,8 +834,8 @@ class MQTT:
"""Queue requested subscriptions.""" """Queue requested subscriptions."""
for subscription in subscriptions: for subscription in subscriptions:
topic, qos = subscription topic, qos = subscription
max_qos = max(qos, self._max_qos.setdefault(topic, qos)) if (max_qos := self._max_qos[topic]) < qos:
self._max_qos[topic] = max_qos self._max_qos[topic] = (max_qos := qos)
self._pending_subscriptions[topic] = max_qos self._pending_subscriptions[topic] = max_qos
# Cancel any pending unsubscribe since we are subscribing now # Cancel any pending unsubscribe since we are subscribing now
if topic in self._pending_unsubscribes: if topic in self._pending_unsubscribes:
@@ -832,26 +846,29 @@ class MQTT:
def _exception_message( def _exception_message(
self, self,
msg_callback: AsyncMessageCallbackType | MessageCallbackType, msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
msg: ReceiveMessage, msg: ReceiveMessage,
) -> str: ) -> str:
"""Return a string with the exception message.""" """Return a string with the exception message."""
# if msg_callback is a partial we return the name of the first argument
if isinstance(msg_callback, partial):
call_back_name = getattr(msg_callback.args[0], "__name__") # type: ignore[unreachable]
else:
call_back_name = getattr(msg_callback, "__name__")
return ( return (
f"Exception in {msg_callback.__name__} when handling msg on " f"Exception in {call_back_name} when handling msg on "
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe] f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
) )
async def async_subscribe( @callback
def async_subscribe(
self, self,
topic: str, topic: str,
msg_callback: AsyncMessageCallbackType | MessageCallbackType, msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int, qos: int,
encoding: str | None = None, encoding: str | None = None,
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Set up a subscription to a topic with the provided qos. """Set up a subscription to a topic with the provided qos."""
This method is a coroutine.
"""
if not isinstance(topic, str): if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!") raise HomeAssistantError("Topic needs to be a string!")
@@ -877,18 +894,18 @@ class MQTT:
if self.connected: if self.connected:
self._async_queue_subscriptions(((topic, qos),)) self._async_queue_subscriptions(((topic, qos),))
@callback return partial(self._async_remove, subscription)
def async_remove() -> None:
"""Remove subscription."""
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
if subscription in self._retained_topics:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(topic)
return async_remove @callback
def _async_remove(self, subscription: Subscription) -> None:
"""Remove subscription."""
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
if subscription in self._retained_topics:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(subscription.topic)
@callback @callback
def _async_unsubscribe(self, topic: str) -> None: def _async_unsubscribe(self, topic: str) -> None:
@@ -1257,9 +1274,7 @@ class MQTT:
last_discovery = self._mqtt_data.last_discovery last_discovery = self._mqtt_data.last_discovery
last_subscribe = now if self._pending_subscriptions else self._last_subscribe last_subscribe = now if self._pending_subscriptions else self._last_subscribe
wait_until = max( wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
while now < wait_until: while now < wait_until:
await asyncio.sleep(wait_until - now) await asyncio.sleep(wait_until - now)
now = time.monotonic() now = time.monotonic()
@@ -1267,9 +1282,7 @@ class MQTT:
last_subscribe = ( last_subscribe = (
now if self._pending_subscriptions else self._last_subscribe now if self._pending_subscriptions else self._last_subscribe
) )
wait_until = max( wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
def _matcher_for_topic(subscription: str) -> Callable[[str], bool]: def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:

View File

@@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def _publish(self, topic: str, payload: PublishPayloadType) -> None: async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
if self._topic[topic] is not None: if self._topic[topic] is not None:

View File

@@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_open_cover(self, **kwargs: Any) -> None: async def async_open_cover(self, **kwargs: Any) -> None:
"""Move the cover up. """Move the cover up.

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -32,13 +33,7 @@ from homeassistant.helpers.typing import ConfigType
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC
from .debug_info import log_messages from .mixins import CONF_JSON_ATTRS_TOPIC, MqttEntity, async_setup_entity_entry_helper
from .mixins import (
CONF_JSON_ATTRS_TOPIC,
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_subscribe_topic from .util import valid_subscribe_topic
@@ -119,33 +114,31 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
config.get(CONF_VALUE_TEMPLATE), entity=self config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
def _tracker_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload == self._config[CONF_PAYLOAD_HOME]:
self._location_name = STATE_HOME
elif payload == self._config[CONF_PAYLOAD_NOT_HOME]:
self._location_name = STATE_NOT_HOME
elif payload == self._config[CONF_PAYLOAD_RESET]:
self._location_name = None
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, str)
self._location_name = msg.payload
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_location_name"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload == self._config[CONF_PAYLOAD_HOME]:
self._location_name = STATE_HOME
elif payload == self._config[CONF_PAYLOAD_NOT_HOME]:
self._location_name = STATE_NOT_HOME
elif payload == self._config[CONF_PAYLOAD_RESET]:
self._location_name = None
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, str)
self._location_name = msg.payload
state_topic: str | None = self._config.get(CONF_STATE_TOPIC) state_topic: str | None = self._config.get(CONF_STATE_TOPIC)
if state_topic is None: if state_topic is None:
return return
@@ -155,7 +148,12 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
{ {
"state_topic": { "state_topic": {
"topic": state_topic, "topic": state_topic,
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._tracker_message_received,
{"_location_name"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
} }
}, },
@@ -168,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property @property
def latitude(self) -> float | None: def latitude(self) -> float | None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any from typing import Any
@@ -31,7 +32,6 @@ from .const import (
PAYLOAD_EMPTY_JSON, PAYLOAD_EMPTY_JSON,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
@@ -113,90 +113,91 @@ class MqttEvent(MqttEntity, EventEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
def _event_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if msg.retain:
_LOGGER.debug(
"Ignoring event trigger from replayed retained payload '%s' on topic %s",
msg.payload,
msg.topic,
)
return
event_attributes: dict[str, Any] = {}
event_type: str
try:
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
if (
not payload
or payload is PayloadSentinel.DEFAULT
or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON)
):
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
try:
event_attributes = json_loads_object(payload)
event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE))
_LOGGER.debug(
(
"JSON event data detected after processing payload '%s' on"
" topic %s, type %s, attributes %s"
),
payload,
msg.topic,
event_type,
event_attributes,
)
except KeyError:
_LOGGER.warning(
("`event_type` missing in JSON event payload, " " '%s' on topic %s"),
payload,
msg.topic,
)
return
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid JSON event payload detected, "
"value after processing payload"
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
try:
self._trigger_event(event_type, event_attributes)
except ValueError:
_LOGGER.warning(
"Invalid event type %s for %s received on topic %s, payload %s",
event_type,
self.entity_id,
msg.topic,
payload,
)
return
mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self)
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {} topics: dict[str, dict[str, Any]] = {}
@callback
@log_messages(self.hass, self.entity_id)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
if msg.retain:
_LOGGER.debug(
"Ignoring event trigger from replayed retained payload '%s' on topic %s",
msg.payload,
msg.topic,
)
return
event_attributes: dict[str, Any] = {}
event_type: str
try:
payload = self._template(msg.payload, PayloadSentinel.DEFAULT)
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
if (
not payload
or payload is PayloadSentinel.DEFAULT
or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON)
):
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
try:
event_attributes = json_loads_object(payload)
event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE))
_LOGGER.debug(
(
"JSON event data detected after processing payload '%s' on"
" topic %s, type %s, attributes %s"
),
payload,
msg.topic,
event_type,
event_attributes,
)
except KeyError:
_LOGGER.warning(
(
"`event_type` missing in JSON event payload, "
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid JSON event payload detected, "
"value after processing payload"
" '%s' on topic %s"
),
payload,
msg.topic,
)
return
try:
self._trigger_event(event_type, event_attributes)
except ValueError:
_LOGGER.warning(
"Invalid event type %s for %s received on topic %s, payload %s",
event_type,
self.entity_id,
msg.topic,
payload,
)
return
mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self)
topics["state_topic"] = { topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC], "topic": self._config[CONF_STATE_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._event_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -207,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import math import math
from typing import Any from typing import Any
@@ -49,12 +50,7 @@ from .const import (
CONF_STATE_VALUE_TEMPLATE, CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MessageCallbackType, MessageCallbackType,
MqttCommandTemplate, MqttCommandTemplate,
@@ -338,137 +334,142 @@ class MqttFan(MqttEntity, FanEntity):
for key, tpl in value_templates.items() for key, tpl in value_templates.items()
} }
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
return
if payload == self._payload["STATE_ON"]:
self._attr_is_on = True
elif payload == self._payload["STATE_OFF"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
@callback
def _percentage_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the percentage."""
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
msg.payload
)
if not rendered_percentage_payload:
_LOGGER.debug("Ignoring empty speed from '%s'", msg.topic)
return
if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]:
self._attr_percentage = None
return
try:
percentage = ranged_value_to_percentage(
self._speed_range, int(rendered_percentage_payload)
)
except ValueError:
_LOGGER.warning(
(
"'%s' received on topic %s. '%s' is not a valid speed within"
" the speed range"
),
msg.payload,
msg.topic,
rendered_percentage_payload,
)
return
if percentage < 0 or percentage > 100:
_LOGGER.warning(
(
"'%s' received on topic %s. '%s' is not a valid speed within"
" the speed range"
),
msg.payload,
msg.topic,
rendered_percentage_payload,
)
return
self._attr_percentage = percentage
@callback
def _preset_mode_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for preset mode."""
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
if preset_mode == self._payload["PRESET_MODE_RESET"]:
self._attr_preset_mode = None
return
if not preset_mode:
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
return
if not self.preset_modes or preset_mode not in self.preset_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid preset mode",
msg.payload,
msg.topic,
preset_mode,
)
return
self._attr_preset_mode = preset_mode
@callback
def _oscillation_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the oscillation."""
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty oscillation from '%s'", msg.topic)
return
if payload == self._payload["OSCILLATE_ON_PAYLOAD"]:
self._attr_oscillating = True
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
self._attr_oscillating = False
@callback
def _direction_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the direction."""
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
if not direction:
_LOGGER.debug("Ignoring empty direction from '%s'", msg.topic)
return
self._attr_current_direction = str(direction)
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, Any] = {} topics: dict[str, Any] = {}
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool: def add_subscribe_topic(
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
) -> bool:
"""Add a topic to subscribe to.""" """Add a topic to subscribe to."""
if has_topic := self._topic[topic] is not None: if has_topic := self._topic[topic] is not None:
topics[topic] = { topics[topic] = {
"topic": self._topic[topic], "topic": self._topic[topic],
"msg_callback": msg_callback, "msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
return has_topic return has_topic
@callback add_subscribe_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
@log_messages(self.hass, self.entity_id) add_subscribe_topic(
@write_state_on_attr_change(self, {"_attr_is_on"}) CONF_PERCENTAGE_STATE_TOPIC, self._percentage_received, {"_attr_percentage"}
def state_received(msg: ReceiveMessage) -> None: )
"""Handle new received MQTT message.""" add_subscribe_topic(
payload = self._value_templates[CONF_STATE](msg.payload) CONF_PRESET_MODE_STATE_TOPIC,
if not payload: self._preset_mode_received,
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic) {"_attr_preset_mode"},
return )
if payload == self._payload["STATE_ON"]: if add_subscribe_topic(
self._attr_is_on = True CONF_OSCILLATION_STATE_TOPIC,
elif payload == self._payload["STATE_OFF"]: self._oscillation_received,
self._attr_is_on = False {"_attr_oscillating"},
elif payload == PAYLOAD_NONE: ):
self._attr_is_on = None
add_subscribe_topic(CONF_STATE_TOPIC, state_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_percentage"})
def percentage_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the percentage."""
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
msg.payload
)
if not rendered_percentage_payload:
_LOGGER.debug("Ignoring empty speed from '%s'", msg.topic)
return
if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]:
self._attr_percentage = None
return
try:
percentage = ranged_value_to_percentage(
self._speed_range, int(rendered_percentage_payload)
)
except ValueError:
_LOGGER.warning(
(
"'%s' received on topic %s. '%s' is not a valid speed within"
" the speed range"
),
msg.payload,
msg.topic,
rendered_percentage_payload,
)
return
if percentage < 0 or percentage > 100:
_LOGGER.warning(
(
"'%s' received on topic %s. '%s' is not a valid speed within"
" the speed range"
),
msg.payload,
msg.topic,
rendered_percentage_payload,
)
return
self._attr_percentage = percentage
add_subscribe_topic(CONF_PERCENTAGE_STATE_TOPIC, percentage_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_preset_mode"})
def preset_mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for preset mode."""
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
if preset_mode == self._payload["PRESET_MODE_RESET"]:
self._attr_preset_mode = None
return
if not preset_mode:
_LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic)
return
if not self.preset_modes or preset_mode not in self.preset_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid preset mode",
msg.payload,
msg.topic,
preset_mode,
)
return
self._attr_preset_mode = preset_mode
add_subscribe_topic(CONF_PRESET_MODE_STATE_TOPIC, preset_mode_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_oscillating"})
def oscillation_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the oscillation."""
payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty oscillation from '%s'", msg.topic)
return
if payload == self._payload["OSCILLATE_ON_PAYLOAD"]:
self._attr_oscillating = True
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
self._attr_oscillating = False
if add_subscribe_topic(CONF_OSCILLATION_STATE_TOPIC, oscillation_received):
self._attr_oscillating = False self._attr_oscillating = False
add_subscribe_topic(
@callback CONF_DIRECTION_STATE_TOPIC,
@log_messages(self.hass, self.entity_id) self._direction_received,
@write_state_on_attr_change(self, {"_attr_current_direction"}) {"_attr_current_direction"},
def direction_received(msg: ReceiveMessage) -> None: )
"""Handle new received MQTT message for the direction."""
direction = self._value_templates[ATTR_DIRECTION](msg.payload)
if not direction:
_LOGGER.debug("Ignoring empty direction from '%s'", msg.topic)
return
self._attr_current_direction = str(direction)
add_subscribe_topic(CONF_DIRECTION_STATE_TOPIC, direction_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -476,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property @property
def is_on(self) -> bool | None: def is_on(self) -> bool | None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any from typing import Any
@@ -51,12 +52,7 @@ from .const import (
CONF_STATE_VALUE_TEMPLATE, CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -284,164 +280,166 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
topics: dict[str, dict[str, Any]], topics: dict[str, dict[str, Any]],
topic: str, topic: str,
msg_callback: Callable[[ReceiveMessage], None], msg_callback: Callable[[ReceiveMessage], None],
tracked_attributes: set[str],
) -> None: ) -> None:
"""Add a subscription.""" """Add a subscription."""
qos: int = self._config[CONF_QOS] qos: int = self._config[CONF_QOS]
if topic in self._topic and self._topic[topic] is not None: if topic in self._topic and self._topic[topic] is not None:
topics[topic] = { topics[topic] = {
"topic": self._topic[topic], "topic": self._topic[topic],
"msg_callback": msg_callback, "msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": qos, "qos": qos,
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
return
if payload == self._payload["STATE_ON"]:
self._attr_is_on = True
elif payload == self._payload["STATE_OFF"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
@callback
def _action_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
action_payload = self._value_templates[ATTR_ACTION](msg.payload)
if not action_payload or action_payload == PAYLOAD_NONE:
_LOGGER.debug("Ignoring empty action from '%s'", msg.topic)
return
try:
self._attr_action = HumidifierAction(str(action_payload))
except ValueError:
_LOGGER.error(
"'%s' received on topic %s. '%s' is not a valid action",
msg.payload,
msg.topic,
action_payload,
)
return
@callback
def _current_humidity_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the current humidity."""
rendered_current_humidity_payload = self._value_templates[
ATTR_CURRENT_HUMIDITY
](msg.payload)
if rendered_current_humidity_payload == self._payload["HUMIDITY_RESET"]:
self._attr_current_humidity = None
return
if not rendered_current_humidity_payload:
_LOGGER.debug("Ignoring empty current humidity from '%s'", msg.topic)
return
try:
current_humidity = round(float(rendered_current_humidity_payload))
except ValueError:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid humidity",
msg.payload,
msg.topic,
rendered_current_humidity_payload,
)
return
if current_humidity < 0 or current_humidity > 100:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid humidity",
msg.payload,
msg.topic,
rendered_current_humidity_payload,
)
return
self._attr_current_humidity = current_humidity
@callback
def _target_humidity_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the target humidity."""
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
msg.payload
)
if not rendered_target_humidity_payload:
_LOGGER.debug("Ignoring empty target humidity from '%s'", msg.topic)
return
if rendered_target_humidity_payload == self._payload["HUMIDITY_RESET"]:
self._attr_target_humidity = None
return
try:
target_humidity = round(float(rendered_target_humidity_payload))
except ValueError:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid target humidity",
msg.payload,
msg.topic,
rendered_target_humidity_payload,
)
return
if (
target_humidity < self._attr_min_humidity
or target_humidity > self._attr_max_humidity
):
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid target humidity",
msg.payload,
msg.topic,
rendered_target_humidity_payload,
)
return
self._attr_target_humidity = target_humidity
@callback
def _mode_received(self, msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for mode."""
mode = str(self._value_templates[ATTR_MODE](msg.payload))
if mode == self._payload["MODE_RESET"]:
self._attr_mode = None
return
if not mode:
_LOGGER.debug("Ignoring empty mode from '%s'", msg.topic)
return
if not self.available_modes or mode not in self.available_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid mode",
msg.payload,
msg.topic,
mode,
)
return
self._attr_mode = mode
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, Any] = {} topics: dict[str, Any] = {}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty state from '%s'", msg.topic)
return
if payload == self._payload["STATE_ON"]:
self._attr_is_on = True
elif payload == self._payload["STATE_OFF"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
self.add_subscription(topics, CONF_STATE_TOPIC, state_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_action"})
def action_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message."""
action_payload = self._value_templates[ATTR_ACTION](msg.payload)
if not action_payload or action_payload == PAYLOAD_NONE:
_LOGGER.debug("Ignoring empty action from '%s'", msg.topic)
return
try:
self._attr_action = HumidifierAction(str(action_payload))
except ValueError:
_LOGGER.error(
"'%s' received on topic %s. '%s' is not a valid action",
msg.payload,
msg.topic,
action_payload,
)
return
self.add_subscription(topics, CONF_ACTION_TOPIC, action_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_current_humidity"})
def current_humidity_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the current humidity."""
rendered_current_humidity_payload = self._value_templates[
ATTR_CURRENT_HUMIDITY
](msg.payload)
if rendered_current_humidity_payload == self._payload["HUMIDITY_RESET"]:
self._attr_current_humidity = None
return
if not rendered_current_humidity_payload:
_LOGGER.debug("Ignoring empty current humidity from '%s'", msg.topic)
return
try:
current_humidity = round(float(rendered_current_humidity_payload))
except ValueError:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid humidity",
msg.payload,
msg.topic,
rendered_current_humidity_payload,
)
return
if current_humidity < 0 or current_humidity > 100:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid humidity",
msg.payload,
msg.topic,
rendered_current_humidity_payload,
)
return
self._attr_current_humidity = current_humidity
self.add_subscription( self.add_subscription(
topics, CONF_CURRENT_HUMIDITY_TOPIC, current_humidity_received topics, CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"}
) )
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_target_humidity"})
def target_humidity_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the target humidity."""
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
msg.payload
)
if not rendered_target_humidity_payload:
_LOGGER.debug("Ignoring empty target humidity from '%s'", msg.topic)
return
if rendered_target_humidity_payload == self._payload["HUMIDITY_RESET"]:
self._attr_target_humidity = None
return
try:
target_humidity = round(float(rendered_target_humidity_payload))
except ValueError:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid target humidity",
msg.payload,
msg.topic,
rendered_target_humidity_payload,
)
return
if (
target_humidity < self._attr_min_humidity
or target_humidity > self._attr_max_humidity
):
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid target humidity",
msg.payload,
msg.topic,
rendered_target_humidity_payload,
)
return
self._attr_target_humidity = target_humidity
self.add_subscription( self.add_subscription(
topics, CONF_TARGET_HUMIDITY_STATE_TOPIC, target_humidity_received topics, CONF_ACTION_TOPIC, self._action_received, {"_attr_action"}
)
self.add_subscription(
topics,
CONF_CURRENT_HUMIDITY_TOPIC,
self._current_humidity_received,
{"_attr_current_humidity"},
)
self.add_subscription(
topics,
CONF_TARGET_HUMIDITY_STATE_TOPIC,
self._target_humidity_received,
{"_attr_target_humidity"},
)
self.add_subscription(
topics, CONF_MODE_STATE_TOPIC, self._mode_received, {"_attr_mode"}
) )
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_mode"})
def mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for mode."""
mode = str(self._value_templates[ATTR_MODE](msg.payload))
if mode == self._payload["MODE_RESET"]:
self._attr_mode = None
return
if not mode:
_LOGGER.debug("Ignoring empty mode from '%s'", msg.topic)
return
if not self.available_modes or mode not in self.available_modes:
_LOGGER.warning(
"'%s' received on topic %s. '%s' is not a valid mode",
msg.payload,
msg.topic,
mode,
)
return
self._attr_mode = mode
self.add_subscription(topics, CONF_MODE_STATE_TOPIC, mode_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -449,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on the entity. """Turn on the entity.

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from base64 import b64decode from base64 import b64decode
import binascii import binascii
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -26,7 +27,6 @@ from homeassistant.util import dt as dt_util
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_ENCODING, CONF_QOS from .const import CONF_ENCODING, CONF_QOS
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
@@ -143,6 +143,45 @@ class MqttImage(MqttEntity, ImageEntity):
config.get(CONF_URL_TEMPLATE), entity=self config.get(CONF_URL_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
def _image_data_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
except (binascii.Error, ValueError, AssertionError) as err:
_LOGGER.error(
"Error processing image data received at topic %s: %s",
msg.topic,
err,
)
self._last_image = None
self._attr_image_last_updated = dt_util.utcnow()
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
@callback
def _image_from_url_request_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
url = cv.url(self._url_template(msg.payload))
self._attr_image_url = url
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
except vol.Invalid:
_LOGGER.error(
"Invalid image URL '%s' received at topic %s",
msg.payload,
msg.topic,
)
self._attr_image_last_updated = dt_util.utcnow()
self._cached_image = None
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@@ -159,56 +198,15 @@ class MqttImage(MqttEntity, ImageEntity):
if has_topic := self._topic[topic] is not None: if has_topic := self._topic[topic] is not None:
topics[topic] = { topics[topic] = {
"topic": self._topic[topic], "topic": self._topic[topic],
"msg_callback": msg_callback, "msg_callback": partial(self._message_callback, msg_callback, None),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": encoding, "encoding": encoding,
} }
return has_topic return has_topic
@callback add_subscribe_topic(CONF_IMAGE_TOPIC, self._image_data_received)
@log_messages(self.hass, self.entity_id) add_subscribe_topic(CONF_URL_TOPIC, self._image_from_url_request_received)
def image_data_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload)
else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes)
self._last_image = msg.payload
except (binascii.Error, ValueError, AssertionError) as err:
_LOGGER.error(
"Error processing image data received at topic %s: %s",
msg.topic,
err,
)
self._last_image = None
self._attr_image_last_updated = dt_util.utcnow()
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
@callback
@log_messages(self.hass, self.entity_id)
def image_from_url_request_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
try:
url = cv.url(self._url_template(msg.payload))
self._attr_image_url = url
except MqttValueTemplateException as exc:
_LOGGER.warning(exc)
return
except vol.Invalid:
_LOGGER.error(
"Invalid image URL '%s' received at topic %s",
msg.payload,
msg.topic,
)
self._attr_image_last_updated = dt_util.utcnow()
self._cached_image = None
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -216,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_image(self) -> bytes | None: async def async_image(self) -> bytes | None:
"""Return bytes of image.""" """Return bytes of image."""

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import contextlib import contextlib
from functools import partial
import logging import logging
import voluptuous as vol import voluptuous as vol
@@ -31,12 +32,7 @@ from .const import (
DEFAULT_OPTIMISTIC, DEFAULT_OPTIMISTIC,
DEFAULT_RETAIN, DEFAULT_RETAIN,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -150,57 +146,59 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self
).async_render ).async_render
@callback
def _message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = str(self._value_template(msg.payload))
if not payload:
_LOGGER.debug(
"Invalid empty activity payload from topic %s, for entity %s",
msg.topic,
self.entity_id,
)
return
if payload.lower() == "none":
self._attr_activity = None
return
try:
self._attr_activity = LawnMowerActivity(payload)
except ValueError:
_LOGGER.error(
"Invalid activity for %s: '%s' (valid activities: %s)",
self.entity_id,
payload,
[option.value for option in LawnMowerActivity],
)
return
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_activity"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = str(self._value_template(msg.payload))
if not payload:
_LOGGER.debug(
"Invalid empty activity payload from topic %s, for entity %s",
msg.topic,
self.entity_id,
)
return
if payload.lower() == "none":
self._attr_activity = None
return
try:
self._attr_activity = LawnMowerActivity(payload)
except ValueError:
_LOGGER.error(
"Invalid activity for %s: '%s' (valid activities: %s)",
self.entity_id,
payload,
[option.value for option in LawnMowerActivity],
)
return
if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None: if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._attr_assumed_state = True self._attr_assumed_state = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
CONF_ACTIVITY_STATE_TOPIC: { CONF_ACTIVITY_STATE_TOPIC: {
"topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC), "topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC),
"msg_callback": message_received, "msg_callback": partial(
"qos": self._config[CONF_QOS], self._message_callback,
"encoding": self._config[CONF_ENCODING] or None, self._message_received,
} {"_attr_activity"},
}, ),
) "entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and ( if self._attr_assumed_state and (
last_state := await self.async_get_last_state() last_state := await self.async_get_last_state()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any, cast from typing import Any, cast
@@ -53,8 +54,7 @@ from ..const import (
CONF_STATE_VALUE_TEMPLATE, CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from ..debug_info import log_messages from ..mixins import MqttEntity
from ..mixins import MqttEntity, write_state_on_attr_change
from ..models import ( from ..models import (
MessageCallbackType, MessageCallbackType,
MqttCommandTemplate, MqttCommandTemplate,
@@ -378,263 +378,248 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
attr: bool = getattr(self, f"_optimistic_{attribute}") attr: bool = getattr(self, f"_optimistic_{attribute}")
return attr return attr
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.NONE
)
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
return
if payload == self._payload["on"]:
self._attr_is_on = True
elif payload == self._payload["off"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
@callback
def _brightness_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for the brightness."""
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty brightness message from '%s'", msg.topic)
return
device_value = float(payload)
if device_value == 0:
_LOGGER.debug("Ignoring zero brightness from '%s'", msg.topic)
return
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
self._attr_brightness = min(round(percent_bright * 255), 255)
@callback
def _rgbx_received(
self,
msg: ReceiveMessage,
template: str,
color_mode: ColorMode,
convert_color: Callable[..., tuple[int, ...]],
) -> tuple[int, ...] | None:
"""Process MQTT messages for RGBW and RGBWW."""
payload = self._value_templates[template](msg.payload, PayloadSentinel.DEFAULT)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty %s message from '%s'", color_mode, msg.topic)
return None
color = tuple(int(val) for val in str(payload).split(","))
if self._optimistic_color_mode:
self._attr_color_mode = color_mode
if self._topic[CONF_BRIGHTNESS_STATE_TOPIC] is None:
rgb = convert_color(*color)
brightness = max(rgb)
if brightness == 0:
_LOGGER.debug(
"Ignoring %s message with zero rgb brightness from '%s'",
color_mode,
msg.topic,
)
return None
self._attr_brightness = brightness
# Normalize the color to 100% brightness
color = tuple(
min(round(channel / brightness * 255), 255) for channel in color
)
return color
@callback
def _rgb_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGB."""
rgb = self._rgbx_received(
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x
)
if rgb is None:
return
self._attr_rgb_color = cast(tuple[int, int, int], rgb)
@callback
def _rgbw_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBW."""
rgbw = self._rgbx_received(
msg,
CONF_RGBW_VALUE_TEMPLATE,
ColorMode.RGBW,
color_util.color_rgbw_to_rgb,
)
if rgbw is None:
return
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
@callback
def _rgbww_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBWW."""
@callback
def _converter(
r: int, g: int, b: int, cw: int, ww: int
) -> tuple[int, int, int]:
min_kelvin = color_util.color_temperature_mired_to_kelvin(self.max_mireds)
max_kelvin = color_util.color_temperature_mired_to_kelvin(self.min_mireds)
return color_util.color_rgbww_to_rgb(
r, g, b, cw, ww, min_kelvin, max_kelvin
)
rgbww = self._rgbx_received(
msg,
CONF_RGBWW_VALUE_TEMPLATE,
ColorMode.RGBWW,
_converter,
)
if rgbww is None:
return
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
@callback
def _color_mode_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color mode."""
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty color mode message from '%s'", msg.topic)
return
self._attr_color_mode = ColorMode(str(payload))
@callback
def _color_temp_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color temperature."""
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty color temp message from '%s'", msg.topic)
return
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.COLOR_TEMP
self._attr_color_temp = int(payload)
@callback
def _effect_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for effect."""
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty effect message from '%s'", msg.topic)
return
self._attr_effect = str(payload)
@callback
def _hs_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for hs color."""
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic)
return
try:
hs_color = tuple(float(val) for val in str(payload).split(",", 2))
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.HS
self._attr_hs_color = cast(tuple[float, float], hs_color)
except ValueError:
_LOGGER.warning("Failed to parse hs state update: '%s'", payload)
@callback
def _xy_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for xy color."""
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty xy-color message from '%s'", msg.topic)
return
xy_color = tuple(float(val) for val in str(payload).split(",", 2))
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.XY
self._attr_xy_color = cast(tuple[float, float], xy_color)
def _prepare_subscribe_topics(self) -> None: # noqa: C901 def _prepare_subscribe_topics(self) -> None: # noqa: C901
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {} topics: dict[str, dict[str, Any]] = {}
def add_topic(topic: str, msg_callback: MessageCallbackType) -> None: def add_topic(
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
) -> None:
"""Add a topic.""" """Add a topic."""
if self._topic[topic] is not None: if self._topic[topic] is not None:
topics[topic] = { topics[topic] = {
"topic": self._topic[topic], "topic": self._topic[topic],
"msg_callback": msg_callback, "msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@callback add_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
@log_messages(self.hass, self.entity_id) add_topic(
@write_state_on_attr_change(self, {"_attr_is_on"}) CONF_BRIGHTNESS_STATE_TOPIC, self._brightness_received, {"_attr_brightness"}
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.NONE
)
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
return
if payload == self._payload["on"]:
self._attr_is_on = True
elif payload == self._payload["off"]:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
if self._topic[CONF_STATE_TOPIC] is not None:
topics[CONF_STATE_TOPIC] = {
"topic": self._topic[CONF_STATE_TOPIC],
"msg_callback": state_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_brightness"})
def brightness_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for the brightness."""
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty brightness message from '%s'", msg.topic)
return
device_value = float(payload)
if device_value == 0:
_LOGGER.debug("Ignoring zero brightness from '%s'", msg.topic)
return
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
self._attr_brightness = min(round(percent_bright * 255), 255)
add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received)
@callback
def _rgbx_received(
msg: ReceiveMessage,
template: str,
color_mode: ColorMode,
convert_color: Callable[..., tuple[int, ...]],
) -> tuple[int, ...] | None:
"""Handle new MQTT messages for RGBW and RGBWW."""
payload = self._value_templates[template](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug(
"Ignoring empty %s message from '%s'", color_mode, msg.topic
)
return None
color = tuple(int(val) for val in str(payload).split(","))
if self._optimistic_color_mode:
self._attr_color_mode = color_mode
if self._topic[CONF_BRIGHTNESS_STATE_TOPIC] is None:
rgb = convert_color(*color)
brightness = max(rgb)
if brightness == 0:
_LOGGER.debug(
"Ignoring %s message with zero rgb brightness from '%s'",
color_mode,
msg.topic,
)
return None
self._attr_brightness = brightness
# Normalize the color to 100% brightness
color = tuple(
min(round(channel / brightness * 255), 255) for channel in color
)
return color
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"}
) )
def rgb_received(msg: ReceiveMessage) -> None: add_topic(
"""Handle new MQTT messages for RGB.""" CONF_RGB_STATE_TOPIC,
rgb = _rgbx_received( self._rgb_received,
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"},
)
if rgb is None:
return
self._attr_rgb_color = cast(tuple[int, int, int], rgb)
add_topic(CONF_RGB_STATE_TOPIC, rgb_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"}
) )
def rgbw_received(msg: ReceiveMessage) -> None: add_topic(
"""Handle new MQTT messages for RGBW.""" CONF_RGBW_STATE_TOPIC,
rgbw = _rgbx_received( self._rgbw_received,
msg, {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"},
CONF_RGBW_VALUE_TEMPLATE, )
ColorMode.RGBW, add_topic(
color_util.color_rgbw_to_rgb, CONF_RGBWW_STATE_TOPIC,
) self._rgbww_received,
if rgbw is None: {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"},
return )
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw) add_topic(
CONF_COLOR_MODE_STATE_TOPIC, self._color_mode_received, {"_attr_color_mode"}
add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received) )
add_topic(
@callback CONF_COLOR_TEMP_STATE_TOPIC,
@log_messages(self.hass, self.entity_id) self._color_temp_received,
@write_state_on_attr_change( {"_attr_color_mode", "_attr_color_temp"},
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"} )
add_topic(CONF_EFFECT_STATE_TOPIC, self._effect_received, {"_attr_effect"})
add_topic(
CONF_HS_STATE_TOPIC,
self._hs_received,
{"_attr_color_mode", "_attr_hs_color"},
)
add_topic(
CONF_XY_STATE_TOPIC,
self._xy_received,
{"_attr_color_mode", "_attr_xy_color"},
) )
def rgbww_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBWW."""
@callback
def _converter(
r: int, g: int, b: int, cw: int, ww: int
) -> tuple[int, int, int]:
min_kelvin = color_util.color_temperature_mired_to_kelvin(
self.max_mireds
)
max_kelvin = color_util.color_temperature_mired_to_kelvin(
self.min_mireds
)
return color_util.color_rgbww_to_rgb(
r, g, b, cw, ww, min_kelvin, max_kelvin
)
rgbww = _rgbx_received(
msg,
CONF_RGBWW_VALUE_TEMPLATE,
ColorMode.RGBWW,
_converter,
)
if rgbww is None:
return
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_color_mode"})
def color_mode_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color mode."""
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty color mode message from '%s'", msg.topic)
return
self._attr_color_mode = ColorMode(str(payload))
add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_color_temp"})
def color_temp_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color temperature."""
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty color temp message from '%s'", msg.topic)
return
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.COLOR_TEMP
self._attr_color_temp = int(payload)
add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_effect"})
def effect_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for effect."""
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty effect message from '%s'", msg.topic)
return
self._attr_effect = str(payload)
add_topic(CONF_EFFECT_STATE_TOPIC, effect_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_hs_color"})
def hs_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for hs color."""
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic)
return
try:
hs_color = tuple(float(val) for val in str(payload).split(",", 2))
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.HS
self._attr_hs_color = cast(tuple[float, float], hs_color)
except ValueError:
_LOGGER.warning("Failed to parse hs state update: '%s'", payload)
add_topic(CONF_HS_STATE_TOPIC, hs_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_xy_color"})
def xy_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for xy color."""
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug("Ignoring empty xy-color message from '%s'", msg.topic)
return
xy_color = tuple(float(val) for val in str(payload).split(",", 2))
if self._optimistic_color_mode:
self._attr_color_mode = ColorMode.XY
self._attr_xy_color = cast(tuple[float, float], xy_color)
add_topic(CONF_XY_STATE_TOPIC, xy_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -642,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state() last_state = await self.async_get_last_state()
def restore_state( def restore_state(

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
@@ -66,8 +67,7 @@ from ..const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
DOMAIN as MQTT_DOMAIN, DOMAIN as MQTT_DOMAIN,
) )
from ..debug_info import log_messages from ..mixins import MqttEntity
from ..mixins import MqttEntity, write_state_on_attr_change
from ..models import ReceiveMessage from ..models import ReceiveMessage
from ..schemas import MQTT_ENTITY_COMMON_SCHEMA from ..schemas import MQTT_ENTITY_COMMON_SCHEMA
from ..util import valid_subscribe_topic from ..util import valid_subscribe_topic
@@ -414,118 +414,121 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
self.entity_id, self.entity_id,
) )
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
values = json_loads_object(msg.payload)
if values["state"] == "ON":
self._attr_is_on = True
elif values["state"] == "OFF":
self._attr_is_on = False
elif values["state"] is None:
self._attr_is_on = None
if (
self._deprecated_color_handling
and color_supported(self.supported_color_modes)
and "color" in values
):
# Deprecated color handling
if values["color"] is None:
self._attr_hs_color = None
else:
self._update_color(values)
if not self._deprecated_color_handling and "color_mode" in values:
self._update_color(values)
if brightness_supported(self.supported_color_modes):
try:
if brightness := values["brightness"]:
if TYPE_CHECKING:
assert isinstance(brightness, float)
self._attr_brightness = color_util.value_to_brightness(
(1, self._config[CONF_BRIGHTNESS_SCALE]), brightness
)
else:
_LOGGER.debug(
"Ignoring zero brightness value for entity %s",
self.entity_id,
)
except KeyError:
pass
except (TypeError, ValueError):
_LOGGER.warning(
"Invalid brightness value '%s' received for entity %s",
values["brightness"],
self.entity_id,
)
if (
self._deprecated_color_handling
and self.supported_color_modes
and ColorMode.COLOR_TEMP in self.supported_color_modes
):
# Deprecated color handling
try:
if values["color_temp"] is None:
self._attr_color_temp = None
else:
self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type]
except KeyError:
pass
except ValueError:
_LOGGER.warning(
"Invalid color temp value '%s' received for entity %s",
values["color_temp"],
self.entity_id,
)
# Allow to switch back to color_temp
if "color" not in values:
self._attr_hs_color = None
if self.supported_features and LightEntityFeature.EFFECT:
with suppress(KeyError):
self._attr_effect = cast(str, values["effect"])
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback #
@log_messages(self.hass, self.entity_id) if self._topic[CONF_STATE_TOPIC] is None:
@write_state_on_attr_change( return
self,
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{ {
"_attr_brightness", CONF_STATE_TOPIC: {
"_attr_color_temp", "topic": self._topic[CONF_STATE_TOPIC],
"_attr_effect", "msg_callback": partial(
"_attr_hs_color", self._message_callback,
"_attr_is_on", self._state_received,
"_attr_rgb_color", {
"_attr_rgbw_color", "_attr_brightness",
"_attr_rgbww_color", "_attr_color_temp",
"_attr_xy_color", "_attr_effect",
"color_mode", "_attr_hs_color",
"_attr_is_on",
"_attr_rgb_color",
"_attr_rgbw_color",
"_attr_rgbww_color",
"_attr_xy_color",
"color_mode",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
}, },
) )
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
values = json_loads_object(msg.payload)
if values["state"] == "ON":
self._attr_is_on = True
elif values["state"] == "OFF":
self._attr_is_on = False
elif values["state"] is None:
self._attr_is_on = None
if (
self._deprecated_color_handling
and color_supported(self.supported_color_modes)
and "color" in values
):
# Deprecated color handling
if values["color"] is None:
self._attr_hs_color = None
else:
self._update_color(values)
if not self._deprecated_color_handling and "color_mode" in values:
self._update_color(values)
if brightness_supported(self.supported_color_modes):
try:
if brightness := values["brightness"]:
if TYPE_CHECKING:
assert isinstance(brightness, float)
self._attr_brightness = color_util.value_to_brightness(
(1, self._config[CONF_BRIGHTNESS_SCALE]), brightness
)
else:
_LOGGER.debug(
"Ignoring zero brightness value for entity %s",
self.entity_id,
)
except KeyError:
pass
except (TypeError, ValueError):
_LOGGER.warning(
"Invalid brightness value '%s' received for entity %s",
values["brightness"],
self.entity_id,
)
if (
self._deprecated_color_handling
and self.supported_color_modes
and ColorMode.COLOR_TEMP in self.supported_color_modes
):
# Deprecated color handling
try:
if values["color_temp"] is None:
self._attr_color_temp = None
else:
self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type]
except KeyError:
pass
except ValueError:
_LOGGER.warning(
"Invalid color temp value '%s' received for entity %s",
values["color_temp"],
self.entity_id,
)
# Allow to switch back to color_temp
if "color" not in values:
self._attr_hs_color = None
if self.supported_features and LightEntityFeature.EFFECT:
with suppress(KeyError):
self._attr_effect = cast(str, values["effect"])
if self._topic[CONF_STATE_TOPIC] is not None:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._topic[CONF_STATE_TOPIC],
"msg_callback": state_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state() last_state = await self.async_get_last_state()
if self._optimistic and last_state: if self._optimistic and last_state:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any from typing import Any
@@ -44,8 +45,7 @@ from ..const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from ..debug_info import log_messages from ..mixins import MqttEntity
from ..mixins import MqttEntity, write_state_on_attr_change
from ..models import ( from ..models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -188,107 +188,107 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
# Support for ct + hs, prioritize hs # Support for ct + hs, prioritize hs
self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP
@callback
def _state_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
if state == STATE_ON:
self._attr_is_on = True
elif state == STATE_OFF:
self._attr_is_on = False
elif state == PAYLOAD_NONE:
self._attr_is_on = None
else:
_LOGGER.warning("Invalid state value received")
if CONF_BRIGHTNESS_TEMPLATE in self._config:
try:
if brightness := int(
self._value_templates[CONF_BRIGHTNESS_TEMPLATE](msg.payload)
):
self._attr_brightness = brightness
else:
_LOGGER.debug(
"Ignoring zero brightness value for entity %s",
self.entity_id,
)
except ValueError:
_LOGGER.warning("Invalid brightness value received from %s", msg.topic)
if CONF_COLOR_TEMP_TEMPLATE in self._config:
try:
color_temp = self._value_templates[CONF_COLOR_TEMP_TEMPLATE](
msg.payload
)
self._attr_color_temp = (
int(color_temp) if color_temp != "None" else None
)
except ValueError:
_LOGGER.warning("Invalid color temperature value received")
if (
CONF_RED_TEMPLATE in self._config
and CONF_GREEN_TEMPLATE in self._config
and CONF_BLUE_TEMPLATE in self._config
):
try:
red = self._value_templates[CONF_RED_TEMPLATE](msg.payload)
green = self._value_templates[CONF_GREEN_TEMPLATE](msg.payload)
blue = self._value_templates[CONF_BLUE_TEMPLATE](msg.payload)
if red == "None" and green == "None" and blue == "None":
self._attr_hs_color = None
else:
self._attr_hs_color = color_util.color_RGB_to_hs(
int(red), int(green), int(blue)
)
self._update_color_mode()
except ValueError:
_LOGGER.warning("Invalid color value received")
if CONF_EFFECT_TEMPLATE in self._config:
effect = str(self._value_templates[CONF_EFFECT_TEMPLATE](msg.payload))
if (
effect_list := self._config[CONF_EFFECT_LIST]
) and effect in effect_list:
self._attr_effect = effect
else:
_LOGGER.warning("Unsupported effect value received")
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback if self._topics[CONF_STATE_TOPIC] is None:
@log_messages(self.hass, self.entity_id) return
@write_state_on_attr_change(
self, self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{ {
"_attr_brightness", "state_topic": {
"_attr_color_mode", "topic": self._topics[CONF_STATE_TOPIC],
"_attr_color_temp", "msg_callback": partial(
"_attr_effect", self._message_callback,
"_attr_hs_color", self._state_received,
"_attr_is_on", {
"_attr_brightness",
"_attr_color_mode",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
}, },
) )
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
if state == STATE_ON:
self._attr_is_on = True
elif state == STATE_OFF:
self._attr_is_on = False
elif state == PAYLOAD_NONE:
self._attr_is_on = None
else:
_LOGGER.warning("Invalid state value received")
if CONF_BRIGHTNESS_TEMPLATE in self._config:
try:
if brightness := int(
self._value_templates[CONF_BRIGHTNESS_TEMPLATE](msg.payload)
):
self._attr_brightness = brightness
else:
_LOGGER.debug(
"Ignoring zero brightness value for entity %s",
self.entity_id,
)
except ValueError:
_LOGGER.warning(
"Invalid brightness value received from %s", msg.topic
)
if CONF_COLOR_TEMP_TEMPLATE in self._config:
try:
color_temp = self._value_templates[CONF_COLOR_TEMP_TEMPLATE](
msg.payload
)
self._attr_color_temp = (
int(color_temp) if color_temp != "None" else None
)
except ValueError:
_LOGGER.warning("Invalid color temperature value received")
if (
CONF_RED_TEMPLATE in self._config
and CONF_GREEN_TEMPLATE in self._config
and CONF_BLUE_TEMPLATE in self._config
):
try:
red = self._value_templates[CONF_RED_TEMPLATE](msg.payload)
green = self._value_templates[CONF_GREEN_TEMPLATE](msg.payload)
blue = self._value_templates[CONF_BLUE_TEMPLATE](msg.payload)
if red == "None" and green == "None" and blue == "None":
self._attr_hs_color = None
else:
self._attr_hs_color = color_util.color_RGB_to_hs(
int(red), int(green), int(blue)
)
self._update_color_mode()
except ValueError:
_LOGGER.warning("Invalid color value received")
if CONF_EFFECT_TEMPLATE in self._config:
effect = str(self._value_templates[CONF_EFFECT_TEMPLATE](msg.payload))
if (
effect_list := self._config[CONF_EFFECT_LIST]
) and effect in effect_list:
self._attr_effect = effect
else:
_LOGGER.warning("Unsupported effect value received")
if self._topics[CONF_STATE_TOPIC] is not None:
self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass,
self._sub_state,
{
"state_topic": {
"topic": self._topics[CONF_STATE_TOPIC],
"msg_callback": state_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state() last_state = await self.async_get_last_state()
if self._optimistic and last_state: if self._optimistic and last_state:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import re import re
from typing import Any from typing import Any
@@ -36,12 +37,7 @@ from .const import (
CONF_STATE_OPENING, CONF_STATE_OPENING,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -186,57 +182,58 @@ class MqttLock(MqttEntity, LockEntity):
self._valid_states = [config[state] for state in STATE_CONFIG_KEYS] self._valid_states = [config[state] for state in STATE_CONFIG_KEYS]
@callback
def _message_received(self, msg: ReceiveMessage) -> None:
"""Handle new lock state messages."""
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload == self._config[CONF_PAYLOAD_RESET]:
# Reset the state to `unknown`
self._attr_is_locked = None
elif payload in self._valid_states:
self._attr_is_locked = payload == self._config[CONF_STATE_LOCKED]
self._attr_is_locking = payload == self._config[CONF_STATE_LOCKING]
self._attr_is_open = payload == self._config[CONF_STATE_OPEN]
self._attr_is_opening = payload == self._config[CONF_STATE_OPENING]
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]]
topics: dict[str, dict[str, Any]] = {}
qos: int = self._config[CONF_QOS] qos: int = self._config[CONF_QOS]
encoding: str | None = self._config[CONF_ENCODING] or None encoding: str | None = self._config[CONF_ENCODING] or None
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self,
{
"_attr_is_jammed",
"_attr_is_locked",
"_attr_is_locking",
"_attr_is_open",
"_attr_is_opening",
"_attr_is_unlocking",
},
)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new lock state messages."""
payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload == self._config[CONF_PAYLOAD_RESET]:
# Reset the state to `unknown`
self._attr_is_locked = None
elif payload in self._valid_states:
self._attr_is_locked = payload == self._config[CONF_STATE_LOCKED]
self._attr_is_locking = payload == self._config[CONF_STATE_LOCKING]
self._attr_is_open = payload == self._config[CONF_STATE_OPEN]
self._attr_is_opening = payload == self._config[CONF_STATE_OPENING]
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: return
topics[CONF_STATE_TOPIC] = { topics = {
CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._message_received,
{
"_attr_is_jammed",
"_attr_is_locked",
"_attr_is_locking",
"_attr_is_open",
"_attr_is_opening",
"_attr_is_unlocking",
},
),
"entity_id": self.entity_id,
CONF_QOS: qos, CONF_QOS: qos,
CONF_ENCODING: encoding, CONF_ENCODING: encoding,
} }
}
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
@@ -246,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_lock(self, **kwargs: Any) -> None: async def async_lock(self, **kwargs: Any) -> None:
"""Lock the device. """Lock the device.

View File

@@ -114,7 +114,7 @@ from .models import (
from .subscription import ( from .subscription import (
EntitySubscription, EntitySubscription,
async_prepare_subscribe_topics, async_prepare_subscribe_topics,
async_subscribe_topics, async_subscribe_topics_internal,
async_unsubscribe_topics, async_unsubscribe_topics,
) )
from .util import mqtt_config_entry_enabled from .util import mqtt_config_entry_enabled
@@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity):
"""Subscribe MQTT events.""" """Subscribe MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._attributes_prepare_subscribe_topics() self._attributes_prepare_subscribe_topics()
await self._attributes_subscribe_topics() self._attributes_subscribe_topics()
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None: def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
@@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity):
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None: async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
await self._attributes_subscribe_topics() self._attributes_subscribe_topics()
def _attributes_prepare_subscribe_topics(self) -> None: def _attributes_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity):
}, },
) )
async def _attributes_subscribe_topics(self) -> None: @callback
def _attributes_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._attributes_sub_state) async_subscribe_topics_internal(self.hass, self._attributes_sub_state)
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
@@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity):
"""Subscribe MQTT events.""" """Subscribe MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._availability_prepare_subscribe_topics() self._availability_prepare_subscribe_topics()
await self._availability_subscribe_topics() self._availability_subscribe_topics()
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect) async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
) )
@@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity):
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None: async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
await self._availability_subscribe_topics() self._availability_subscribe_topics()
def _availability_setup_from_config(self, config: ConfigType) -> None: def _availability_setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup.""" """(Re)Setup."""
@@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity):
self._available[topic] = False self._available[topic] = False
self._available_latest = False self._available_latest = False
async def _availability_subscribe_topics(self) -> None: @callback
def _availability_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state) async_subscribe_topics_internal(self.hass, self._availability_sub_state)
@callback @callback
def async_mqtt_connect(self) -> None: def async_mqtt_connect(self) -> None:
@@ -1254,13 +1256,15 @@ class MqttEntity(
def _message_callback( def _message_callback(
self, self,
msg_callback: MessageCallbackType, msg_callback: MessageCallbackType,
attributes: set[str], attributes: set[str] | None,
msg: ReceiveMessage, msg: ReceiveMessage,
) -> None: ) -> None:
"""Process the message callback.""" """Process the message callback."""
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple( if attributes is not None:
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
) (attribute, getattr(self, attribute, UNDEFINED))
for attribute in attributes
)
mqtt_data = self.hass.data[DATA_MQTT] mqtt_data = self.hass.data[DATA_MQTT]
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][ messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
msg.subscribed_topic msg.subscribed_topic
@@ -1274,7 +1278,7 @@ class MqttEntity(
_LOGGER.warning(exc) _LOGGER.warning(exc)
return return
if self._attrs_have_changed(attrs_snapshot): if attributes is not None and self._attrs_have_changed(attrs_snapshot):
mqtt_data.state_write_requests.write_state_request(self) mqtt_data.state_write_requests.write_state_request(self)

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from ast import literal_eval from ast import literal_eval
import asyncio import asyncio
from collections import deque from collections import deque
from collections.abc import Callable, Coroutine from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import StrEnum from enum import StrEnum
import logging import logging
@@ -70,7 +70,6 @@ class ReceiveMessage:
timestamp: float timestamp: float
type AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]]
type MessageCallbackType = Callable[[ReceiveMessage], None] type MessageCallbackType = Callable[[ReceiveMessage], None]

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import voluptuous as vol import voluptuous as vol
@@ -41,12 +42,7 @@ from .const import (
CONF_RETAIN, CONF_RETAIN,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -165,64 +161,66 @@ class MqttNumber(MqttEntity, RestoreNumber):
self._attr_native_step = config[CONF_STEP] self._attr_native_step = config[CONF_STEP]
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
@callback
def _message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
num_value: int | float | None
payload = str(self._value_template(msg.payload))
if not payload.strip():
_LOGGER.debug("Ignoring empty state update from '%s'", msg.topic)
return
try:
if payload == self._config[CONF_PAYLOAD_RESET]:
num_value = None
elif payload.isnumeric():
num_value = int(payload)
else:
num_value = float(payload)
except ValueError:
_LOGGER.warning("Payload '%s' is not a Number", msg.payload)
return
if num_value is not None and (
num_value < self.min_value or num_value > self.max_value
):
_LOGGER.error(
"Invalid value for %s: %s (range %s - %s)",
self.entity_id,
num_value,
self.min_value,
self.max_value,
)
return
self._attr_native_value = num_value
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_native_value"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
num_value: int | float | None
payload = str(self._value_template(msg.payload))
if not payload.strip():
_LOGGER.debug("Ignoring empty state update from '%s'", msg.topic)
return
try:
if payload == self._config[CONF_PAYLOAD_RESET]:
num_value = None
elif payload.isnumeric():
num_value = int(payload)
else:
num_value = float(payload)
except ValueError:
_LOGGER.warning("Payload '%s' is not a Number", msg.payload)
return
if num_value is not None and (
num_value < self.min_value or num_value > self.max_value
):
_LOGGER.error(
"Invalid value for %s: %s (range %s - %s)",
self.entity_id,
num_value,
self.min_value,
self.max_value,
)
return
self._attr_native_value = num_value
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._attr_assumed_state = True self._attr_assumed_state = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received, "msg_callback": partial(
"qos": self._config[CONF_QOS], self._message_callback,
"encoding": self._config[CONF_ENCODING] or None, self._message_received,
} {"_attr_native_value"},
}, ),
) "entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and ( if self._attr_assumed_state and (
last_number_data := await self.async_get_last_number_data() last_number_data := await self.async_get_last_number_data()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import voluptuous as vol import voluptuous as vol
@@ -27,12 +28,7 @@ from .const import (
CONF_RETAIN, CONF_RETAIN,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -113,56 +109,58 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
config.get(CONF_VALUE_TEMPLATE), entity=self config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
def _message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = str(self._value_template(msg.payload))
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload.lower() == "none":
self._attr_current_option = None
return
if payload not in self.options:
_LOGGER.error(
"Invalid option for %s: '%s' (valid options: %s)",
self.entity_id,
payload,
self.options,
)
return
self._attr_current_option = payload
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_current_option"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages."""
payload = str(self._value_template(msg.payload))
if not payload.strip(): # No output from template, ignore
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
if payload.lower() == "none":
self._attr_current_option = None
return
if payload not in self.options:
_LOGGER.error(
"Invalid option for %s: '%s' (valid options: %s)",
self.entity_id,
payload,
self.options,
)
return
self._attr_current_option = payload
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._attr_assumed_state = True self._attr_assumed_state = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received, "msg_callback": partial(
"qos": self._config[CONF_QOS], self._message_callback,
"encoding": self._config[CONF_ENCODING] or None, self._message_received,
} {"_attr_current_option"},
}, ),
) "entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and ( if self._attr_assumed_state and (
last_state := await self.async_get_last_state() last_state := await self.async_get_last_state()

View File

@@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@callback @callback
def _value_is_expired(self, *_: datetime) -> None: def _value_is_expired(self, *_: datetime) -> None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any, cast from typing import Any, cast
@@ -48,12 +49,7 @@ from .const import (
PAYLOAD_EMPTY_JSON, PAYLOAD_EMPTY_JSON,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -205,92 +201,94 @@ class MqttSiren(MqttEntity, SirenEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None: @callback
"""(Re)Subscribe to topics.""" def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
@callback payload = self._value_template(msg.payload)
@log_messages(self.hass, self.entity_id) if not payload or payload == PAYLOAD_EMPTY_JSON:
@write_state_on_attr_change(self, {"_attr_is_on", "_extra_attributes"}) _LOGGER.debug(
def state_message_received(msg: ReceiveMessage) -> None: "Ignoring empty payload '%s' after rendering for topic %s",
"""Handle new MQTT state messages.""" payload,
payload = self._value_template(msg.payload) msg.topic,
if not payload or payload == PAYLOAD_EMPTY_JSON: )
return
json_payload: dict[str, Any] = {}
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]:
json_payload = {STATE: payload}
else:
try:
json_payload = json_loads_object(payload)
_LOGGER.debug( _LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s", (
payload, "JSON payload detected after processing payload '%s' on"
" topic %s"
),
json_payload,
msg.topic,
)
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid (JSON) payload detected after processing payload"
" '%s' on topic %s"
),
json_payload,
msg.topic, msg.topic,
) )
return return
json_payload: dict[str, Any] = {} if STATE in json_payload:
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]: if json_payload[STATE] == self._state_on:
json_payload = {STATE: payload} self._attr_is_on = True
else: if json_payload[STATE] == self._state_off:
try: self._attr_is_on = False
json_payload = json_loads_object(payload) if json_payload[STATE] == PAYLOAD_NONE:
_LOGGER.debug( self._attr_is_on = None
( del json_payload[STATE]
"JSON payload detected after processing payload '%s' on"
" topic %s"
),
json_payload,
msg.topic,
)
except JSON_DECODE_EXCEPTIONS:
_LOGGER.warning(
(
"No valid (JSON) payload detected after processing payload"
" '%s' on topic %s"
),
json_payload,
msg.topic,
)
return
if STATE in json_payload:
if json_payload[STATE] == self._state_on:
self._attr_is_on = True
if json_payload[STATE] == self._state_off:
self._attr_is_on = False
if json_payload[STATE] == PAYLOAD_NONE:
self._attr_is_on = None
del json_payload[STATE]
if json_payload: if json_payload:
# process attributes # process attributes
try: try:
params: SirenTurnOnServiceParameters params: SirenTurnOnServiceParameters
params = vol.All(TURN_ON_SCHEMA)(json_payload) params = vol.All(TURN_ON_SCHEMA)(json_payload)
except vol.MultipleInvalid as invalid_siren_parameters: except vol.MultipleInvalid as invalid_siren_parameters:
_LOGGER.warning( _LOGGER.warning(
"Unable to update siren state attributes from payload '%s': %s", "Unable to update siren state attributes from payload '%s': %s",
json_payload, json_payload,
invalid_siren_parameters, invalid_siren_parameters,
) )
return return
# To be able to track changes to self._extra_attributes we assign # To be able to track changes to self._extra_attributes we assign
# a fresh copy to make the original tracked reference immutable. # a fresh copy to make the original tracked reference immutable.
self._extra_attributes = dict(self._extra_attributes) self._extra_attributes = dict(self._extra_attributes)
self._update(process_turn_on_params(self, params)) self._update(process_turn_on_params(self, params))
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
CONF_STATE_TOPIC: { CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received, "msg_callback": partial(
"qos": self._config[CONF_QOS], self._message_callback,
"encoding": self._config[CONF_ENCODING] or None, self._state_message_received,
} {"_attr_is_on", "_extra_attributes"},
}, ),
) "entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property @property
def extra_state_attributes(self) -> dict[str, Any] | None: def extra_state_attributes(self) -> dict[str, Any] | None:

View File

@@ -2,14 +2,15 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Coroutine from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from .. import mqtt
from . import debug_info from . import debug_info
from .client import async_subscribe_internal
from .const import DEFAULT_QOS from .const import DEFAULT_QOS
from .models import MessageCallbackType from .models import MessageCallbackType
@@ -21,7 +22,7 @@ class EntitySubscription:
hass: HomeAssistant hass: HomeAssistant
topic: str | None topic: str | None
message_callback: MessageCallbackType message_callback: MessageCallbackType
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None should_subscribe: bool | None
unsubscribe_callback: Callable[[], None] | None unsubscribe_callback: Callable[[], None] | None
qos: int = 0 qos: int = 0
encoding: str = "utf-8" encoding: str = "utf-8"
@@ -53,15 +54,16 @@ class EntitySubscription:
self.hass, self.message_callback, self.topic, self.entity_id self.hass, self.message_callback, self.topic, self.entity_id
) )
self.subscribe_task = mqtt.async_subscribe( self.should_subscribe = True
hass, self.topic, self.message_callback, self.qos, self.encoding
)
async def subscribe(self) -> None: @callback
def subscribe(self) -> None:
"""Subscribe to a topic.""" """Subscribe to a topic."""
if not self.subscribe_task: if not self.should_subscribe or not self.topic:
return return
self.unsubscribe_callback = await self.subscribe_task self.unsubscribe_callback = async_subscribe_internal(
self.hass, self.topic, self.message_callback, self.qos, self.encoding
)
def _should_resubscribe(self, other: EntitySubscription | None) -> bool: def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
"""Check if we should re-subscribe to the topic using the old state.""" """Check if we should re-subscribe to the topic using the old state."""
@@ -79,6 +81,7 @@ class EntitySubscription:
) )
@callback
def async_prepare_subscribe_topics( def async_prepare_subscribe_topics(
hass: HomeAssistant, hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None, new_state: dict[str, EntitySubscription] | None,
@@ -107,7 +110,7 @@ def async_prepare_subscribe_topics(
qos=value.get("qos", DEFAULT_QOS), qos=value.get("qos", DEFAULT_QOS),
encoding=value.get("encoding", "utf-8"), encoding=value.get("encoding", "utf-8"),
hass=hass, hass=hass,
subscribe_task=None, should_subscribe=None,
entity_id=value.get("entity_id", None), entity_id=value.get("entity_id", None),
) )
# Get the current subscription state # Get the current subscription state
@@ -135,12 +138,29 @@ async def async_subscribe_topics(
sub_state: dict[str, EntitySubscription], sub_state: dict[str, EntitySubscription],
) -> None: ) -> None:
"""(Re)Subscribe to a set of MQTT topics.""" """(Re)Subscribe to a set of MQTT topics."""
async_subscribe_topics_internal(hass, sub_state)
@callback
def async_subscribe_topics_internal(
hass: HomeAssistant,
sub_state: dict[str, EntitySubscription],
) -> None:
"""(Re)Subscribe to a set of MQTT topics.
This function is internal to the MQTT integration and should not be called
from outside the integration.
"""
for sub in sub_state.values(): for sub in sub_state.values():
await sub.subscribe() sub.subscribe()
def async_unsubscribe_topics( if TYPE_CHECKING:
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]: def async_unsubscribe_topics(
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
return async_prepare_subscribe_topics(hass, sub_state, {}) ) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={})

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
@@ -36,12 +37,7 @@ from .const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MqttValueTemplate, ReceiveMessage from .models import MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@@ -118,42 +114,44 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
@callback
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
if payload == self._state_on:
self._attr_is_on = True
elif payload == self._state_off:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
if payload == self._state_on:
self._attr_is_on = True
elif payload == self._state_off:
self._attr_is_on = False
elif payload == PAYLOAD_NONE:
self._attr_is_on = None
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
CONF_STATE_TOPIC: { CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received, "msg_callback": partial(
"qos": self._config[CONF_QOS], self._message_callback,
"encoding": self._config[CONF_ENCODING] or None, self._state_message_received,
} {"_attr_is_on"},
}, ),
) "entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
},
)
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()): if self._optimistic and (last_state := await self.async_get_last_state()):
self._attr_is_on = last_state.state == STATE_ON self._attr_is_on = last_state.state == STATE_ON

View File

@@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin):
} }
}, },
) )
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_tear_down(self) -> None: async def async_tear_down(self) -> None:
"""Cleanup tag scanner.""" """Cleanup tag scanner."""

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import re import re
from typing import Any from typing import Any
@@ -34,12 +35,7 @@ from .const import (
CONF_RETAIN, CONF_RETAIN,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MessageCallbackType, MessageCallbackType,
MqttCommandTemplate, MqttCommandTemplate,
@@ -160,32 +156,41 @@ class MqttTextEntity(MqttEntity, TextEntity):
self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None
self._attr_assumed_state = bool(self._optimistic) self._attr_assumed_state = bool(self._optimistic)
@callback
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = str(self._value_template(msg.payload))
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
return
self._attr_native_value = payload
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, Any] = {} topics: dict[str, Any] = {}
def add_subscription( def add_subscription(
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType topics: dict[str, Any],
topic: str,
msg_callback: MessageCallbackType,
tracked_attributes: set[str],
) -> None: ) -> None:
if self._config.get(topic) is not None: if self._config.get(topic) is not None:
topics[topic] = { topics[topic] = {
"topic": self._config[topic], "topic": self._config[topic],
"msg_callback": msg_callback, "msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@callback add_subscription(
@log_messages(self.hass, self.entity_id) topics,
@write_state_on_attr_change(self, {"_attr_native_value"}) CONF_STATE_TOPIC,
def handle_state_message_received(msg: ReceiveMessage) -> None: self._handle_state_message_received,
"""Handle receiving state message via MQTT.""" {"_attr_native_value"},
payload = str(self._value_template(msg.payload)) )
if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
return
self._attr_native_value = payload
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -193,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_set_value(self, value: str) -> None: async def async_set_value(self, value: str) -> None:
"""Change the text.""" """Change the text."""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from functools import partial
import logging import logging
from typing import Any, TypedDict, cast from typing import Any, TypedDict, cast
@@ -32,12 +33,7 @@ from .const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
PAYLOAD_EMPTY_JSON, PAYLOAD_EMPTY_JSON,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic, valid_subscribe_topic from .util import valid_publish_topic, valid_subscribe_topic
@@ -141,25 +137,104 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
).async_render_with_possible_json_value, ).async_render_with_possible_json_value,
} }
@callback
def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
if not payload or payload == PAYLOAD_EMPTY_JSON:
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
json_payload: _MqttUpdatePayloadType = {}
try:
rendered_json_payload = json_loads(payload)
if isinstance(rendered_json_payload, dict):
_LOGGER.debug(
(
"JSON payload detected after processing payload '%s' on"
" topic %s"
),
rendered_json_payload,
msg.topic,
)
json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload)
else:
_LOGGER.debug(
(
"Non-dictionary JSON payload detected after processing"
" payload '%s' on topic %s"
),
payload,
msg.topic,
)
json_payload = {"installed_version": str(payload)}
except JSON_DECODE_EXCEPTIONS:
_LOGGER.debug(
(
"No valid (JSON) payload detected after processing payload '%s'"
" on topic %s"
),
payload,
msg.topic,
)
json_payload["installed_version"] = str(payload)
if "installed_version" in json_payload:
self._attr_installed_version = json_payload["installed_version"]
if "latest_version" in json_payload:
self._attr_latest_version = json_payload["latest_version"]
if "title" in json_payload:
self._attr_title = json_payload["title"]
if "release_summary" in json_payload:
self._attr_release_summary = json_payload["release_summary"]
if "release_url" in json_payload:
self._attr_release_url = json_payload["release_url"]
if "entity_picture" in json_payload:
self._entity_picture = json_payload["entity_picture"]
@callback
def _handle_latest_version_received(self, msg: ReceiveMessage) -> None:
"""Handle receiving latest version via MQTT."""
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
if isinstance(latest_version, str) and latest_version != "":
self._attr_latest_version = latest_version
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, Any] = {} topics: dict[str, Any] = {}
def add_subscription( def add_subscription(
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType topics: dict[str, Any],
topic: str,
msg_callback: MessageCallbackType,
tracked_attributes: set[str],
) -> None: ) -> None:
if self._config.get(topic) is not None: if self._config.get(topic) is not None:
topics[topic] = { topics[topic] = {
"topic": self._config[topic], "topic": self._config[topic],
"msg_callback": msg_callback, "msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@callback add_subscription(
@log_messages(self.hass, self.entity_id) topics,
@write_state_on_attr_change( CONF_STATE_TOPIC,
self, self._handle_state_message_received,
{ {
"_attr_installed_version", "_attr_installed_version",
"_attr_latest_version", "_attr_latest_version",
@@ -169,84 +244,11 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
"_entity_picture", "_entity_picture",
}, },
) )
def handle_state_message_received(msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT."""
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
if not payload or payload == PAYLOAD_EMPTY_JSON:
_LOGGER.debug(
"Ignoring empty payload '%s' after rendering for topic %s",
payload,
msg.topic,
)
return
json_payload: _MqttUpdatePayloadType = {}
try:
rendered_json_payload = json_loads(payload)
if isinstance(rendered_json_payload, dict):
_LOGGER.debug(
(
"JSON payload detected after processing payload '%s' on"
" topic %s"
),
rendered_json_payload,
msg.topic,
)
json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload)
else:
_LOGGER.debug(
(
"Non-dictionary JSON payload detected after processing"
" payload '%s' on topic %s"
),
payload,
msg.topic,
)
json_payload = {"installed_version": str(payload)}
except JSON_DECODE_EXCEPTIONS:
_LOGGER.debug(
(
"No valid (JSON) payload detected after processing payload '%s'"
" on topic %s"
),
payload,
msg.topic,
)
json_payload["installed_version"] = str(payload)
if "installed_version" in json_payload:
self._attr_installed_version = json_payload["installed_version"]
if "latest_version" in json_payload:
self._attr_latest_version = json_payload["latest_version"]
if "title" in json_payload:
self._attr_title = json_payload["title"]
if "release_summary" in json_payload:
self._attr_release_summary = json_payload["release_summary"]
if "release_url" in json_payload:
self._attr_release_url = json_payload["release_url"]
if "entity_picture" in json_payload:
self._entity_picture = json_payload["entity_picture"]
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(self, {"_attr_latest_version"})
def handle_latest_version_received(msg: ReceiveMessage) -> None:
"""Handle receiving latest version via MQTT."""
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
if isinstance(latest_version, str) and latest_version != "":
self._attr_latest_version = latest_version
add_subscription( add_subscription(
topics, CONF_LATEST_VERSION_TOPIC, handle_latest_version_received topics,
CONF_LATEST_VERSION_TOPIC,
self._handle_latest_version_received,
{"_attr_latest_version"},
) )
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
@@ -255,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_install( async def async_install(
self, version: str | None, backup: bool, **kwargs: Any self, version: str | None, backup: bool, **kwargs: Any

View File

@@ -8,6 +8,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any, cast from typing import Any, cast
@@ -49,12 +50,7 @@ from .const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
DOMAIN, DOMAIN,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ReceiveMessage from .models import ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic from .util import valid_publish_topic
@@ -322,31 +318,32 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0) self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0)
self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0))) self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0)))
@callback
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle state MQTT message."""
payload = json_loads_object(msg.payload)
if STATE in payload and (
(state := payload[STATE]) in POSSIBLE_STATES or state is None
):
self._attr_state = (
POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None
)
del payload[STATE]
self._update_state_attributes(payload)
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics: dict[str, Any] = {} topics: dict[str, Any] = {}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self, {"_attr_battery_level", "_attr_fan_speed", "_attr_state"}
)
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle state MQTT message."""
payload = json_loads_object(msg.payload)
if STATE in payload and (
(state := payload[STATE]) in POSSIBLE_STATES or state is None
):
self._attr_state = (
POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None
)
del payload[STATE]
self._update_state_attributes(payload)
if state_topic := self._config.get(CONF_STATE_TOPIC): if state_topic := self._config.get(CONF_STATE_TOPIC):
topics["state_position_topic"] = { topics["state_position_topic"] = {
"topic": state_topic, "topic": state_topic,
"msg_callback": state_message_received, "msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_battery_level", "_attr_fan_speed", "_attr_state"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -356,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None: async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
"""Publish a command.""" """Publish a command."""

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import suppress from contextlib import suppress
from functools import partial
import logging import logging
from typing import Any from typing import Any
@@ -61,12 +62,7 @@ from .const import (
DEFAULT_RETAIN, DEFAULT_RETAIN,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic, valid_subscribe_topic from .util import valid_publish_topic, valid_subscribe_topic
@@ -302,65 +298,63 @@ class MqttValve(MqttEntity, ValveEntity):
return return
self._update_state(state) self._update_state(state)
@callback
def _state_message_received(self, msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
payload_dict: Any = None
position_payload: Any = payload
state_payload: Any = payload
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
return
with suppress(*JSON_DECODE_EXCEPTIONS):
payload_dict = json_loads(payload)
if isinstance(payload_dict, dict):
if self.reports_position and "position" not in payload_dict:
_LOGGER.warning(
"Missing required `position` attribute in json payload "
"on topic '%s', got: %s",
msg.topic,
payload,
)
return
if not self.reports_position and "state" not in payload_dict:
_LOGGER.warning(
"Missing required `state` attribute in json payload "
" on topic '%s', got: %s",
msg.topic,
payload,
)
return
position_payload = payload_dict.get("position")
state_payload = payload_dict.get("state")
if self._config[CONF_REPORTS_POSITION]:
self._process_position_valve_update(msg, position_payload, state_payload)
else:
self._process_binary_valve_update(msg, state_payload)
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
topics = {} topics = {}
@callback
@log_messages(self.hass, self.entity_id)
@write_state_on_attr_change(
self,
{
"_attr_current_valve_position",
"_attr_is_closed",
"_attr_is_closing",
"_attr_is_opening",
},
)
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages."""
payload = self._value_template(msg.payload)
payload_dict: Any = None
position_payload: Any = payload
state_payload: Any = payload
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", msg.topic)
return
with suppress(*JSON_DECODE_EXCEPTIONS):
payload_dict = json_loads(payload)
if isinstance(payload_dict, dict):
if self.reports_position and "position" not in payload_dict:
_LOGGER.warning(
"Missing required `position` attribute in json payload "
"on topic '%s', got: %s",
msg.topic,
payload,
)
return
if not self.reports_position and "state" not in payload_dict:
_LOGGER.warning(
"Missing required `state` attribute in json payload "
" on topic '%s', got: %s",
msg.topic,
payload,
)
return
position_payload = payload_dict.get("position")
state_payload = payload_dict.get("state")
if self._config[CONF_REPORTS_POSITION]:
self._process_position_valve_update(
msg, position_payload, state_payload
)
else:
self._process_binary_valve_update(msg, state_payload)
if self._config.get(CONF_STATE_TOPIC): if self._config.get(CONF_STATE_TOPIC):
topics["state_topic"] = { topics["state_topic"] = {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received, "msg_callback": partial(
self._message_callback,
self._state_message_received,
{
"_attr_current_valve_position",
"_attr_is_closed",
"_attr_is_closing",
"_attr_is_opening",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -371,7 +365,7 @@ class MqttValve(MqttEntity, ValveEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_open_valve(self) -> None: async def async_open_valve(self) -> None:
"""Move the valve up. """Move the valve up.

View File

@@ -9,6 +9,7 @@ from typing import Literal
import ollama import ollama
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
@@ -138,6 +139,11 @@ class OllamaConversationEntity(
ollama.Message(role=MessageRole.USER.value, content=user_input.text) ollama.Message(role=MessageRole.USER.value, content=user_input.text)
) )
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": message_history.messages},
)
# Get response # Get response
try: try:
response = await client.chat( response = await client.chat(

View File

@@ -31,14 +31,15 @@ from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT, DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -49,6 +50,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
} }
) )
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT,
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
@@ -88,7 +95,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry( return self.async_create_entry(
title="ChatGPT", title="ChatGPT",
data=user_input, data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, options=RECOMMENDED_OPTIONS,
) )
return self.async_show_form( return self.async_show_form(
@@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow):
def __init__(self, config_entry: ConfigEntry) -> None: def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow.""" """Initialize options flow."""
self.config_entry = config_entry self.config_entry = config_entry
self.last_rendered_recommended = config_entry.options.get(
CONF_RECOMMENDED, False
)
async def async_step_init( async def async_step_init(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Manage the options.""" """Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
if user_input is not None: if user_input is not None:
if user_input[CONF_LLM_HASS_API] == "none": if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
user_input.pop(CONF_LLM_HASS_API) if user_input[CONF_LLM_HASS_API] == "none":
return self.async_create_entry(title="", data=user_input) user_input.pop(CONF_LLM_HASS_API)
schema = openai_config_option_schema(self.hass, self.config_entry.options) return self.async_create_entry(title="", data=user_input)
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
schema = openai_config_option_schema(self.hass, options)
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="init",
data_schema=vol.Schema(schema), data_schema=vol.Schema(schema),
@@ -127,16 +150,16 @@ class OpenAIOptionsFlow(OptionsFlow):
def openai_config_option_schema( def openai_config_option_schema(
hass: HomeAssistant, hass: HomeAssistant,
options: MappingProxyType[str, Any], options: dict[str, Any] | MappingProxyType[str, Any],
) -> dict: ) -> dict:
"""Return a schema for OpenAI completion options.""" """Return a schema for OpenAI completion options."""
apis: list[SelectOptionDict] = [ hass_apis: list[SelectOptionDict] = [
SelectOptionDict( SelectOptionDict(
label="No control", label="No control",
value="none", value="none",
) )
] ]
apis.extend( hass_apis.extend(
SelectOptionDict( SelectOptionDict(
label=api.name, label=api.name,
value=api.id, value=api.id,
@@ -144,38 +167,46 @@ def openai_config_option_schema(
for api in llm.async_get_apis(hass) for api in llm.async_get_apis(hass)
) )
return { schema = {
vol.Optional( vol.Optional(
CONF_PROMPT, CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)}, description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(), ): TemplateSelector(),
vol.Optional( vol.Optional(
CONF_LLM_HASS_API, CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)}, description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none", default="none",
): SelectSelector(SelectSelectorConfig(options=apis)), ): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Optional( vol.Required(
CONF_CHAT_MODEL, CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
description={ ): bool,
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
} }
if options.get(CONF_RECOMMENDED):
return schema
schema.update(
{
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
}
)
return schema

View File

@@ -4,13 +4,15 @@ import logging
DOMAIN = "openai_conversation" DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
CONF_RECOMMENDED = "recommended"
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point.""" DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-4o" RECOMMENDED_CHAT_MODEL = "gpt-4o"
CONF_MAX_TOKENS = "max_tokens" CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150 RECOMMENDED_MAX_TOKENS = 150
CONF_TOP_P = "top_p" CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 1.0 RECOMMENDED_TOP_P = 1.0
CONF_TEMPERATURE = "temperature" CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 1.0 RECOMMENDED_TEMPERATURE = 1.0

View File

@@ -8,6 +8,7 @@ import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -22,13 +23,13 @@ from .const import (
CONF_PROMPT, CONF_PROMPT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT, DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
) )
# Max number of back and forth with the LLM to generate a response # Max number of back and forth with the LLM to generate a response
@@ -97,15 +98,14 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API): if options.get(CONF_LLM_HASS_API):
try: try:
llm_api = llm.async_get_api( llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError as err: except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err) LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error( intent_response.async_set_error(
@@ -117,26 +117,12 @@ class OpenAIConversationEntity(
) )
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history: if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
messages = self.history[conversation_id] messages = self.history[conversation_id]
else: else:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
try: try:
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api: if llm_api:
empty_tool_input = llm.ToolInput( empty_tool_input = llm.ToolInput(
tool_name="", tool_name="",
@@ -149,11 +135,24 @@ class OpenAIConversationEntity(
device_id=user_input.device_id, device_id=user_input.device_id,
) )
prompt = ( api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
await llm_api.async_get_api_prompt(empty_tool_input)
+ "\n" else:
+ prompt api_prompt = llm.PROMPT_NO_API_CONFIGURED
prompt = "\n".join(
(
template.Template(
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
),
api_prompt,
) )
)
except TemplateError as err: except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err) LOGGER.error("Error rendering prompt: %s", err)
@@ -170,7 +169,10 @@ class OpenAIConversationEntity(
messages.append({"role": "user", "content": user_input.text}) messages.append({"role": "user", "content": user_input.text})
LOGGER.debug("Prompt for %s: %s", model, messages) LOGGER.debug("Prompt: %s", messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
client = self.hass.data[DOMAIN][self.entry.entry_id] client = self.hass.data[DOMAIN][self.entry.entry_id]
@@ -178,12 +180,12 @@ class OpenAIConversationEntity(
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
try: try:
result = await client.chat.completions.create( result = await client.chat.completions.create(
model=model, model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages, messages=messages,
tools=tools, tools=tools,
max_tokens=max_tokens, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=top_p, top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=temperature, temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=conversation_id, user=conversation_id,
) )
except openai.OpenAIError as err: except openai.OpenAIError as err:

View File

@@ -22,7 +22,8 @@
"max_tokens": "Maximum tokens to return in response", "max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature", "temperature": "Temperature",
"top_p": "Top P", "top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
}, },
"data_description": { "data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template." "prompt": "Instruct how the LLM should respond. This can be a template."

View File

@@ -30,7 +30,6 @@ from .util import (
PLATFORMS = [Platform.MEDIA_PLAYER] PLATFORMS = [Platform.MEDIA_PLAYER]
__all__ = [ __all__ = [
"async_browse_media", "async_browse_media",
"DOMAIN", "DOMAIN",
@@ -50,7 +49,10 @@ class HomeAssistantSpotifyData:
session: OAuth2Session session: OAuth2Session
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: type SpotifyConfigEntry = ConfigEntry[HomeAssistantSpotifyData]
async def async_setup_entry(hass: HomeAssistant, entry: SpotifyConfigEntry) -> bool:
"""Set up Spotify from a config entry.""" """Set up Spotify from a config entry."""
implementation = await async_get_config_entry_implementation(hass, entry) implementation = await async_get_config_entry_implementation(hass, entry)
session = OAuth2Session(hass, entry, implementation) session = OAuth2Session(hass, entry, implementation)
@@ -100,8 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
await device_coordinator.async_config_entry_first_refresh() await device_coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {}) entry.runtime_data = HomeAssistantSpotifyData(
hass.data[DOMAIN][entry.entry_id] = HomeAssistantSpotifyData(
client=spotify, client=spotify,
current_user=current_user, current_user=current_user,
devices=device_coordinator, devices=device_coordinator,
@@ -117,6 +118,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Spotify config entry.""" """Unload Spotify config entry."""
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
del hass.data[DOMAIN][entry.entry_id]
return unload_ok

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from enum import StrEnum from enum import StrEnum
from functools import partial from functools import partial
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
from spotipy import Spotify from spotipy import Spotify
import yarl import yarl
@@ -22,6 +22,9 @@ from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES
from .util import fetch_image_url from .util import fetch_image_url
if TYPE_CHECKING:
from . import HomeAssistantSpotifyData
BROWSE_LIMIT = 48 BROWSE_LIMIT = 48
@@ -140,21 +143,21 @@ async def async_browse_media(
# Check if caller is requesting the root nodes # Check if caller is requesting the root nodes
if media_content_type is None and media_content_id is None: if media_content_type is None and media_content_id is None:
children = [] config_entries = hass.config_entries.async_entries(
for config_entry_id in hass.data[DOMAIN]: DOMAIN, include_disabled=False, include_ignore=False
config_entry = hass.config_entries.async_get_entry(config_entry_id) )
assert config_entry is not None children = [
children.append( BrowseMedia(
BrowseMedia( title=config_entry.title,
title=config_entry.title, media_class=MediaClass.APP,
media_class=MediaClass.APP, media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry.entry_id}",
media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry_id}", media_content_type=f"{MEDIA_PLAYER_PREFIX}library",
media_content_type=f"{MEDIA_PLAYER_PREFIX}library", thumbnail="https://brands.home-assistant.io/_/spotify/logo.png",
thumbnail="https://brands.home-assistant.io/_/spotify/logo.png", can_play=False,
can_play=False, can_expand=True,
can_expand=True,
)
) )
for config_entry in config_entries
]
return BrowseMedia( return BrowseMedia(
title="Spotify", title="Spotify",
media_class=MediaClass.APP, media_class=MediaClass.APP,
@@ -171,9 +174,15 @@ async def async_browse_media(
# Check for config entry specifier, and extract Spotify URI # Check for config entry specifier, and extract Spotify URI
parsed_url = yarl.URL(media_content_id) parsed_url = yarl.URL(media_content_id)
if (info := hass.data[DOMAIN].get(parsed_url.host)) is None:
if (
parsed_url.host is None
or (entry := hass.config_entries.async_get_entry(parsed_url.host)) is None
or not isinstance(entry.runtime_data, HomeAssistantSpotifyData)
):
raise BrowseError("Invalid Spotify account specified") raise BrowseError("Invalid Spotify account specified")
media_content_id = parsed_url.name media_content_id = parsed_url.name
info = entry.runtime_data
result = await async_browse_media_internal( result = await async_browse_media_internal(
hass, hass,

View File

@@ -22,7 +22,6 @@ from homeassistant.components.media_player import (
MediaType, MediaType,
RepeatMode, RepeatMode,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ID from homeassistant.const import CONF_ID
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -30,7 +29,7 @@ from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from . import HomeAssistantSpotifyData from . import HomeAssistantSpotifyData, SpotifyConfigEntry
from .browse_media import async_browse_media_internal from .browse_media import async_browse_media_internal
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES
from .util import fetch_image_url from .util import fetch_image_url
@@ -70,12 +69,12 @@ SPOTIFY_DJ_PLAYLIST = {"uri": "spotify:playlist:37i9dQZF1EYkqdzj48dyYq", "name":
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
entry: ConfigEntry, entry: SpotifyConfigEntry,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Spotify based on a config entry.""" """Set up Spotify based on a config entry."""
spotify = SpotifyMediaPlayer( spotify = SpotifyMediaPlayer(
hass.data[DOMAIN][entry.entry_id], entry.runtime_data,
entry.data[CONF_ID], entry.data[CONF_ID],
entry.title, entry.title,
) )

View File

@@ -4,15 +4,14 @@ from __future__ import annotations
import logging import logging
from aioswitcher.bridge import SwitcherBridge
from aioswitcher.device import SwitcherBase from aioswitcher.device import SwitcherBase
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from .const import DATA_DEVICE, DOMAIN
from .coordinator import SwitcherDataUpdateCoordinator from .coordinator import SwitcherDataUpdateCoordinator
from .utils import async_start_bridge, async_stop_bridge
PLATFORMS = [ PLATFORMS = [
Platform.BUTTON, Platform.BUTTON,
@@ -25,20 +24,20 @@ PLATFORMS = [
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: type SwitcherConfigEntry = ConfigEntry[dict[str, SwitcherDataUpdateCoordinator]]
async def async_setup_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
"""Set up Switcher from a config entry.""" """Set up Switcher from a config entry."""
hass.data.setdefault(DOMAIN, {})
hass.data[DOMAIN][DATA_DEVICE] = {}
@callback @callback
def on_device_data_callback(device: SwitcherBase) -> None: def on_device_data_callback(device: SwitcherBase) -> None:
"""Use as a callback for device data.""" """Use as a callback for device data."""
coordinators = entry.runtime_data
# Existing device update device data # Existing device update device data
if device.device_id in hass.data[DOMAIN][DATA_DEVICE]: if coordinator := coordinators.get(device.device_id):
coordinator: SwitcherDataUpdateCoordinator = hass.data[DOMAIN][DATA_DEVICE][
device.device_id
]
coordinator.async_set_updated_data(device) coordinator.async_set_updated_data(device)
return return
@@ -52,18 +51,21 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
device.device_type.hex_rep, device.device_type.hex_rep,
) )
coordinator = hass.data[DOMAIN][DATA_DEVICE][device.device_id] = ( coordinator = SwitcherDataUpdateCoordinator(hass, entry, device)
SwitcherDataUpdateCoordinator(hass, entry, device)
)
coordinator.async_setup() coordinator.async_setup()
coordinators[device.device_id] = coordinator
# Must be ready before dispatcher is called # Must be ready before dispatcher is called
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
await async_start_bridge(hass, on_device_data_callback) entry.runtime_data = {}
bridge = SwitcherBridge(on_device_data_callback)
await bridge.start()
async def stop_bridge(event: Event) -> None: async def stop_bridge(event: Event | None = None) -> None:
await async_stop_bridge(hass) await bridge.stop()
entry.async_on_unload(stop_bridge)
entry.async_on_unload( entry.async_on_unload(
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge)
@@ -72,12 +74,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
await async_stop_bridge(hass) return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(DATA_DEVICE)
return unload_ok

View File

@@ -15,7 +15,6 @@ from aioswitcher.api.remotes import SwitcherBreezeRemote
from aioswitcher.device import DeviceCategory from aioswitcher.device import DeviceCategory
from homeassistant.components.button import ButtonEntity, ButtonEntityDescription from homeassistant.components.button import ButtonEntity, ButtonEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -25,6 +24,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import SwitcherConfigEntry
from .const import SIGNAL_DEVICE_ADD from .const import SIGNAL_DEVICE_ADD
from .coordinator import SwitcherDataUpdateCoordinator from .coordinator import SwitcherDataUpdateCoordinator
from .utils import get_breeze_remote_manager from .utils import get_breeze_remote_manager
@@ -78,7 +78,7 @@ THERMOSTAT_BUTTONS = [
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: ConfigEntry, config_entry: SwitcherConfigEntry,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Switcher button from config entry.""" """Set up Switcher button from config entry."""

View File

@@ -25,7 +25,6 @@ from homeassistant.components.climate import (
ClimateEntityFeature, ClimateEntityFeature,
HVACMode, HVACMode,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -35,6 +34,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import SwitcherConfigEntry
from .const import SIGNAL_DEVICE_ADD from .const import SIGNAL_DEVICE_ADD
from .coordinator import SwitcherDataUpdateCoordinator from .coordinator import SwitcherDataUpdateCoordinator
from .utils import get_breeze_remote_manager from .utils import get_breeze_remote_manager
@@ -61,7 +61,7 @@ HA_TO_DEVICE_FAN = {value: key for key, value in DEVICE_FAN_TO_HA.items()}
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: ConfigEntry, config_entry: SwitcherConfigEntry,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Switcher climate from config entry.""" """Set up Switcher climate from config entry."""

View File

@@ -2,9 +2,6 @@
DOMAIN = "switcher_kis" DOMAIN = "switcher_kis"
DATA_BRIDGE = "bridge"
DATA_DEVICE = "device"
DISCOVERY_TIME_SEC = 12 DISCOVERY_TIME_SEC = 12
SIGNAL_DEVICE_ADD = "switcher_device_add" SIGNAL_DEVICE_ADD = "switcher_device_add"

View File

@@ -6,24 +6,23 @@ from dataclasses import asdict
from typing import Any from typing import Any
from homeassistant.components.diagnostics import async_redact_data from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import DATA_DEVICE, DOMAIN from . import SwitcherConfigEntry
TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"} TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"}
async def async_get_config_entry_diagnostics( async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: SwitcherConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
devices = hass.data[DOMAIN][DATA_DEVICE] coordinators = entry.runtime_data
return async_redact_data( return async_redact_data(
{ {
"entry": entry.as_dict(), "entry": entry.as_dict(),
"devices": [asdict(devices[d].data) for d in devices], "devices": [asdict(coordinators[d].data) for d in coordinators],
}, },
TO_REDACT, TO_REDACT,
) )

View File

@@ -3,9 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable
import logging import logging
from typing import Any
from aioswitcher.api.remotes import SwitcherBreezeRemoteManager from aioswitcher.api.remotes import SwitcherBreezeRemoteManager
from aioswitcher.bridge import SwitcherBase, SwitcherBridge from aioswitcher.bridge import SwitcherBase, SwitcherBridge
@@ -13,29 +11,11 @@ from aioswitcher.bridge import SwitcherBase, SwitcherBridge
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import singleton from homeassistant.helpers import singleton
from .const import DATA_BRIDGE, DISCOVERY_TIME_SEC, DOMAIN from .const import DISCOVERY_TIME_SEC
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_start_bridge(
hass: HomeAssistant, on_device_callback: Callable[[SwitcherBase], Any]
) -> None:
"""Start switcher UDP bridge."""
bridge = hass.data[DOMAIN][DATA_BRIDGE] = SwitcherBridge(on_device_callback)
_LOGGER.debug("Starting Switcher bridge")
await bridge.start()
async def async_stop_bridge(hass: HomeAssistant) -> None:
"""Stop switcher UDP bridge."""
bridge: SwitcherBridge = hass.data[DOMAIN].get(DATA_BRIDGE)
if bridge is not None:
_LOGGER.debug("Stopping Switcher bridge")
await bridge.stop()
hass.data[DOMAIN].pop(DATA_BRIDGE)
async def async_has_devices(hass: HomeAssistant) -> bool: async def async_has_devices(hass: HomeAssistant) -> bool:
"""Discover Switcher devices.""" """Discover Switcher devices."""
_LOGGER.debug("Starting discovery") _LOGGER.debug("Starting discovery")

View File

@@ -30,6 +30,7 @@ PLATFORMS: Final = [
Platform.BINARY_SENSOR, Platform.BINARY_SENSOR,
Platform.CLIMATE, Platform.CLIMATE,
Platform.COVER, Platform.COVER,
Platform.DEVICE_TRACKER,
Platform.LOCK, Platform.LOCK,
Platform.SELECT, Platform.SELECT,
Platform.SENSOR, Platform.SENSOR,

View File

@@ -0,0 +1,85 @@
"""Device tracker platform for Teslemetry integration."""
from __future__ import annotations
from homeassistant.components.device_tracker import SourceType
from homeassistant.components.device_tracker.config_entry import TrackerEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .entity import TeslemetryVehicleEntity
from .models import TeslemetryVehicleData
async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
) -> None:
"""Set up the Teslemetry device tracker platform from a config entry."""
async_add_entities(
klass(vehicle)
for klass in (
TeslemetryDeviceTrackerLocationEntity,
TeslemetryDeviceTrackerRouteEntity,
)
for vehicle in entry.runtime_data.vehicles
)
class TeslemetryDeviceTrackerEntity(TeslemetryVehicleEntity, TrackerEntity):
"""Base class for Teslemetry tracker entities."""
lat_key: str
lon_key: str
def __init__(
self,
vehicle: TeslemetryVehicleData,
) -> None:
"""Initialize the device tracker."""
super().__init__(vehicle, self.key)
def _async_update_attrs(self) -> None:
"""Update the attributes of the device tracker."""
self._attr_available = (
self.get(self.lat_key, False) is not None
and self.get(self.lon_key, False) is not None
)
@property
def latitude(self) -> float | None:
"""Return latitude value of the device."""
return self.get(self.lat_key)
@property
def longitude(self) -> float | None:
"""Return longitude value of the device."""
return self.get(self.lon_key)
@property
def source_type(self) -> SourceType:
"""Return the source type of the device tracker."""
return SourceType.GPS
class TeslemetryDeviceTrackerLocationEntity(TeslemetryDeviceTrackerEntity):
"""Vehicle location device tracker class."""
key = "location"
lat_key = "drive_state_latitude"
lon_key = "drive_state_longitude"
class TeslemetryDeviceTrackerRouteEntity(TeslemetryDeviceTrackerEntity):
"""Vehicle navigation device tracker class."""
key = "route"
lat_key = "drive_state_active_route_latitude"
lon_key = "drive_state_active_route_longitude"
@property
def location_name(self) -> str | None:
"""Return a location name for the current location of the device."""
return self.get("drive_state_active_route_destination")

View File

@@ -109,6 +109,7 @@
"off": "mdi:car-seat" "off": "mdi:car-seat"
} }
}, },
"components_customer_preferred_export_rule": { "components_customer_preferred_export_rule": {
"default": "mdi:transmission-tower", "default": "mdi:transmission-tower",
"state": { "state": {
@@ -126,6 +127,14 @@
} }
} }
}, },
"device_tracker": {
"location": {
"default": "mdi:map-marker"
},
"route": {
"default": "mdi:routes"
}
},
"cover": { "cover": {
"charge_state_charge_port_door_open": { "charge_state_charge_port_door_open": {
"default": "mdi:ev-plug-ccs2" "default": "mdi:ev-plug-ccs2"

View File

@@ -111,6 +111,14 @@
} }
} }
}, },
"device_tracker": {
"location": {
"name": "Location"
},
"route": {
"name": "Route"
}
},
"lock": { "lock": {
"charge_state_charge_port_latch": { "charge_state_charge_port_latch": {
"name": "Charge cable lock" "name": "Charge cable lock"

View File

@@ -13,7 +13,7 @@
"velbus-packet", "velbus-packet",
"velbus-protocol" "velbus-protocol"
], ],
"requirements": ["velbus-aio==2024.4.1"], "requirements": ["velbus-aio==2024.5.1"],
"usb": [ "usb": [
{ {
"vid": "10CF", "vid": "10CF",

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Awaitable
import logging import logging
from typing import Any from typing import Any
@@ -157,16 +156,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Unload Withings config entry.""" """Unload vera config entry."""
controller_data: ControllerData = get_controller_data(hass, config_entry) controller_data: ControllerData = get_controller_data(hass, config_entry)
await asyncio.gather(
tasks: list[Awaitable] = [ *(
hass.config_entries.async_forward_entry_unload(config_entry, platform) hass.config_entries.async_unload_platforms(
for platform in get_configured_platforms(controller_data) config_entry, get_configured_platforms(controller_data)
] ),
tasks.append(hass.async_add_executor_job(controller_data.controller.stop)) hass.async_add_executor_job(controller_data.controller.stop),
await asyncio.gather(*tasks) )
)
return True return True

View File

@@ -1,6 +1,5 @@
"""Support for Zigbee Home Automation devices.""" """Support for Zigbee Home Automation devices."""
import asyncio
import contextlib import contextlib
import copy import copy
import logging import logging
@@ -238,12 +237,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
websocket_api.async_unload_api(hass) websocket_api.async_unload_api(hass)
# our components don't have unload methods so no need to look at return values # our components don't have unload methods so no need to look at return values
await asyncio.gather( await hass.config_entries.async_unload_platforms(config_entry, PLATFORMS)
*(
hass.config_entries.async_forward_entry_unload(config_entry, platform)
for platform in PLATFORMS
)
)
return True return True

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Coroutine
from contextlib import suppress from contextlib import suppress
import logging import logging
from typing import Any from typing import Any
@@ -958,14 +957,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
client: ZwaveClient = entry.runtime_data[DATA_CLIENT] client: ZwaveClient = entry.runtime_data[DATA_CLIENT]
driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS] driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS]
platforms = [
tasks: list[Coroutine] = [ platform
hass.config_entries.async_forward_entry_unload(entry, platform)
for platform, task in driver_events.platform_setup_tasks.items() for platform, task in driver_events.platform_setup_tasks.items()
if not task.cancel() if not task.cancel()
] ]
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
unload_ok = all(await asyncio.gather(*tasks)) if tasks else True
if client.connected and client.driver: if client.connected and client.driver:
await async_disable_server_logging_if_needed(hass, entry, client.driver) await async_disable_server_logging_if_needed(hass, entry, client.driver)

View File

@@ -3,12 +3,16 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import asdict, dataclass
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
from homeassistant.components.conversation.trace import (
ConversationTraceEventType,
async_conversation_trace_append,
)
from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -116,6 +120,10 @@ class API(ABC):
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response.""" """Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
)
for tool in self.async_get_tools(): for tool in self.async_get_tools():
if tool.name == tool_input.tool_name: if tool.name == tool_input.tool_name:
break break
@@ -191,7 +199,10 @@ class AssistAPI(API):
async def async_get_api_prompt(self, tool_input: ToolInput) -> str: async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API.""" """Return the prompt for the API."""
prompt = "Call the intent tools to control Home Assistant. Just pass the name to the intent." prompt = (
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent."
)
if tool_input.device_id: if tool_input.device_id:
device_reg = device_registry.async_get(self.hass) device_reg = device_registry.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id) device = device_reg.async_get(tool_input.device_id)

View File

@@ -1821,7 +1821,7 @@ pyegps==0.2.5
pyenphase==1.20.3 pyenphase==1.20.3
# homeassistant.components.envisalink # homeassistant.components.envisalink
pyenvisalink==4.6 pyenvisalink==4.7
# homeassistant.components.ephember # homeassistant.components.ephember
pyephember==0.3.1 pyephember==0.3.1
@@ -2817,7 +2817,7 @@ vallox-websocket-api==5.1.1
vehicle==2.2.1 vehicle==2.2.1
# homeassistant.components.velbus # homeassistant.components.velbus
velbus-aio==2024.4.1 velbus-aio==2024.5.1
# homeassistant.components.venstar # homeassistant.components.venstar
venstarcolortouch==0.19 venstarcolortouch==0.19

View File

@@ -2185,7 +2185,7 @@ vallox-websocket-api==5.1.1
vehicle==2.2.1 vehicle==2.2.1
# homeassistant.components.velbus # homeassistant.components.velbus
velbus-aio==2024.4.1 velbus-aio==2024.5.1
# homeassistant.components.venstar # homeassistant.components.venstar
venstarcolortouch==0.19 venstarcolortouch==0.19

View File

@@ -117,7 +117,6 @@ NO_IOT_CLASS = [
# https://github.com/home-assistant/developers.home-assistant/pull/1512 # https://github.com/home-assistant/developers.home-assistant/pull/1512
NO_DIAGNOSTICS = [ NO_DIAGNOSTICS = [
"dlna_dms", "dlna_dms",
"fronius",
"gdacs", "gdacs",
"geonetnz_quakes", "geonetnz_quakes",
"google_assistant_sdk", "google_assistant_sdk",

View File

@@ -2,7 +2,9 @@
from unittest.mock import patch from unittest.mock import patch
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant, State from homeassistant.core import Context, HomeAssistant, State
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@@ -31,6 +33,11 @@ async def test_state_set_and_restore(hass: HomeAssistant) -> None:
) as mock_process, ) as mock_process,
patch("homeassistant.util.dt.utcnow", return_value=now), patch("homeassistant.util.dt.utcnow", return_value=now),
): ):
intent_response = intent.IntentResponse(language="en")
intent_response.async_set_speech("response text")
mock_process.return_value = conversation.ConversationResult(
response=intent_response,
)
await hass.services.async_call( await hass.services.async_call(
"conversation", "conversation",
"process", "process",

View File

@@ -0,0 +1,80 @@
"""Test for the conversation traces."""
from unittest.mock import patch
import pytest
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
@pytest.fixture
async def init_components(hass: HomeAssistant):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {})
async def test_converation_trace(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert trace_event.get("data")
assert trace_event["data"].get("text") == "add apples to my shopping list"
assert last_trace.get("result")
assert (
last_trace["result"]
.get("response", {})
.get("speech", {})
.get("plain", {})
.get("speech")
== "Added apples"
)
async def test_converation_trace_error(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
side_effect=HomeAssistantError("Failed to talk to agent"),
),
pytest.raises(HomeAssistantError),
):
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert last_trace.get("error") == "Failed to talk to agent"

View File

@@ -25,6 +25,7 @@ async def setup_fronius_integration(
"""Create the Fronius integration.""" """Create the Fronius integration."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
entry_id="f1e2b9837e8adaed6fa682acaa216fd8",
unique_id=unique_id, # has to match mocked logger unique_id unique_id=unique_id, # has to match mocked logger unique_id
data={ data={
CONF_HOST: MOCK_HOST, CONF_HOST: MOCK_HOST,

View File

@@ -0,0 +1,370 @@
# serializer version: 1
# name: test_diagnostics
dict({
'config_entry': dict({
'data': dict({
'host': 'http://fronius',
'is_logger': True,
}),
'disabled_by': None,
'domain': 'fronius',
'entry_id': 'f1e2b9837e8adaed6fa682acaa216fd8',
'minor_version': 1,
'options': dict({
}),
'pref_disable_new_entities': False,
'pref_disable_polling': False,
'source': 'user',
'title': 'Mock Title',
'unique_id': '**REDACTED**',
'version': 1,
}),
'coordinators': dict({
'inverters': dict({
'1': dict({
'current_ac': dict({
'unit': 'A',
'value': 5.19,
}),
'current_dc': dict({
'unit': 'A',
'value': 2.19,
}),
'energy_day': dict({
'unit': 'Wh',
'value': 1113,
}),
'energy_total': dict({
'unit': 'Wh',
'value': 44188000,
}),
'energy_year': dict({
'unit': 'Wh',
'value': 25508798,
}),
'error_code': dict({
'value': 0,
}),
'frequency_ac': dict({
'unit': 'Hz',
'value': 49.94,
}),
'led_color': dict({
'value': 2,
}),
'led_state': dict({
'value': 0,
}),
'power_ac': dict({
'unit': 'W',
'value': 1190,
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'status_code': dict({
'value': 7,
}),
'timestamp': dict({
'value': '2021-10-07T10:01:17+02:00',
}),
'voltage_ac': dict({
'unit': 'V',
'value': 227.9,
}),
'voltage_dc': dict({
'unit': 'V',
'value': 518,
}),
}),
}),
'logger': dict({
'system': dict({
'cash_factor': dict({
'unit': 'EUR/kWh',
'value': 0.07800000160932541,
}),
'co2_factor': dict({
'unit': 'kg/kWh',
'value': 0.5299999713897705,
}),
'delivery_factor': dict({
'unit': 'EUR/kWh',
'value': 0.15000000596046448,
}),
'hardware_platform': dict({
'value': 'wilma',
}),
'hardware_version': dict({
'value': '2.4E',
}),
'product_type': dict({
'value': 'fronius-datamanager-card',
}),
'software_version': dict({
'value': '3.18.7-1',
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'time_zone': dict({
'value': 'CEST',
}),
'time_zone_location': dict({
'value': 'Vienna',
}),
'timestamp': dict({
'value': '2021-10-06T23:56:32+02:00',
}),
'unique_identifier': '**REDACTED**',
'utc_offset': dict({
'value': 7200,
}),
}),
}),
'meter': dict({
'0': dict({
'current_ac_phase_1': dict({
'unit': 'A',
'value': 7.755,
}),
'current_ac_phase_2': dict({
'unit': 'A',
'value': 6.68,
}),
'current_ac_phase_3': dict({
'unit': 'A',
'value': 10.102,
}),
'enable': dict({
'value': 1,
}),
'energy_reactive_ac_consumed': dict({
'unit': 'VArh',
'value': 59960790,
}),
'energy_reactive_ac_produced': dict({
'unit': 'VArh',
'value': 723160,
}),
'energy_real_ac_minus': dict({
'unit': 'Wh',
'value': 35623065,
}),
'energy_real_ac_plus': dict({
'unit': 'Wh',
'value': 15303334,
}),
'energy_real_consumed': dict({
'unit': 'Wh',
'value': 15303334,
}),
'energy_real_produced': dict({
'unit': 'Wh',
'value': 35623065,
}),
'frequency_phase_average': dict({
'unit': 'Hz',
'value': 50,
}),
'manufacturer': dict({
'value': 'Fronius',
}),
'meter_location': dict({
'value': 0,
}),
'model': dict({
'value': 'Smart Meter 63A',
}),
'power_apparent': dict({
'unit': 'VA',
'value': 5592.57,
}),
'power_apparent_phase_1': dict({
'unit': 'VA',
'value': 1772.793,
}),
'power_apparent_phase_2': dict({
'unit': 'VA',
'value': 1527.048,
}),
'power_apparent_phase_3': dict({
'unit': 'VA',
'value': 2333.562,
}),
'power_factor': dict({
'value': 1,
}),
'power_factor_phase_1': dict({
'value': -0.99,
}),
'power_factor_phase_2': dict({
'value': -0.99,
}),
'power_factor_phase_3': dict({
'value': 0.99,
}),
'power_reactive': dict({
'unit': 'VAr',
'value': 2.87,
}),
'power_reactive_phase_1': dict({
'unit': 'VAr',
'value': 51.48,
}),
'power_reactive_phase_2': dict({
'unit': 'VAr',
'value': 115.63,
}),
'power_reactive_phase_3': dict({
'unit': 'VAr',
'value': -164.24,
}),
'power_real': dict({
'unit': 'W',
'value': 5592.57,
}),
'power_real_phase_1': dict({
'unit': 'W',
'value': 1765.55,
}),
'power_real_phase_2': dict({
'unit': 'W',
'value': 1515.8,
}),
'power_real_phase_3': dict({
'unit': 'W',
'value': 2311.22,
}),
'serial': '**REDACTED**',
'visible': dict({
'value': 1,
}),
'voltage_ac_phase_1': dict({
'unit': 'V',
'value': 228.6,
}),
'voltage_ac_phase_2': dict({
'unit': 'V',
'value': 228.6,
}),
'voltage_ac_phase_3': dict({
'unit': 'V',
'value': 231,
}),
'voltage_ac_phase_to_phase_12': dict({
'unit': 'V',
'value': 395.9,
}),
'voltage_ac_phase_to_phase_23': dict({
'unit': 'V',
'value': 398,
}),
'voltage_ac_phase_to_phase_31': dict({
'unit': 'V',
'value': 398,
}),
}),
}),
'ohmpilot': None,
'power_flow': dict({
'power_flow': dict({
'energy_day': dict({
'unit': 'Wh',
'value': 1101.7000732421875,
}),
'energy_total': dict({
'unit': 'Wh',
'value': 44188000,
}),
'energy_year': dict({
'unit': 'Wh',
'value': 25508788,
}),
'meter_location': dict({
'value': 'grid',
}),
'meter_mode': dict({
'value': 'meter',
}),
'power_battery': dict({
'unit': 'W',
'value': None,
}),
'power_grid': dict({
'unit': 'W',
'value': 1703.74,
}),
'power_load': dict({
'unit': 'W',
'value': -2814.74,
}),
'power_photovoltaics': dict({
'unit': 'W',
'value': 1111,
}),
'relative_autonomy': dict({
'unit': '%',
'value': 39.4707859340472,
}),
'relative_self_consumption': dict({
'unit': '%',
'value': 100,
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'timestamp': dict({
'value': '2021-10-07T10:00:43+02:00',
}),
}),
}),
'storage': None,
}),
'inverter_info': dict({
'inverters': list([
dict({
'custom_name': dict({
'value': 'Symo 20',
}),
'device_id': dict({
'value': '1',
}),
'device_type': dict({
'manufacturer': 'Fronius',
'model': 'Symo 20.0-3-M',
'value': 121,
}),
'error_code': dict({
'value': 0,
}),
'pv_power': dict({
'unit': 'W',
'value': 23100,
}),
'show': dict({
'value': 1,
}),
'status_code': dict({
'value': 7,
}),
'unique_id': '**REDACTED**',
}),
]),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'timestamp': dict({
'value': '2021-10-07T13:41:00+02:00',
}),
}),
})
# ---

View File

@@ -0,0 +1,31 @@
"""Tests for the diagnostics data provided by the KNX integration."""
from syrupy import SnapshotAssertion
from homeassistant.core import HomeAssistant
from . import mock_responses, setup_fronius_integration
from tests.components.diagnostics import get_diagnostics_for_config_entry
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
async def test_diagnostics(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
snapshot: SnapshotAssertion,
) -> None:
"""Test diagnostics."""
mock_responses(aioclient_mock)
entry = await setup_fronius_integration(hass)
assert (
await get_diagnostics_for_config_entry(
hass,
hass_client,
entry,
)
== snapshot
)

View File

@@ -1,4 +1,114 @@
# serializer version: 1 # serializer version: 1
# name: test_chat_history
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'1st user request',
),
dict({
}),
),
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
dict({
'parts': '1st user request',
'role': 'user',
}),
dict({
'parts': '1st model response',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'2nd user request',
),
dict({
}),
),
])
# ---
# name: test_default_prompt[config_entry_options0-None] # name: test_default_prompt[config_entry_options0-None]
list([ list([
tuple( tuple(
@@ -14,10 +124,10 @@
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'tools': None, 'tools': None,
}), }),
@@ -29,7 +139,10 @@
dict({ dict({
'history': list([ 'history': list([
dict({ dict({
'parts': 'Answer in plain text. Keep it simple and to the point.', 'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user', 'role': 'user',
}), }),
dict({ dict({
@@ -64,10 +177,10 @@
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'tools': None, 'tools': None,
}), }),
@@ -79,7 +192,10 @@
dict({ dict({
'history': list([ 'history': list([
dict({ dict({
'parts': 'Answer in plain text. Keep it simple and to the point.', 'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user', 'role': 'user',
}), }),
dict({ dict({
@@ -114,10 +230,10 @@
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'tools': None, 'tools': None,
}), }),
@@ -130,8 +246,8 @@
'history': list([ 'history': list([
dict({ dict({
'parts': ''' 'parts': '''
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
''', ''',
'role': 'user', 'role': 'user',
}), }),
@@ -167,10 +283,10 @@
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'tools': None, 'tools': None,
}), }),
@@ -183,8 +299,8 @@
'history': list([ 'history': list([
dict({ dict({
'parts': ''' 'parts': '''
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
''', ''',
'role': 'user', 'role': 'user',
}), }),

View File

@@ -2,12 +2,14 @@
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from google.api_core.exceptions import ClientError from google.api_core.exceptions import GoogleAPICallError
import google.generativeai.types as genai_types
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -150,6 +152,57 @@ async def test_default_prompt(
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options) assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
async def test_chat_history(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the agent keeps track of the chat history."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call = None
chat_response.parts = [mock_part]
chat_response.text = "1st model response"
mock_chat.history = [
{"role": "user", "parts": "prompt"},
{"role": "model", "parts": "Ok"},
{"role": "user", "parts": "1st user request"},
{"role": "model", "parts": "1st model response"},
]
result = await conversation.async_converse(
hass,
"1st user request",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.as_dict()["speech"]["plain"]["speech"]
== "1st model response"
)
chat_response.text = "2nd model response"
result = await conversation.async_converse(
hass,
"2nd user request",
result.conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.as_dict()["speech"]["plain"]["speech"]
== "2nd model response"
)
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
@patch( @patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
) )
@@ -233,6 +286,20 @@ async def test_function_call(
), ),
) )
# Test conversating tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"]
@patch( @patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
@@ -325,7 +392,7 @@ async def test_error_handling(
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = ClientError("some error") mock_chat.send_message_async.side_effect = GoogleAPICallError("some error")
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
) )
@@ -340,7 +407,28 @@ async def test_error_handling(
async def test_blocked_response( async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None: ) -> None:
"""Test response was blocked.""" """Test blocked response."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException(
"finish_reason: SAFETY\n"
)
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"The message got blocked by your safety settings"
)
async def test_empty_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test empty response."""
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_model.return_value.start_chat.return_value = mock_chat
@@ -358,6 +446,32 @@ async def test_blocked_response(
) )
async def test_invalid_llm_api(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test handling of invalid llm api."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
)
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Error preparing LLM API: API invalid_llm_api not found"
)
async def test_template_error( async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:

View File

@@ -529,16 +529,16 @@ async def test_non_unique_triggers(
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 2 assert len(calls) == 2
assert calls[0].data["some"] == "press1" all_calls = {calls[0].data["some"], calls[1].data["some"]}
assert calls[1].data["some"] == "press2" assert all_calls == {"press1", "press2"}
# Trigger second config references to same trigger # Trigger second config references to same trigger
# and triggers both attached instances. # and triggers both attached instances.
async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press") async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 2 assert len(calls) == 2
assert calls[0].data["some"] == "press1" all_calls = {calls[0].data["some"], calls[1].data["some"]}
assert calls[1].data["some"] == "press2" assert all_calls == {"press1", "press2"}
# Removing the first trigger will clean up # Removing the first trigger will clean up
calls.clear() calls.clear()

View File

@@ -4,6 +4,7 @@ import asyncio
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial
import json import json
import logging import logging
import socket import socket
@@ -1050,6 +1051,27 @@ async def test_subscribe_topic_not_initialize(
await mqtt.async_subscribe(hass, "test-topic", record_calls) await mqtt.async_subscribe(hass, "test-topic", record_calls)
async def test_subscribe_mqtt_config_entry_disabled(
hass: HomeAssistant, mqtt_mock: MqttMockHAClient
) -> None:
"""Test the subscription of a topic when MQTT config entry is disabled."""
mqtt_mock.connected = True
mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
assert mqtt_config_entry.state is ConfigEntryState.LOADED
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED
await hass.config_entries.async_set_disabled_by(
mqtt_config_entry.entry_id, ConfigEntryDisabler.USER
)
mqtt_mock.connected = False
with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"):
await mqtt.async_subscribe(hass, "test-topic", record_calls)
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2) @patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
async def test_subscribe_and_resubscribe( async def test_subscribe_and_resubscribe(
@@ -2912,8 +2934,8 @@ async def test_message_callback_exception_gets_logged(
await mqtt_mock_entry() await mqtt_mock_entry()
@callback @callback
def bad_handler(*args) -> None: def bad_handler(msg: ReceiveMessage) -> None:
"""Record calls.""" """Handle callback."""
raise ValueError("This is a bad message callback") raise ValueError("This is a bad message callback")
await mqtt.async_subscribe(hass, "test-topic", bad_handler) await mqtt.async_subscribe(hass, "test-topic", bad_handler)
@@ -2926,6 +2948,40 @@ async def test_message_callback_exception_gets_logged(
) )
@pytest.mark.no_fail_on_log_exception
async def test_message_partial_callback_exception_gets_logged(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mqtt_mock_entry: MqttMockHAClientGenerator,
) -> None:
"""Test exception raised by message handler."""
await mqtt_mock_entry()
@callback
def bad_handler(msg: ReceiveMessage) -> None:
"""Handle callback."""
raise ValueError("This is a bad message callback")
def parial_handler(
msg_callback: MessageCallbackType,
attributes: set[str],
msg: ReceiveMessage,
) -> None:
"""Partial callback handler."""
msg_callback(msg)
await mqtt.async_subscribe(
hass, "test-topic", partial(parial_handler, bad_handler, {"some_attr"})
)
async_fire_mqtt_message(hass, "test-topic", "test")
await hass.async_block_till_done()
assert (
"Exception in bad_handler when handling msg on 'test-topic':"
" 'test'" in caplog.text
)
async def test_mqtt_ws_subscription( async def test_mqtt_ws_subscription(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
@@ -3787,7 +3843,7 @@ async def test_unload_config_entry(
async def test_publish_or_subscribe_without_valid_config_entry( async def test_publish_or_subscribe_without_valid_config_entry(
hass: HomeAssistant, record_calls: MessageCallbackType hass: HomeAssistant, record_calls: MessageCallbackType
) -> None: ) -> None:
"""Test internal publish function with bas use cases.""" """Test internal publish function with bad use cases."""
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
await mqtt.async_publish( await mqtt.async_publish(
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None

View File

@@ -11,8 +11,12 @@ from .const import DEFAULT_FORECAST, DEFAULT_OBSERVATION
@pytest.fixture @pytest.fixture
def mock_simple_nws(): def mock_simple_nws():
"""Mock pynws SimpleNWS with default values.""" """Mock pynws SimpleNWS with default values."""
# set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws: with (
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_nws.return_value instance = mock_nws.return_value
instance.set_station = AsyncMock(return_value=None) instance.set_station = AsyncMock(return_value=None)
instance.update_observation = AsyncMock(return_value=None) instance.update_observation = AsyncMock(return_value=None)
@@ -29,7 +33,12 @@ def mock_simple_nws():
@pytest.fixture @pytest.fixture
def mock_simple_nws_times_out(): def mock_simple_nws_times_out():
"""Mock pynws SimpleNWS that times out.""" """Mock pynws SimpleNWS that times out."""
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws: # set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
with (
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_nws.return_value instance = mock_nws.return_value
instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError) instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError)
instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError) instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError)

View File

@@ -1,7 +1,6 @@
"""Tests for the NWS weather component.""" """Tests for the NWS weather component."""
from datetime import timedelta from datetime import timedelta
from unittest.mock import patch
import aiohttp import aiohttp
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
@@ -24,7 +23,6 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
from .const import ( from .const import (
@@ -127,47 +125,43 @@ async def test_data_caching_error_observation(
caplog, caplog,
) -> None: ) -> None:
"""Test caching of data with errors.""" """Test caching of data with errors."""
with ( instance = mock_simple_nws.return_value
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_simple_nws.return_value
entry = MockConfigEntry( entry = MockConfigEntry(
domain=nws.DOMAIN, domain=nws.DOMAIN,
data=NWS_CONFIG, data=NWS_CONFIG,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("weather.abc") state = hass.states.get("weather.abc")
assert state.state == "sunny" assert state.state == "sunny"
# data is still valid even when update fails # data is still valid even when update fails
instance.update_observation.side_effect = NwsNoDataError("Test") instance.update_observation.side_effect = NwsNoDataError("Test")
freezer.tick(DEFAULT_SCAN_INTERVAL + timedelta(seconds=100)) freezer.tick(DEFAULT_SCAN_INTERVAL + timedelta(seconds=100))
async_fire_time_changed(hass) async_fire_time_changed(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("weather.abc") state = hass.states.get("weather.abc")
assert state.state == "sunny" assert state.state == "sunny"
assert ( assert (
"NWS observation update failed, but data still valid. Last success: " "NWS observation update failed, but data still valid. Last success: "
in caplog.text in caplog.text
) )
# data is no longer valid after OBSERVATION_VALID_TIME # data is no longer valid after OBSERVATION_VALID_TIME
freezer.tick(OBSERVATION_VALID_TIME + timedelta(seconds=1)) freezer.tick(OBSERVATION_VALID_TIME + timedelta(seconds=1))
async_fire_time_changed(hass) async_fire_time_changed(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("weather.abc") state = hass.states.get("weather.abc")
assert state.state == STATE_UNAVAILABLE assert state.state == STATE_UNAVAILABLE
assert "Error fetching NWS observation station ABC data: Test" in caplog.text assert "Error fetching NWS observation station ABC data: Test" in caplog.text
async def test_no_data_error_observation( async def test_no_data_error_observation(
@@ -302,26 +296,23 @@ async def test_error_observation(
hass: HomeAssistant, mock_simple_nws, no_sensor hass: HomeAssistant, mock_simple_nws, no_sensor
) -> None: ) -> None:
"""Test error during update observation.""" """Test error during update observation."""
utc_time = dt_util.utcnow() instance = mock_simple_nws.return_value
with patch("homeassistant.components.nws.coordinator.utcnow") as mock_utc: # first update fails
mock_utc.return_value = utc_time instance.update_observation.side_effect = aiohttp.ClientError
instance = mock_simple_nws.return_value
# first update fails
instance.update_observation.side_effect = aiohttp.ClientError
entry = MockConfigEntry( entry = MockConfigEntry(
domain=nws.DOMAIN, domain=nws.DOMAIN,
data=NWS_CONFIG, data=NWS_CONFIG,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
instance.update_observation.assert_called_once() instance.update_observation.assert_called_once()
state = hass.states.get("weather.abc") state = hass.states.get("weather.abc")
assert state assert state
assert state.state == STATE_UNAVAILABLE assert state.state == STATE_UNAVAILABLE
async def test_new_config_entry(hass: HomeAssistant, no_sensor) -> None: async def test_new_config_entry(hass: HomeAssistant, no_sensor) -> None:

View File

@@ -6,6 +6,7 @@ from ollama import Message, ResponseError
import pytest import pytest
from homeassistant.components import conversation, ollama from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
@@ -110,6 +111,19 @@ async def test_chat(
), result ), result
assert result.response.speech["plain"]["speech"] == "test response" assert result.response.speech["plain"]["speech"] == "test response"
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
async def test_message_history_trimming( async def test_message_history_trimming(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component

View File

@@ -9,9 +9,17 @@ import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.openai_conversation.const import ( from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
DEFAULT_CHAT_MODEL, CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_P,
) )
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@@ -75,7 +83,7 @@ async def test_options(
assert options["type"] is FlowResultType.CREATE_ENTRY assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"]["prompt"] == "Speak like a pirate" assert options["data"]["prompt"] == "Speak like a pirate"
assert options["data"]["max_tokens"] == 200 assert options["data"]["max_tokens"] == 200
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": error} assert result2["errors"] == {"base": error}
@pytest.mark.parametrize(
("current_options", "new_options", "expected_options"),
[
(
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "none",
CONF_PROMPT: "bla",
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
},
),
(
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
),
],
)
async def test_options_switching(
hass: HomeAssistant,
mock_config_entry,
mock_init_component,
current_options,
new_options,
expected_options,
) -> None:
"""Test the options form."""
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
options_flow = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
**current_options,
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
},
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
new_options,
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == expected_options

View File

@@ -15,6 +15,7 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -200,6 +201,20 @@ async def test_function_call(
), ),
) )
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
@patch( @patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"

View File

@@ -2,6 +2,7 @@
from unittest.mock import patch from unittest.mock import patch
from freezegun import freeze_time
from pyplaato.models.airlock import PlaatoAirlock from pyplaato.models.airlock import PlaatoAirlock
from pyplaato.models.device import PlaatoDeviceType from pyplaato.models.device import PlaatoDeviceType
from pyplaato.models.keg import PlaatoKeg from pyplaato.models.keg import PlaatoKeg
@@ -23,6 +24,7 @@ AIRLOCK_DATA = {}
KEG_DATA = {} KEG_DATA = {}
@freeze_time("2024-05-24 12:00:00", tz_offset=0)
async def init_integration( async def init_integration(
hass: HomeAssistant, device_type: PlaatoDeviceType hass: HomeAssistant, device_type: PlaatoDeviceType
) -> MockConfigEntry: ) -> MockConfigEntry:

View File

@@ -492,7 +492,6 @@ async def test_block_set_mode_auth_error(
{ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT}, {ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED

View File

@@ -227,7 +227,6 @@ async def test_block_set_value_auth_error(
{ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30}, {ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED

View File

@@ -618,7 +618,6 @@ async def test_rpc_sleeping_update_entity_service(
service_data={ATTR_ENTITY_ID: entity_id}, service_data={ATTR_ENTITY_ID: entity_id},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
# Entity should be available after update_entity service call # Entity should be available after update_entity service call
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
@@ -667,7 +666,6 @@ async def test_block_sleeping_update_entity_service(
service_data={ATTR_ENTITY_ID: entity_id}, service_data={ATTR_ENTITY_ID: entity_id},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
# Entity should be available after update_entity service call # Entity should be available after update_entity service call
state = hass.states.get(entity_id) state = hass.states.get(entity_id)

View File

@@ -230,7 +230,6 @@ async def test_block_set_state_auth_error(
{ATTR_ENTITY_ID: "switch.test_name_channel_1"}, {ATTR_ENTITY_ID: "switch.test_name_channel_1"},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
@@ -374,7 +373,6 @@ async def test_rpc_auth_error(
{ATTR_ENTITY_ID: "switch.test_switch_0"}, {ATTR_ENTITY_ID: "switch.test_switch_0"},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED

View File

@@ -207,7 +207,6 @@ async def test_block_update_auth_error(
{ATTR_ENTITY_ID: "update.test_name_firmware_update"}, {ATTR_ENTITY_ID: "update.test_name_firmware_update"},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
@@ -669,7 +668,6 @@ async def test_rpc_update_auth_error(
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress() flows = hass.config_entries.flow.async_progress()

View File

@@ -18,9 +18,15 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
@pytest.fixture @pytest.fixture
def mock_bridge(request): def mock_bridge(request):
"""Return a mocked SwitcherBridge.""" """Return a mocked SwitcherBridge."""
with patch( with (
"homeassistant.components.switcher_kis.utils.SwitcherBridge", autospec=True patch(
) as bridge_mock: "homeassistant.components.switcher_kis.SwitcherBridge", autospec=True
) as bridge_mock,
patch(
"homeassistant.components.switcher_kis.utils.SwitcherBridge",
new=bridge_mock,
),
):
bridge = bridge_mock.return_value bridge = bridge_mock.return_value
bridge.devices = [] bridge.devices = []

View File

@@ -4,11 +4,7 @@ from datetime import timedelta
import pytest import pytest
from homeassistant.components.switcher_kis.const import ( from homeassistant.components.switcher_kis.const import MAX_UPDATE_INTERVAL_SEC
DATA_DEVICE,
DOMAIN,
MAX_UPDATE_INTERVAL_SEC,
)
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import STATE_UNAVAILABLE from homeassistant.const import STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -24,15 +20,14 @@ async def test_update_fail(
hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test entities state unavailable when updates fail..""" """Test entities state unavailable when updates fail.."""
await init_integration(hass) entry = await init_integration(hass)
assert mock_bridge assert mock_bridge
mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES) mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_bridge.is_running is True assert mock_bridge.is_running is True
assert len(hass.data[DOMAIN]) == 2 assert len(entry.runtime_data) == 2
assert len(hass.data[DOMAIN][DATA_DEVICE]) == 2
async_fire_time_changed( async_fire_time_changed(
hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1) hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1)
@@ -77,11 +72,9 @@ async def test_entry_unload(hass: HomeAssistant, mock_bridge) -> None:
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
assert mock_bridge.is_running is True assert mock_bridge.is_running is True
assert len(hass.data[DOMAIN]) == 2
await hass.config_entries.async_unload(entry.entry_id) await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert entry.state is ConfigEntryState.NOT_LOADED assert entry.state is ConfigEntryState.NOT_LOADED
assert mock_bridge.is_running is False assert mock_bridge.is_running is False
assert len(hass.data[DOMAIN]) == 0

Some files were not shown because too many files have changed in this diff Show More