Merge branch 'dev' into jbouwh-mqtt-device-discovery

This commit is contained in:
J. Nick Koston
2024-05-25 11:39:47 -10:00
committed by GitHub
106 changed files with 3208 additions and 1980 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
import logging import logging
from typing import Any from typing import Any
@@ -20,6 +21,11 @@ from .models import (
ConversationInput, ConversationInput,
ConversationResult, ConversationResult,
) )
from .trace import (
ConversationTraceEvent,
ConversationTraceEventType,
async_conversation_trace,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -84,15 +90,23 @@ async def async_converse(
language = hass.config.language language = hass.config.language
_LOGGER.debug("Processing in %s: %s", language, text) _LOGGER.debug("Processing in %s: %s", language, text)
return await method( conversation_input = ConversationInput(
ConversationInput(
text=text, text=text,
context=context, context=context,
conversation_id=conversation_id, conversation_id=conversation_id,
device_id=device_id, device_id=device_id,
language=language, language=language,
) )
with async_conversation_trace() as trace:
trace.add_event(
ConversationTraceEvent(
ConversationTraceEventType.ASYNC_PROCESS,
dataclasses.asdict(conversation_input),
) )
)
result = await method(conversation_input)
trace.set_result(**result.as_dict())
return result
class AgentManager: class AgentManager:

View File

@@ -0,0 +1,118 @@
"""Debug traces for conversation."""
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import asdict, dataclass, field
import enum
from typing import Any
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.limited_size_dict import LimitedSizeDict
STORED_TRACES = 3
class ConversationTraceEventType(enum.StrEnum):
"""Type of an event emitted during a conversation."""
ASYNC_PROCESS = "async_process"
"""The conversation is started from user input."""
AGENT_DETAIL = "agent_detail"
"""Event detail added by a conversation agent."""
LLM_TOOL_CALL = "llm_tool_call"
"""An LLM Tool call"""
@dataclass(frozen=True)
class ConversationTraceEvent:
"""Event emitted during a conversation."""
event_type: ConversationTraceEventType
data: dict[str, Any] | None = None
timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat())
class ConversationTrace:
"""Stores debug data related to a conversation."""
def __init__(self) -> None:
"""Initialize ConversationTrace."""
self._trace_id = ulid_util.ulid_now()
self._events: list[ConversationTraceEvent] = []
self._error: Exception | None = None
self._result: dict[str, Any] = {}
@property
def trace_id(self) -> str:
"""Identifier for this trace."""
return self._trace_id
def add_event(self, event: ConversationTraceEvent) -> None:
"""Add an event to the trace."""
self._events.append(event)
def set_error(self, ex: Exception) -> None:
"""Set error."""
self._error = ex
def set_result(self, **kwargs: Any) -> None:
"""Set result."""
self._result = {**kwargs}
def as_dict(self) -> dict[str, Any]:
"""Return dictionary version of this ConversationTrace."""
result: dict[str, Any] = {
"id": self._trace_id,
"events": [asdict(event) for event in self._events],
}
if self._error is not None:
result["error"] = str(self._error) or self._error.__class__.__name__
if self._result is not None:
result["result"] = self._result
return result
_current_trace: ContextVar[ConversationTrace | None] = ContextVar(
"current_trace", default=None
)
_recent_traces: LimitedSizeDict[str, ConversationTrace] = LimitedSizeDict(
size_limit=STORED_TRACES
)
def async_conversation_trace_append(
event_type: ConversationTraceEventType, event_data: dict[str, Any]
) -> None:
"""Append a ConversationTraceEvent to the current active trace."""
trace = _current_trace.get()
if not trace:
return
trace.add_event(ConversationTraceEvent(event_type, event_data))
@contextmanager
def async_conversation_trace() -> Generator[ConversationTrace, None]:
"""Create a new active ConversationTrace."""
trace = ConversationTrace()
token = _current_trace.set(trace)
_recent_traces[trace.trace_id] = trace
try:
yield trace
except Exception as ex:
trace.set_error(ex)
raise
finally:
_current_trace.reset(token)
def async_get_traces() -> list[ConversationTrace]:
"""Get the most recent traces."""
return list(_recent_traces.values())
def async_clear_traces() -> None:
"""Clear all traces."""
_recent_traces.clear()

View File

@@ -5,5 +5,5 @@
"documentation": "https://www.home-assistant.io/integrations/envisalink", "documentation": "https://www.home-assistant.io/integrations/envisalink",
"iot_class": "local_push", "iot_class": "local_push",
"loggers": ["pyenvisalink"], "loggers": ["pyenvisalink"],
"requirements": ["pyenvisalink==4.6"] "requirements": ["pyenvisalink==4.7"]
} }

View File

@@ -1,6 +1,5 @@
"""Support for Arduino-compatible Microcontrollers through Firmata.""" """Support for Arduino-compatible Microcontrollers through Firmata."""
import asyncio
from copy import copy from copy import copy
import logging import logging
@@ -212,16 +211,15 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Shutdown and close a Firmata board for a config entry.""" """Shutdown and close a Firmata board for a config entry."""
_LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME]) _LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME])
results: list[bool] = []
unload_entries = [] if platforms := [
for conf, platform in CONF_PLATFORM_MAP.items(): platform
if conf in config_entry.data: for conf, platform in CONF_PLATFORM_MAP.items()
unload_entries.append( if conf in config_entry.data
hass.config_entries.async_forward_entry_unload(config_entry, platform) ]:
results.append(
await hass.config_entries.async_unload_platforms(config_entry, platforms)
) )
results = []
if unload_entries:
results = await asyncio.gather(*unload_entries)
results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset()) results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset())
return False not in results return False not in results

View File

@@ -11,12 +11,13 @@ from .const import (
CONF_DAMPING_EVENING, CONF_DAMPING_EVENING,
CONF_DAMPING_MORNING, CONF_DAMPING_MORNING,
CONF_MODULES_POWER, CONF_MODULES_POWER,
DOMAIN,
) )
from .coordinator import ForecastSolarDataUpdateCoordinator from .coordinator import ForecastSolarDataUpdateCoordinator
PLATFORMS = [Platform.SENSOR] PLATFORMS = [Platform.SENSOR]
type ForecastSolarConfigEntry = ConfigEntry[ForecastSolarDataUpdateCoordinator]
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Migrate old config entry.""" """Migrate old config entry."""
@@ -36,12 +37,14 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(
hass: HomeAssistant, entry: ForecastSolarConfigEntry
) -> bool:
"""Set up Forecast.Solar from a config entry.""" """Set up Forecast.Solar from a config entry."""
coordinator = ForecastSolarDataUpdateCoordinator(hass, entry) coordinator = ForecastSolarDataUpdateCoordinator(hass, entry)
await coordinator.async_config_entry_first_refresh() await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator entry.runtime_data = coordinator
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
@@ -52,11 +55,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok
async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None: async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None:

View File

@@ -4,15 +4,11 @@ from __future__ import annotations
from typing import Any from typing import Any
from forecast_solar import Estimate
from homeassistant.components.diagnostics import async_redact_data from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from .const import DOMAIN from . import ForecastSolarConfigEntry
TO_REDACT = { TO_REDACT = {
CONF_API_KEY, CONF_API_KEY,
@@ -22,10 +18,10 @@ TO_REDACT = {
async def async_get_config_entry_diagnostics( async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: ForecastSolarConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
coordinator: DataUpdateCoordinator[Estimate] = hass.data[DOMAIN][entry.entry_id] coordinator = entry.runtime_data
return { return {
"entry": { "entry": {

View File

@@ -4,19 +4,21 @@ from __future__ import annotations
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import DOMAIN from .coordinator import ForecastSolarDataUpdateCoordinator
async def async_get_solar_forecast( async def async_get_solar_forecast(
hass: HomeAssistant, config_entry_id: str hass: HomeAssistant, config_entry_id: str
) -> dict[str, dict[str, float | int]] | None: ) -> dict[str, dict[str, float | int]] | None:
"""Get solar forecast for a config entry ID.""" """Get solar forecast for a config entry ID."""
if (coordinator := hass.data[DOMAIN].get(config_entry_id)) is None: if (
entry := hass.config_entries.async_get_entry(config_entry_id)
) is None or not isinstance(entry.runtime_data, ForecastSolarDataUpdateCoordinator):
return None return None
return { return {
"wh_hours": { "wh_hours": {
timestamp.isoformat(): val timestamp.isoformat(): val
for timestamp, val in coordinator.data.wh_period.items() for timestamp, val in entry.runtime_data.data.wh_period.items()
} }
} }

View File

@@ -16,7 +16,6 @@ from homeassistant.components.sensor import (
SensorEntityDescription, SensorEntityDescription,
SensorStateClass, SensorStateClass,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import UnitOfEnergy, UnitOfPower from homeassistant.const import UnitOfEnergy, UnitOfPower
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
@@ -24,6 +23,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import ForecastSolarConfigEntry
from .const import DOMAIN from .const import DOMAIN
from .coordinator import ForecastSolarDataUpdateCoordinator from .coordinator import ForecastSolarDataUpdateCoordinator
@@ -133,10 +133,12 @@ SENSORS: tuple[ForecastSolarSensorEntityDescription, ...] = (
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback hass: HomeAssistant,
entry: ForecastSolarConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Defer sensor setup to the shared sensor module.""" """Defer sensor setup to the shared sensor module."""
coordinator: ForecastSolarDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id] coordinator = entry.runtime_data
async_add_entities( async_add_entities(
ForecastSolarSensorEntity( ForecastSolarSensorEntity(

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
from typing import Final from typing import Any, Final
from homeassistant.components.button import ( from homeassistant.components.button import (
ButtonDeviceClass, ButtonDeviceClass,
@@ -30,7 +30,7 @@ _LOGGER = logging.getLogger(__name__)
class FritzButtonDescription(ButtonEntityDescription): class FritzButtonDescription(ButtonEntityDescription):
"""Class to describe a Button entity.""" """Class to describe a Button entity."""
press_action: Callable press_action: Callable[[AvmWrapper], Any]
BUTTONS: Final = [ BUTTONS: Final = [

View File

@@ -57,9 +57,6 @@ ERROR_UPNP_NOT_CONFIGURED = "upnp_not_configured"
ERROR_UNKNOWN = "unknown_error" ERROR_UNKNOWN = "unknown_error"
FRITZ_SERVICES = "fritz_services" FRITZ_SERVICES = "fritz_services"
SERVICE_REBOOT = "reboot"
SERVICE_RECONNECT = "reconnect"
SERVICE_CLEANUP = "cleanup"
SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password" SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password"
SWITCH_TYPE_DEFLECTION = "CallDeflection" SWITCH_TYPE_DEFLECTION = "CallDeflection"

View File

@@ -46,9 +46,6 @@ from .const import (
DEFAULT_USERNAME, DEFAULT_USERNAME,
DOMAIN, DOMAIN,
FRITZ_EXCEPTIONS, FRITZ_EXCEPTIONS,
SERVICE_CLEANUP,
SERVICE_REBOOT,
SERVICE_RECONNECT,
SERVICE_SET_GUEST_WIFI_PW, SERVICE_SET_GUEST_WIFI_PW,
MeshRoles, MeshRoles,
) )
@@ -730,30 +727,6 @@ class FritzBoxTools(DataUpdateCoordinator[UpdateCoordinatorDataType]):
) )
try: try:
if service_call.service == SERVICE_REBOOT:
_LOGGER.warning(
'Service "fritz.reboot" is deprecated, please use the corresponding'
" button entity instead"
)
await self.async_trigger_reboot()
return
if service_call.service == SERVICE_RECONNECT:
_LOGGER.warning(
'Service "fritz.reconnect" is deprecated, please use the'
" corresponding button entity instead"
)
await self.async_trigger_reconnect()
return
if service_call.service == SERVICE_CLEANUP:
_LOGGER.warning(
'Service "fritz.cleanup" is deprecated, please use the'
" corresponding button entity instead"
)
await self.async_trigger_cleanup(config_entry)
return
if service_call.service == SERVICE_SET_GUEST_WIFI_PW: if service_call.service == SERVICE_SET_GUEST_WIFI_PW:
await self.async_trigger_set_guest_password( await self.async_trigger_set_guest_password(
service_call.data.get("password"), service_call.data.get("password"),

View File

@@ -11,14 +11,7 @@ from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.service import async_extract_config_entry_ids from homeassistant.helpers.service import async_extract_config_entry_ids
from .const import ( from .const import DOMAIN, FRITZ_SERVICES, SERVICE_SET_GUEST_WIFI_PW
DOMAIN,
FRITZ_SERVICES,
SERVICE_CLEANUP,
SERVICE_REBOOT,
SERVICE_RECONNECT,
SERVICE_SET_GUEST_WIFI_PW,
)
from .coordinator import AvmWrapper from .coordinator import AvmWrapper
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -32,9 +25,6 @@ SERVICE_SCHEMA_SET_GUEST_WIFI_PW = vol.Schema(
) )
SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [ SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [
(SERVICE_CLEANUP, None),
(SERVICE_REBOOT, None),
(SERVICE_RECONNECT, None),
(SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW), (SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW),
] ]

View File

@@ -1,31 +1,3 @@
reconnect:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
reboot:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
cleanup:
fields:
device_id:
required: true
selector:
device:
integration: fritz
entity:
device_class: connectivity
set_guest_wifi_password: set_guest_wifi_password:
fields: fields:
device_id: device_id:

View File

@@ -144,42 +144,12 @@
} }
}, },
"services": { "services": {
"reconnect": {
"name": "[%key:component::fritz::entity::button::reconnect::name%]",
"description": "Reconnects your FRITZ!Box internet connection.",
"fields": {
"device_id": {
"name": "Fritz!Box Device",
"description": "Select the Fritz!Box to reconnect."
}
}
},
"reboot": {
"name": "Reboot",
"description": "Reboots your FRITZ!Box.",
"fields": {
"device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
"description": "Select the Fritz!Box to reboot."
}
}
},
"cleanup": {
"name": "Remove stale device tracker entities",
"description": "Remove FRITZ!Box stale device_tracker entities.",
"fields": {
"device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]",
"description": "Select the Fritz!Box to check."
}
}
},
"set_guest_wifi_password": { "set_guest_wifi_password": {
"name": "Set guest Wi-Fi password", "name": "Set guest Wi-Fi password",
"description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.", "description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.",
"fields": { "fields": {
"device_id": { "device_id": {
"name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]", "name": "Fritz!Box Device",
"description": "Select the Fritz!Box to configure." "description": "Select the Fritz!Box to configure."
}, },
"password": { "password": {

View File

@@ -0,0 +1,46 @@
"""Diagnostics support for Fronius."""
from typing import Any
from homeassistant.components.diagnostics import async_redact_data
from homeassistant.core import HomeAssistant
from . import FroniusConfigEntry
TO_REDACT = {"unique_id", "unique_identifier", "serial"}
async def async_get_config_entry_diagnostics(
hass: HomeAssistant, config_entry: FroniusConfigEntry
) -> dict[str, Any]:
"""Return diagnostics for a config entry."""
diag: dict[str, Any] = {}
solar_net = config_entry.runtime_data
fronius = solar_net.fronius
diag["config_entry"] = config_entry.as_dict()
diag["inverter_info"] = await fronius.inverter_info()
diag["coordinators"] = {"inverters": {}}
for inv in solar_net.inverter_coordinators:
diag["coordinators"]["inverters"] |= inv.data
diag["coordinators"]["logger"] = (
solar_net.logger_coordinator.data if solar_net.logger_coordinator else None
)
diag["coordinators"]["meter"] = (
solar_net.meter_coordinator.data if solar_net.meter_coordinator else None
)
diag["coordinators"]["ohmpilot"] = (
solar_net.ohmpilot_coordinator.data if solar_net.ohmpilot_coordinator else None
)
diag["coordinators"]["power_flow"] = (
solar_net.power_flow_coordinator.data
if solar_net.power_flow_coordinator
else None
)
diag["coordinators"]["storage"] = (
solar_net.storage_coordinator.data if solar_net.storage_coordinator else None
)
return async_redact_data(diag, TO_REDACT)

View File

@@ -181,8 +181,7 @@ async def google_generative_ai_config_option_schema(
schema = { schema = {
vol.Optional( vol.Optional(
CONF_PROMPT, CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)}, description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(), ): TemplateSelector(),
vol.Optional( vol.Optional(
CONF_LLM_HASS_API, CONF_LLM_HASS_API,

View File

@@ -22,4 +22,4 @@ CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold"
CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold" CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold"
CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold" CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold"
CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold" CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold"
RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_LOW_AND_ABOVE" RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE"

View File

@@ -5,13 +5,14 @@ from __future__ import annotations
from typing import Any, Literal from typing import Any, Literal
import google.ai.generativelanguage as glm import google.ai.generativelanguage as glm
from google.api_core.exceptions import ClientError from google.api_core.exceptions import GoogleAPICallError
import google.generativeai as genai import google.generativeai as genai
import google.generativeai.types as genai_types import google.generativeai.types as genai_types
import voluptuous as vol import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -205,15 +206,6 @@ class GoogleGenerativeAIConversationEntity(
messages = [{}, {}] messages = [{}, {}]
try: try:
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api: if llm_api:
empty_tool_input = llm.ToolInput( empty_tool_input = llm.ToolInput(
tool_name="", tool_name="",
@@ -226,8 +218,23 @@ class GoogleGenerativeAIConversationEntity(
device_id=user_input.device_id, device_id=user_input.device_id,
) )
prompt = ( api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
else:
api_prompt = llm.PROMPT_NO_API_CONFIGURED
prompt = "\n".join(
(
template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
),
api_prompt,
)
) )
except TemplateError as err: except TemplateError as err:
@@ -244,6 +251,9 @@ class GoogleGenerativeAIConversationEntity(
messages[1] = {"role": "model", "parts": "Ok"} messages[1] = {"role": "model", "parts": "Ok"}
LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
chat = model.start_chat(history=messages) chat = model.start_chat(history=messages)
chat_request = user_input.text chat_request = user_input.text
@@ -252,15 +262,25 @@ class GoogleGenerativeAIConversationEntity(
try: try:
chat_response = await chat.send_message_async(chat_request) chat_response = await chat.send_message_async(chat_request)
except ( except (
ClientError, GoogleAPICallError,
ValueError, ValueError,
genai_types.BlockedPromptException, genai_types.BlockedPromptException,
genai_types.StopCandidateException, genai_types.StopCandidateException,
) as err: ) as err:
LOGGER.error("Error sending message: %s", err) LOGGER.error("Error sending message: %s %s", type(err), err)
if isinstance(
err, genai_types.StopCandidateException
) and "finish_reason: SAFETY\n" in str(err):
error = "The message got blocked by your safety settings"
else:
error = (
f"Sorry, I had a problem talking to Google Generative AI: {err}"
)
intent_response.async_set_error( intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to Google Generative AI: {err}", error,
) )
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=conversation_id

View File

@@ -1,6 +1,6 @@
{ {
"domain": "integration", "domain": "integration",
"name": "Integration - Riemann sum integral", "name": "Integral",
"after_dependencies": ["counter"], "after_dependencies": ["counter"],
"codeowners": ["@dgomes"], "codeowners": ["@dgomes"],
"config_flow": true, "config_flow": true,

View File

@@ -1,5 +1,5 @@
{ {
"title": "Integration - Riemann sum integral sensor", "title": "Integral sensor",
"config": { "config": {
"step": { "step": {
"user": { "user": {

View File

@@ -21,7 +21,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
CONF_VALIDATOR = "validator" CONF_VALIDATOR = "validator"
CONF_SECRET = "secret" CONF_SECRET = "secret"
URL = "/api/meraki" URL = "/api/meraki"
VERSION = "2.0" ACCEPTED_VERSIONS = ["2.0", "2.1"]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -74,7 +74,7 @@ class MerakiView(HomeAssistantView):
if data["secret"] != self.secret: if data["secret"] != self.secret:
_LOGGER.error("Invalid Secret received from Meraki") _LOGGER.error("Invalid Secret received from Meraki")
return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY) return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY)
if data["version"] != VERSION: if data["version"] not in ACCEPTED_VERSIONS:
_LOGGER.error("Invalid API version: %s", data["version"]) _LOGGER.error("Invalid API version: %s", data["version"])
return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY) return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY)
_LOGGER.debug("Valid Secret") _LOGGER.debug("Valid Secret")

View File

@@ -6,6 +6,6 @@
"documentation": "https://www.home-assistant.io/integrations/minecraft_server", "documentation": "https://www.home-assistant.io/integrations/minecraft_server",
"iot_class": "local_polling", "iot_class": "local_polling",
"loggers": ["dnspython", "mcstatus"], "loggers": ["dnspython", "mcstatus"],
"quality_scale": "gold", "quality_scale": "platinum",
"requirements": ["mcstatus==11.1.1"] "requirements": ["mcstatus==11.1.1"]
} }

View File

@@ -39,6 +39,7 @@ from .client import ( # noqa: F401
MQTT, MQTT,
async_publish, async_publish,
async_subscribe, async_subscribe,
async_subscribe_internal,
publish, publish,
subscribe, subscribe,
) )
@@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
def collect_msg(msg: ReceiveMessage) -> None: def collect_msg(msg: ReceiveMessage) -> None:
messages.append((msg.topic, str(msg.payload).replace("\n", ""))) messages.append((msg.topic, str(msg.payload).replace("\n", "")))
unsub = await async_subscribe(hass, call.data["topic"], collect_msg) unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg)
def write_dump() -> None: def write_dump() -> None:
with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp: with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp:
@@ -459,7 +460,7 @@ async def websocket_subscribe(
# Perform UTF-8 decoding directly in callback routine # Perform UTF-8 decoding directly in callback routine
qos: int = msg.get("qos", DEFAULT_QOS) qos: int = msg.get("qos", DEFAULT_QOS)
connection.subscriptions[msg["id"]] = await async_subscribe( connection.subscriptions[msg["id"]] = async_subscribe_internal(
hass, msg["topic"], forward_messages, encoding=None, qos=qos hass, msg["topic"], forward_messages, encoding=None, qos=qos
) )
@@ -522,24 +523,13 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
mqtt_client = mqtt_data.client mqtt_client = mqtt_data.client
# Unload publish and dump services. # Unload publish and dump services.
hass.services.async_remove( hass.services.async_remove(DOMAIN, SERVICE_PUBLISH)
DOMAIN, hass.services.async_remove(DOMAIN, SERVICE_DUMP)
SERVICE_PUBLISH,
)
hass.services.async_remove(
DOMAIN,
SERVICE_DUMP,
)
# Stop the discovery # Stop the discovery
await discovery.async_stop(hass) await discovery.async_stop(hass)
# Unload the platforms # Unload the platforms
await asyncio.gather( await hass.config_entries.async_unload_platforms(entry, mqtt_data.platforms_loaded)
*(
hass.config_entries.async_forward_entry_unload(entry, component)
for component in mqtt_data.platforms_loaded
)
)
mqtt_data.platforms_loaded = set() mqtt_data.platforms_loaded = set()
await asyncio.sleep(0) await asyncio.sleep(0)
# Unsubscribe reload dispatchers # Unsubscribe reload dispatchers

View File

@@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_alarm_disarm(self, code: str | None = None) -> None: async def async_alarm_disarm(self, code: str | None = None) -> None:
"""Send disarm command. """Send disarm command.

View File

@@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@callback @callback
def _value_is_expired(self, *_: Any) -> None: def _value_is_expired(self, *_: Any) -> None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from base64 import b64decode from base64 import b64decode
from functools import partial
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -20,7 +21,6 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_QOS, CONF_TOPIC from .const import CONF_QOS, CONF_TOPIC
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ReceiveMessage from .models import ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@@ -97,12 +97,8 @@ class MqttCamera(MqttEntity, Camera):
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _image_received(self, msg: ReceiveMessage) -> None:
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
if CONF_IMAGE_ENCODING in self._config: if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload) self._last_image = b64decode(msg.payload)
@@ -111,13 +107,21 @@ class MqttCamera(MqttEntity, Camera):
assert isinstance(msg.payload, bytes) assert isinstance(msg.payload, bytes)
self._last_image = msg.payload self._last_image = msg.payload
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config[CONF_TOPIC], "topic": self._config[CONF_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._image_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": None, "encoding": None,
} }
@@ -126,7 +130,7 @@ class MqttCamera(MqttEntity, Camera):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_camera_image( async def async_camera_image(
self, width: int | None = None, height: int | None = None self, width: int | None = None, height: int | None = None

View File

@@ -77,7 +77,6 @@ from .const import (
) )
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
AsyncMessageCallbackType,
MessageCallbackType, MessageCallbackType,
MqttData, MqttData,
PublishMessage, PublishMessage,
@@ -184,7 +183,7 @@ async def async_publish(
async def async_subscribe( async def async_subscribe(
hass: HomeAssistant, hass: HomeAssistant,
topic: str, topic: str,
msg_callback: AsyncMessageCallbackType | MessageCallbackType, msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS, qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING, encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE: ) -> CALLBACK_TYPE:
@@ -192,13 +191,25 @@ async def async_subscribe(
Call the return value to unsubscribe. Call the return value to unsubscribe.
""" """
if not mqtt_config_entry_enabled(hass): return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
translation_key="mqtt_not_setup_cannot_subscribe", @callback
translation_domain=DOMAIN, def async_subscribe_internal(
translation_placeholders={"topic": topic}, hass: HomeAssistant,
) topic: str,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic.
This function is internal to the MQTT integration
and may change at any time. It should not be considered
a stable API.
Call the return value to unsubscribe.
"""
try: try:
mqtt_data = hass.data[DATA_MQTT] mqtt_data = hass.data[DATA_MQTT]
except KeyError as exc: except KeyError as exc:
@@ -209,12 +220,15 @@ async def async_subscribe(
translation_domain=DOMAIN, translation_domain=DOMAIN,
translation_placeholders={"topic": topic}, translation_placeholders={"topic": topic},
) from exc ) from exc
return await mqtt_data.client.async_subscribe( client = mqtt_data.client
topic, if not client.connected and not mqtt_config_entry_enabled(hass):
msg_callback, raise HomeAssistantError(
qos, f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
encoding, translation_key="mqtt_not_setup_cannot_subscribe",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
) )
return client.async_subscribe(topic, msg_callback, qos, encoding)
@bind_hass @bind_hass
@@ -429,10 +443,10 @@ class MQTT:
self.config_entry = config_entry self.config_entry = config_entry
self.conf = conf self.conf = conf
self._simple_subscriptions: defaultdict[str, list[Subscription]] = defaultdict( self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict(
list set
) )
self._wildcard_subscriptions: list[Subscription] = [] self._wildcard_subscriptions: set[Subscription] = set()
# _retained_topics prevents a Subscription from receiving a # _retained_topics prevents a Subscription from receiving a
# retained message more than once per topic. This prevents flooding # retained message more than once per topic. This prevents flooding
# already active subscribers when new subscribers subscribe to a topic # already active subscribers when new subscribers subscribe to a topic
@@ -452,7 +466,7 @@ class MQTT:
self._should_reconnect: bool = True self._should_reconnect: bool = True
self._available_future: asyncio.Future[bool] | None = None self._available_future: asyncio.Future[bool] | None = None
self._max_qos: dict[str, int] = {} # topic, max qos self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos
self._pending_subscriptions: dict[str, int] = {} # topic, qos self._pending_subscriptions: dict[str, int] = {} # topic, qos
self._unsubscribe_debouncer = EnsureJobAfterCooldown( self._unsubscribe_debouncer = EnsureJobAfterCooldown(
UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes
@@ -789,9 +803,9 @@ class MQTT:
The caller is responsible clearing the cache of _matching_subscriptions. The caller is responsible clearing the cache of _matching_subscriptions.
""" """
if subscription.is_simple_match: if subscription.is_simple_match:
self._simple_subscriptions[subscription.topic].append(subscription) self._simple_subscriptions[subscription.topic].add(subscription)
else: else:
self._wildcard_subscriptions.append(subscription) self._wildcard_subscriptions.add(subscription)
@callback @callback
def _async_untrack_subscription(self, subscription: Subscription) -> None: def _async_untrack_subscription(self, subscription: Subscription) -> None:
@@ -820,8 +834,8 @@ class MQTT:
"""Queue requested subscriptions.""" """Queue requested subscriptions."""
for subscription in subscriptions: for subscription in subscriptions:
topic, qos = subscription topic, qos = subscription
max_qos = max(qos, self._max_qos.setdefault(topic, qos)) if (max_qos := self._max_qos[topic]) < qos:
self._max_qos[topic] = max_qos self._max_qos[topic] = (max_qos := qos)
self._pending_subscriptions[topic] = max_qos self._pending_subscriptions[topic] = max_qos
# Cancel any pending unsubscribe since we are subscribing now # Cancel any pending unsubscribe since we are subscribing now
if topic in self._pending_unsubscribes: if topic in self._pending_unsubscribes:
@@ -832,26 +846,29 @@ class MQTT:
def _exception_message( def _exception_message(
self, self,
msg_callback: AsyncMessageCallbackType | MessageCallbackType, msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
msg: ReceiveMessage, msg: ReceiveMessage,
) -> str: ) -> str:
"""Return a string with the exception message.""" """Return a string with the exception message."""
# if msg_callback is a partial we return the name of the first argument
if isinstance(msg_callback, partial):
call_back_name = getattr(msg_callback.args[0], "__name__") # type: ignore[unreachable]
else:
call_back_name = getattr(msg_callback, "__name__")
return ( return (
f"Exception in {msg_callback.__name__} when handling msg on " f"Exception in {call_back_name} when handling msg on "
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe] f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
) )
async def async_subscribe( @callback
def async_subscribe(
self, self,
topic: str, topic: str,
msg_callback: AsyncMessageCallbackType | MessageCallbackType, msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int, qos: int,
encoding: str | None = None, encoding: str | None = None,
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Set up a subscription to a topic with the provided qos. """Set up a subscription to a topic with the provided qos."""
This method is a coroutine.
"""
if not isinstance(topic, str): if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!") raise HomeAssistantError("Topic needs to be a string!")
@@ -877,8 +894,10 @@ class MQTT:
if self.connected: if self.connected:
self._async_queue_subscriptions(((topic, qos),)) self._async_queue_subscriptions(((topic, qos),))
return partial(self._async_remove, subscription)
@callback @callback
def async_remove() -> None: def _async_remove(self, subscription: Subscription) -> None:
"""Remove subscription.""" """Remove subscription."""
self._async_untrack_subscription(subscription) self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear() self._matching_subscriptions.cache_clear()
@@ -886,9 +905,7 @@ class MQTT:
del self._retained_topics[subscription] del self._retained_topics[subscription]
# Only unsubscribe if currently connected # Only unsubscribe if currently connected
if self.connected: if self.connected:
self._async_unsubscribe(topic) self._async_unsubscribe(subscription.topic)
return async_remove
@callback @callback
def _async_unsubscribe(self, topic: str) -> None: def _async_unsubscribe(self, topic: str) -> None:
@@ -1257,9 +1274,7 @@ class MQTT:
last_discovery = self._mqtt_data.last_discovery last_discovery = self._mqtt_data.last_discovery
last_subscribe = now if self._pending_subscriptions else self._last_subscribe last_subscribe = now if self._pending_subscriptions else self._last_subscribe
wait_until = max( wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
while now < wait_until: while now < wait_until:
await asyncio.sleep(wait_until - now) await asyncio.sleep(wait_until - now)
now = time.monotonic() now = time.monotonic()
@@ -1267,9 +1282,7 @@ class MQTT:
last_subscribe = ( last_subscribe = (
now if self._pending_subscriptions else self._last_subscribe now if self._pending_subscriptions else self._last_subscribe
) )
wait_until = max( wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN
last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN
)
def _matcher_for_topic(subscription: str) -> Callable[[str], bool]: def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:

View File

@@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def _publish(self, topic: str, payload: PublishPayloadType) -> None: async def _publish(self, topic: str, payload: PublishPayloadType) -> None:
if self._topic[topic] is not None: if self._topic[topic] is not None:

View File

@@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_open_cover(self, **kwargs: Any) -> None: async def async_open_cover(self, **kwargs: Any) -> None:
"""Move the cover up. """Move the cover up.

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -32,13 +33,7 @@ from homeassistant.helpers.typing import ConfigType
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC
from .debug_info import log_messages from .mixins import CONF_JSON_ATTRS_TOPIC, MqttEntity, async_setup_entity_entry_helper
from .mixins import (
CONF_JSON_ATTRS_TOPIC,
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_subscribe_topic from .util import valid_subscribe_topic
@@ -119,13 +114,8 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
config.get(CONF_VALUE_TEMPLATE), entity=self config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _tracker_message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_location_name"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore if not payload.strip(): # No output from template, ignore
@@ -146,6 +136,9 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
assert isinstance(msg.payload, str) assert isinstance(msg.payload, str)
self._location_name = msg.payload self._location_name = msg.payload
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
state_topic: str | None = self._config.get(CONF_STATE_TOPIC) state_topic: str | None = self._config.get(CONF_STATE_TOPIC)
if state_topic is None: if state_topic is None:
return return
@@ -155,7 +148,12 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
{ {
"state_topic": { "state_topic": {
"topic": state_topic, "topic": state_topic,
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._tracker_message_received,
{"_location_name"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
} }
}, },
@@ -168,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property @property
def latitude(self) -> float | None: def latitude(self) -> float | None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any from typing import Any
@@ -31,7 +32,6 @@ from .const import (
PAYLOAD_EMPTY_JSON, PAYLOAD_EMPTY_JSON,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
@@ -113,13 +113,8 @@ class MqttEvent(MqttEntity, EventEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
@callback @callback
@log_messages(self.hass, self.entity_id) def _event_received(self, msg: ReceiveMessage) -> None:
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
if msg.retain: if msg.retain:
_LOGGER.debug( _LOGGER.debug(
@@ -161,10 +156,7 @@ class MqttEvent(MqttEntity, EventEntity):
) )
except KeyError: except KeyError:
_LOGGER.warning( _LOGGER.warning(
( ("`event_type` missing in JSON event payload, " " '%s' on topic %s"),
"`event_type` missing in JSON event payload, "
" '%s' on topic %s"
),
payload, payload,
msg.topic, msg.topic,
) )
@@ -194,9 +186,18 @@ class MqttEvent(MqttEntity, EventEntity):
mqtt_data = self.hass.data[DATA_MQTT] mqtt_data = self.hass.data[DATA_MQTT]
mqtt_data.state_write_requests.write_state_request(self) mqtt_data.state_write_requests.write_state_request(self)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
topics["state_topic"] = { topics["state_topic"] = {
"topic": self._config[CONF_STATE_TOPIC], "topic": self._config[CONF_STATE_TOPIC],
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._event_received,
None,
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -207,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import math import math
from typing import Any from typing import Any
@@ -49,12 +50,7 @@ from .const import (
CONF_STATE_VALUE_TEMPLATE, CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MessageCallbackType, MessageCallbackType,
MqttCommandTemplate, MqttCommandTemplate,
@@ -338,25 +334,8 @@ class MqttFan(MqttEntity, FanEntity):
for key, tpl in value_templates.items() for key, tpl in value_templates.items()
} }
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
"""Add a topic to subscribe to."""
if has_topic := self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": msg_callback,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
return has_topic
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message.""" """Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload) payload = self._value_templates[CONF_STATE](msg.payload)
if not payload: if not payload:
@@ -369,12 +348,8 @@ class MqttFan(MqttEntity, FanEntity):
elif payload == PAYLOAD_NONE: elif payload == PAYLOAD_NONE:
self._attr_is_on = None self._attr_is_on = None
add_subscribe_topic(CONF_STATE_TOPIC, state_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _percentage_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_percentage"})
def percentage_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the percentage.""" """Handle new received MQTT message for the percentage."""
rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE]( rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE](
msg.payload msg.payload
@@ -413,12 +388,8 @@ class MqttFan(MqttEntity, FanEntity):
return return
self._attr_percentage = percentage self._attr_percentage = percentage
add_subscribe_topic(CONF_PERCENTAGE_STATE_TOPIC, percentage_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _preset_mode_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_preset_mode"})
def preset_mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for preset mode.""" """Handle new received MQTT message for preset mode."""
preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload)) preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload))
if preset_mode == self._payload["PRESET_MODE_RESET"]: if preset_mode == self._payload["PRESET_MODE_RESET"]:
@@ -438,12 +409,8 @@ class MqttFan(MqttEntity, FanEntity):
self._attr_preset_mode = preset_mode self._attr_preset_mode = preset_mode
add_subscribe_topic(CONF_PRESET_MODE_STATE_TOPIC, preset_mode_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _oscillation_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_oscillating"})
def oscillation_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the oscillation.""" """Handle new received MQTT message for the oscillation."""
payload = self._value_templates[ATTR_OSCILLATING](msg.payload) payload = self._value_templates[ATTR_OSCILLATING](msg.payload)
if not payload: if not payload:
@@ -454,13 +421,8 @@ class MqttFan(MqttEntity, FanEntity):
elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]: elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]:
self._attr_oscillating = False self._attr_oscillating = False
if add_subscribe_topic(CONF_OSCILLATION_STATE_TOPIC, oscillation_received):
self._attr_oscillating = False
@callback @callback
@log_messages(self.hass, self.entity_id) def _direction_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_current_direction"})
def direction_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the direction.""" """Handle new received MQTT message for the direction."""
direction = self._value_templates[ATTR_DIRECTION](msg.payload) direction = self._value_templates[ATTR_DIRECTION](msg.payload)
if not direction: if not direction:
@@ -468,7 +430,46 @@ class MqttFan(MqttEntity, FanEntity):
return return
self._attr_current_direction = str(direction) self._attr_current_direction = str(direction)
add_subscribe_topic(CONF_DIRECTION_STATE_TOPIC, direction_received) def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscribe_topic(
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
) -> bool:
"""Add a topic to subscribe to."""
if has_topic := self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
return has_topic
add_subscribe_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
add_subscribe_topic(
CONF_PERCENTAGE_STATE_TOPIC, self._percentage_received, {"_attr_percentage"}
)
add_subscribe_topic(
CONF_PRESET_MODE_STATE_TOPIC,
self._preset_mode_received,
{"_attr_preset_mode"},
)
if add_subscribe_topic(
CONF_OSCILLATION_STATE_TOPIC,
self._oscillation_received,
{"_attr_oscillating"},
):
self._attr_oscillating = False
add_subscribe_topic(
CONF_DIRECTION_STATE_TOPIC,
self._direction_received,
{"_attr_current_direction"},
)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -476,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property @property
def is_on(self) -> bool | None: def is_on(self) -> bool | None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any from typing import Any
@@ -51,12 +52,7 @@ from .const import (
CONF_STATE_VALUE_TEMPLATE, CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -284,25 +280,23 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
topics: dict[str, dict[str, Any]], topics: dict[str, dict[str, Any]],
topic: str, topic: str,
msg_callback: Callable[[ReceiveMessage], None], msg_callback: Callable[[ReceiveMessage], None],
tracked_attributes: set[str],
) -> None: ) -> None:
"""Add a subscription.""" """Add a subscription."""
qos: int = self._config[CONF_QOS] qos: int = self._config[CONF_QOS]
if topic in self._topic and self._topic[topic] is not None: if topic in self._topic and self._topic[topic] is not None:
topics[topic] = { topics[topic] = {
"topic": self._topic[topic], "topic": self._topic[topic],
"msg_callback": msg_callback, "msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": qos, "qos": qos,
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message.""" """Handle new received MQTT message."""
payload = self._value_templates[CONF_STATE](msg.payload) payload = self._value_templates[CONF_STATE](msg.payload)
if not payload: if not payload:
@@ -315,12 +309,8 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
elif payload == PAYLOAD_NONE: elif payload == PAYLOAD_NONE:
self._attr_is_on = None self._attr_is_on = None
self.add_subscription(topics, CONF_STATE_TOPIC, state_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _action_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_action"})
def action_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message.""" """Handle new received MQTT message."""
action_payload = self._value_templates[ATTR_ACTION](msg.payload) action_payload = self._value_templates[ATTR_ACTION](msg.payload)
if not action_payload or action_payload == PAYLOAD_NONE: if not action_payload or action_payload == PAYLOAD_NONE:
@@ -337,12 +327,8 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
) )
return return
self.add_subscription(topics, CONF_ACTION_TOPIC, action_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _current_humidity_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_current_humidity"})
def current_humidity_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the current humidity.""" """Handle new received MQTT message for the current humidity."""
rendered_current_humidity_payload = self._value_templates[ rendered_current_humidity_payload = self._value_templates[
ATTR_CURRENT_HUMIDITY ATTR_CURRENT_HUMIDITY
@@ -373,14 +359,8 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
return return
self._attr_current_humidity = current_humidity self._attr_current_humidity = current_humidity
self.add_subscription(
topics, CONF_CURRENT_HUMIDITY_TOPIC, current_humidity_received
)
@callback @callback
@log_messages(self.hass, self.entity_id) def _target_humidity_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_target_humidity"})
def target_humidity_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for the target humidity.""" """Handle new received MQTT message for the target humidity."""
rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY]( rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY](
msg.payload msg.payload
@@ -414,14 +394,8 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
return return
self._attr_target_humidity = target_humidity self._attr_target_humidity = target_humidity
self.add_subscription(
topics, CONF_TARGET_HUMIDITY_STATE_TOPIC, target_humidity_received
)
@callback @callback
@log_messages(self.hass, self.entity_id) def _mode_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_mode"})
def mode_received(msg: ReceiveMessage) -> None:
"""Handle new received MQTT message for mode.""" """Handle new received MQTT message for mode."""
mode = str(self._value_templates[ATTR_MODE](msg.payload)) mode = str(self._value_templates[ATTR_MODE](msg.payload))
if mode == self._payload["MODE_RESET"]: if mode == self._payload["MODE_RESET"]:
@@ -441,7 +415,31 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
self._attr_mode = mode self._attr_mode = mode
self.add_subscription(topics, CONF_MODE_STATE_TOPIC, mode_received) def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
self.add_subscription(
topics, CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"}
)
self.add_subscription(
topics, CONF_ACTION_TOPIC, self._action_received, {"_attr_action"}
)
self.add_subscription(
topics,
CONF_CURRENT_HUMIDITY_TOPIC,
self._current_humidity_received,
{"_attr_current_humidity"},
)
self.add_subscription(
topics,
CONF_TARGET_HUMIDITY_STATE_TOPIC,
self._target_humidity_received,
{"_attr_target_humidity"},
)
self.add_subscription(
topics, CONF_MODE_STATE_TOPIC, self._mode_received, {"_attr_mode"}
)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -449,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on the entity. """Turn on the entity.

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from base64 import b64decode from base64 import b64decode
import binascii import binascii
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -26,7 +27,6 @@ from homeassistant.util import dt as dt_util
from . import subscription from . import subscription
from .config import MQTT_BASE_SCHEMA from .config import MQTT_BASE_SCHEMA
from .const import CONF_ENCODING, CONF_QOS from .const import CONF_ENCODING, CONF_QOS
from .debug_info import log_messages
from .mixins import MqttEntity, async_setup_entity_entry_helper from .mixins import MqttEntity, async_setup_entity_entry_helper
from .models import ( from .models import (
DATA_MQTT, DATA_MQTT,
@@ -143,31 +143,8 @@ class MqttImage(MqttEntity, ImageEntity):
config.get(CONF_URL_TEMPLATE), entity=self config.get(CONF_URL_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
"""Add a topic to subscribe to."""
encoding: str | None
encoding = (
None
if CONF_IMAGE_TOPIC in self._config
else self._config[CONF_ENCODING] or None
)
if has_topic := self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": msg_callback,
"qos": self._config[CONF_QOS],
"encoding": encoding,
}
return has_topic
@callback @callback
@log_messages(self.hass, self.entity_id) def _image_data_received(self, msg: ReceiveMessage) -> None:
def image_data_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
try: try:
if CONF_IMAGE_ENCODING in self._config: if CONF_IMAGE_ENCODING in self._config:
@@ -186,11 +163,8 @@ class MqttImage(MqttEntity, ImageEntity):
self._attr_image_last_updated = dt_util.utcnow() self._attr_image_last_updated = dt_util.utcnow()
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self) self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _image_from_url_request_received(self, msg: ReceiveMessage) -> None:
def image_from_url_request_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
try: try:
url = cv.url(self._url_template(msg.payload)) url = cv.url(self._url_template(msg.payload))
@@ -208,7 +182,31 @@ class MqttImage(MqttEntity, ImageEntity):
self._cached_image = None self._cached_image = None
self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self) self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received) def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool:
"""Add a topic to subscribe to."""
encoding: str | None
encoding = (
None
if CONF_IMAGE_TOPIC in self._config
else self._config[CONF_ENCODING] or None
)
if has_topic := self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": partial(self._message_callback, msg_callback, None),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": encoding,
}
return has_topic
add_subscribe_topic(CONF_IMAGE_TOPIC, self._image_data_received)
add_subscribe_topic(CONF_URL_TOPIC, self._image_from_url_request_received)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -216,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_image(self) -> bytes | None: async def async_image(self) -> bytes | None:
"""Return bytes of image.""" """Return bytes of image."""

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import contextlib import contextlib
from functools import partial
import logging import logging
import voluptuous as vol import voluptuous as vol
@@ -31,12 +32,7 @@ from .const import (
DEFAULT_OPTIMISTIC, DEFAULT_OPTIMISTIC,
DEFAULT_RETAIN, DEFAULT_RETAIN,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -150,13 +146,8 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self
).async_render ).async_render
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_activity"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
payload = str(self._value_template(msg.payload)) payload = str(self._value_template(msg.payload))
if not payload: if not payload:
@@ -181,17 +172,24 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
) )
return return
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None: if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._attr_assumed_state = True self._attr_assumed_state = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
CONF_ACTIVITY_STATE_TOPIC: { CONF_ACTIVITY_STATE_TOPIC: {
"topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC), "topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC),
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._message_received,
{"_attr_activity"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -200,7 +198,7 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and ( if self._attr_assumed_state and (
last_state := await self.async_get_last_state() last_state := await self.async_get_last_state()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any, cast from typing import Any, cast
@@ -53,8 +54,7 @@ from ..const import (
CONF_STATE_VALUE_TEMPLATE, CONF_STATE_VALUE_TEMPLATE,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from ..debug_info import log_messages from ..mixins import MqttEntity
from ..mixins import MqttEntity, write_state_on_attr_change
from ..models import ( from ..models import (
MessageCallbackType, MessageCallbackType,
MqttCommandTemplate, MqttCommandTemplate,
@@ -378,24 +378,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
attr: bool = getattr(self, f"_optimistic_{attribute}") attr: bool = getattr(self, f"_optimistic_{attribute}")
return attr return attr
def _prepare_subscribe_topics(self) -> None: # noqa: C901
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
def add_topic(topic: str, msg_callback: MessageCallbackType) -> None:
"""Add a topic."""
if self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": msg_callback,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE]( payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.NONE msg.payload, PayloadSentinel.NONE
@@ -411,18 +395,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
elif payload == PAYLOAD_NONE: elif payload == PAYLOAD_NONE:
self._attr_is_on = None self._attr_is_on = None
if self._topic[CONF_STATE_TOPIC] is not None:
topics[CONF_STATE_TOPIC] = {
"topic": self._topic[CONF_STATE_TOPIC],
"msg_callback": state_received,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback @callback
@log_messages(self.hass, self.entity_id) def _brightness_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_brightness"})
def brightness_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for the brightness.""" """Handle new MQTT messages for the brightness."""
payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE]( payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT msg.payload, PayloadSentinel.DEFAULT
@@ -439,23 +413,18 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE] percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE]
self._attr_brightness = min(round(percent_bright * 255), 255) self._attr_brightness = min(round(percent_bright * 255), 255)
add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received)
@callback @callback
def _rgbx_received( def _rgbx_received(
self,
msg: ReceiveMessage, msg: ReceiveMessage,
template: str, template: str,
color_mode: ColorMode, color_mode: ColorMode,
convert_color: Callable[..., tuple[int, ...]], convert_color: Callable[..., tuple[int, ...]],
) -> tuple[int, ...] | None: ) -> tuple[int, ...] | None:
"""Handle new MQTT messages for RGBW and RGBWW.""" """Process MQTT messages for RGBW and RGBWW."""
payload = self._value_templates[template]( payload = self._value_templates[template](msg.payload, PayloadSentinel.DEFAULT)
msg.payload, PayloadSentinel.DEFAULT
)
if payload is PayloadSentinel.DEFAULT or not payload: if payload is PayloadSentinel.DEFAULT or not payload:
_LOGGER.debug( _LOGGER.debug("Ignoring empty %s message from '%s'", color_mode, msg.topic)
"Ignoring empty %s message from '%s'", color_mode, msg.topic
)
return None return None
color = tuple(int(val) for val in str(payload).split(",")) color = tuple(int(val) for val in str(payload).split(","))
if self._optimistic_color_mode: if self._optimistic_color_mode:
@@ -478,29 +447,19 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
return color return color
@callback @callback
@log_messages(self.hass, self.entity_id) def _rgb_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"}
)
def rgb_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGB.""" """Handle new MQTT messages for RGB."""
rgb = _rgbx_received( rgb = self._rgbx_received(
msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x
) )
if rgb is None: if rgb is None:
return return
self._attr_rgb_color = cast(tuple[int, int, int], rgb) self._attr_rgb_color = cast(tuple[int, int, int], rgb)
add_topic(CONF_RGB_STATE_TOPIC, rgb_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _rgbw_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"}
)
def rgbw_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBW.""" """Handle new MQTT messages for RGBW."""
rgbw = _rgbx_received( rgbw = self._rgbx_received(
msg, msg,
CONF_RGBW_VALUE_TEMPLATE, CONF_RGBW_VALUE_TEMPLATE,
ColorMode.RGBW, ColorMode.RGBW,
@@ -510,31 +469,21 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
return return
self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw) self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw)
add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _rgbww_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"}
)
def rgbww_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for RGBWW.""" """Handle new MQTT messages for RGBWW."""
@callback @callback
def _converter( def _converter(
r: int, g: int, b: int, cw: int, ww: int r: int, g: int, b: int, cw: int, ww: int
) -> tuple[int, int, int]: ) -> tuple[int, int, int]:
min_kelvin = color_util.color_temperature_mired_to_kelvin( min_kelvin = color_util.color_temperature_mired_to_kelvin(self.max_mireds)
self.max_mireds max_kelvin = color_util.color_temperature_mired_to_kelvin(self.min_mireds)
)
max_kelvin = color_util.color_temperature_mired_to_kelvin(
self.min_mireds
)
return color_util.color_rgbww_to_rgb( return color_util.color_rgbww_to_rgb(
r, g, b, cw, ww, min_kelvin, max_kelvin r, g, b, cw, ww, min_kelvin, max_kelvin
) )
rgbww = _rgbx_received( rgbww = self._rgbx_received(
msg, msg,
CONF_RGBWW_VALUE_TEMPLATE, CONF_RGBWW_VALUE_TEMPLATE,
ColorMode.RGBWW, ColorMode.RGBWW,
@@ -544,12 +493,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
return return
self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww) self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww)
add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _color_mode_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_color_mode"})
def color_mode_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color mode.""" """Handle new MQTT messages for color mode."""
payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE]( payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT msg.payload, PayloadSentinel.DEFAULT
@@ -560,12 +505,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self._attr_color_mode = ColorMode(str(payload)) self._attr_color_mode = ColorMode(str(payload))
add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _color_temp_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_color_temp"})
def color_temp_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for color temperature.""" """Handle new MQTT messages for color temperature."""
payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE]( payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT msg.payload, PayloadSentinel.DEFAULT
@@ -578,12 +519,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self._attr_color_mode = ColorMode.COLOR_TEMP self._attr_color_mode = ColorMode.COLOR_TEMP
self._attr_color_temp = int(payload) self._attr_color_temp = int(payload)
add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _effect_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_effect"})
def effect_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for effect.""" """Handle new MQTT messages for effect."""
payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE]( payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT msg.payload, PayloadSentinel.DEFAULT
@@ -594,12 +531,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self._attr_effect = str(payload) self._attr_effect = str(payload)
add_topic(CONF_EFFECT_STATE_TOPIC, effect_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _hs_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_hs_color"})
def hs_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for hs color.""" """Handle new MQTT messages for hs color."""
payload = self._value_templates[CONF_HS_VALUE_TEMPLATE]( payload = self._value_templates[CONF_HS_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT msg.payload, PayloadSentinel.DEFAULT
@@ -615,12 +548,8 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
except ValueError: except ValueError:
_LOGGER.warning("Failed to parse hs state update: '%s'", payload) _LOGGER.warning("Failed to parse hs state update: '%s'", payload)
add_topic(CONF_HS_STATE_TOPIC, hs_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _xy_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_color_mode", "_attr_xy_color"})
def xy_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages for xy color.""" """Handle new MQTT messages for xy color."""
payload = self._value_templates[CONF_XY_VALUE_TEMPLATE]( payload = self._value_templates[CONF_XY_VALUE_TEMPLATE](
msg.payload, PayloadSentinel.DEFAULT msg.payload, PayloadSentinel.DEFAULT
@@ -634,7 +563,63 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
self._attr_color_mode = ColorMode.XY self._attr_color_mode = ColorMode.XY
self._attr_xy_color = cast(tuple[float, float], xy_color) self._attr_xy_color = cast(tuple[float, float], xy_color)
add_topic(CONF_XY_STATE_TOPIC, xy_received) def _prepare_subscribe_topics(self) -> None: # noqa: C901
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
def add_topic(
topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str]
) -> None:
"""Add a topic."""
if self._topic[topic] is not None:
topics[topic] = {
"topic": self._topic[topic],
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
add_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"})
add_topic(
CONF_BRIGHTNESS_STATE_TOPIC, self._brightness_received, {"_attr_brightness"}
)
add_topic(
CONF_RGB_STATE_TOPIC,
self._rgb_received,
{"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"},
)
add_topic(
CONF_RGBW_STATE_TOPIC,
self._rgbw_received,
{"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"},
)
add_topic(
CONF_RGBWW_STATE_TOPIC,
self._rgbww_received,
{"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"},
)
add_topic(
CONF_COLOR_MODE_STATE_TOPIC, self._color_mode_received, {"_attr_color_mode"}
)
add_topic(
CONF_COLOR_TEMP_STATE_TOPIC,
self._color_temp_received,
{"_attr_color_mode", "_attr_color_temp"},
)
add_topic(CONF_EFFECT_STATE_TOPIC, self._effect_received, {"_attr_effect"})
add_topic(
CONF_HS_STATE_TOPIC,
self._hs_received,
{"_attr_color_mode", "_attr_hs_color"},
)
add_topic(
CONF_XY_STATE_TOPIC,
self._xy_received,
{"_attr_color_mode", "_attr_xy_color"},
)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -642,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state() last_state = await self.async_get_last_state()
def restore_state( def restore_state(

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
@@ -66,8 +67,7 @@ from ..const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
DOMAIN as MQTT_DOMAIN, DOMAIN as MQTT_DOMAIN,
) )
from ..debug_info import log_messages from ..mixins import MqttEntity
from ..mixins import MqttEntity, write_state_on_attr_change
from ..models import ReceiveMessage from ..models import ReceiveMessage
from ..schemas import MQTT_ENTITY_COMMON_SCHEMA from ..schemas import MQTT_ENTITY_COMMON_SCHEMA
from ..util import valid_subscribe_topic from ..util import valid_subscribe_topic
@@ -414,27 +414,8 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
self.entity_id, self.entity_id,
) )
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self,
{
"_attr_brightness",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
"_attr_rgb_color",
"_attr_rgbw_color",
"_attr_rgbww_color",
"_attr_xy_color",
"color_mode",
},
)
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
values = json_loads_object(msg.payload) values = json_loads_object(msg.payload)
@@ -509,14 +490,36 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
with suppress(KeyError): with suppress(KeyError):
self._attr_effect = cast(str, values["effect"]) self._attr_effect = cast(str, values["effect"])
if self._topic[CONF_STATE_TOPIC] is not None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
#
if self._topic[CONF_STATE_TOPIC] is None:
return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { CONF_STATE_TOPIC: {
"topic": self._topic[CONF_STATE_TOPIC], "topic": self._topic[CONF_STATE_TOPIC],
"msg_callback": state_received, "msg_callback": partial(
self._message_callback,
self._state_received,
{
"_attr_brightness",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
"_attr_rgb_color",
"_attr_rgbw_color",
"_attr_rgbww_color",
"_attr_xy_color",
"color_mode",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -525,7 +528,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state() last_state = await self.async_get_last_state()
if self._optimistic and last_state: if self._optimistic and last_state:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any from typing import Any
@@ -44,8 +45,7 @@ from ..const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from ..debug_info import log_messages from ..mixins import MqttEntity
from ..mixins import MqttEntity, write_state_on_attr_change
from ..models import ( from ..models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -188,23 +188,8 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
# Support for ct + hs, prioritize hs # Support for ct + hs, prioritize hs
self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self,
{
"_attr_brightness",
"_attr_color_mode",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
},
)
def state_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload) state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload)
if state == STATE_ON: if state == STATE_ON:
@@ -229,9 +214,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
) )
except ValueError: except ValueError:
_LOGGER.warning( _LOGGER.warning("Invalid brightness value received from %s", msg.topic)
"Invalid brightness value received from %s", msg.topic
)
if CONF_COLOR_TEMP_TEMPLATE in self._config: if CONF_COLOR_TEMP_TEMPLATE in self._config:
try: try:
@@ -272,14 +255,31 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
else: else:
_LOGGER.warning("Unsupported effect value received") _LOGGER.warning("Unsupported effect value received")
if self._topics[CONF_STATE_TOPIC] is not None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
if self._topics[CONF_STATE_TOPIC] is None:
return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._topics[CONF_STATE_TOPIC], "topic": self._topics[CONF_STATE_TOPIC],
"msg_callback": state_received, "msg_callback": partial(
self._message_callback,
self._state_received,
{
"_attr_brightness",
"_attr_color_mode",
"_attr_color_temp",
"_attr_effect",
"_attr_hs_color",
"_attr_is_on",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -288,7 +288,7 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
last_state = await self.async_get_last_state() last_state = await self.async_get_last_state()
if self._optimistic and last_state: if self._optimistic and last_state:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import re import re
from typing import Any from typing import Any
@@ -36,12 +37,7 @@ from .const import (
CONF_STATE_OPENING, CONF_STATE_OPENING,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -186,27 +182,8 @@ class MqttLock(MqttEntity, LockEntity):
self._valid_states = [config[state] for state in STATE_CONFIG_KEYS] self._valid_states = [config[state] for state in STATE_CONFIG_KEYS]
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]] = {}
qos: int = self._config[CONF_QOS]
encoding: str | None = self._config[CONF_ENCODING] or None
@callback @callback
@log_messages(self.hass, self.entity_id) def _message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self,
{
"_attr_is_jammed",
"_attr_is_locked",
"_attr_is_locking",
"_attr_is_open",
"_attr_is_opening",
"_attr_is_unlocking",
},
)
def message_received(msg: ReceiveMessage) -> None:
"""Handle new lock state messages.""" """Handle new lock state messages."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
if not payload.strip(): # No output from template, ignore if not payload.strip(): # No output from template, ignore
@@ -227,16 +204,36 @@ class MqttLock(MqttEntity, LockEntity):
self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING] self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING]
self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED] self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED]
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, dict[str, Any]]
qos: int = self._config[CONF_QOS]
encoding: str | None = self._config[CONF_ENCODING] or None
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: return
topics[CONF_STATE_TOPIC] = { topics = {
CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._message_received,
{
"_attr_is_jammed",
"_attr_is_locked",
"_attr_is_locking",
"_attr_is_open",
"_attr_is_opening",
"_attr_is_unlocking",
},
),
"entity_id": self.entity_id,
CONF_QOS: qos, CONF_QOS: qos,
CONF_ENCODING: encoding, CONF_ENCODING: encoding,
} }
}
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
@@ -246,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_lock(self, **kwargs: Any) -> None: async def async_lock(self, **kwargs: Any) -> None:
"""Lock the device. """Lock the device.

View File

@@ -114,7 +114,7 @@ from .models import (
from .subscription import ( from .subscription import (
EntitySubscription, EntitySubscription,
async_prepare_subscribe_topics, async_prepare_subscribe_topics,
async_subscribe_topics, async_subscribe_topics_internal,
async_unsubscribe_topics, async_unsubscribe_topics,
) )
from .util import mqtt_config_entry_enabled from .util import mqtt_config_entry_enabled
@@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity):
"""Subscribe MQTT events.""" """Subscribe MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._attributes_prepare_subscribe_topics() self._attributes_prepare_subscribe_topics()
await self._attributes_subscribe_topics() self._attributes_subscribe_topics()
def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None: def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
@@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity):
async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None: async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
await self._attributes_subscribe_topics() self._attributes_subscribe_topics()
def _attributes_prepare_subscribe_topics(self) -> None: def _attributes_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity):
}, },
) )
async def _attributes_subscribe_topics(self) -> None: @callback
def _attributes_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._attributes_sub_state) async_subscribe_topics_internal(self.hass, self._attributes_sub_state)
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
@@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity):
"""Subscribe MQTT events.""" """Subscribe MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._availability_prepare_subscribe_topics() self._availability_prepare_subscribe_topics()
await self._availability_subscribe_topics() self._availability_subscribe_topics()
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect) async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect)
) )
@@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity):
async def availability_discovery_update(self, config: DiscoveryInfoType) -> None: async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
await self._availability_subscribe_topics() self._availability_subscribe_topics()
def _availability_setup_from_config(self, config: ConfigType) -> None: def _availability_setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup.""" """(Re)Setup."""
@@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity):
self._available[topic] = False self._available[topic] = False
self._available_latest = False self._available_latest = False
async def _availability_subscribe_topics(self) -> None: @callback
def _availability_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state) async_subscribe_topics_internal(self.hass, self._availability_sub_state)
@callback @callback
def async_mqtt_connect(self) -> None: def async_mqtt_connect(self) -> None:
@@ -1254,12 +1256,14 @@ class MqttEntity(
def _message_callback( def _message_callback(
self, self,
msg_callback: MessageCallbackType, msg_callback: MessageCallbackType,
attributes: set[str], attributes: set[str] | None,
msg: ReceiveMessage, msg: ReceiveMessage,
) -> None: ) -> None:
"""Process the message callback.""" """Process the message callback."""
if attributes is not None:
attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple( attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple(
(attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes (attribute, getattr(self, attribute, UNDEFINED))
for attribute in attributes
) )
mqtt_data = self.hass.data[DATA_MQTT] mqtt_data = self.hass.data[DATA_MQTT]
messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][ messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][
@@ -1274,7 +1278,7 @@ class MqttEntity(
_LOGGER.warning(exc) _LOGGER.warning(exc)
return return
if self._attrs_have_changed(attrs_snapshot): if attributes is not None and self._attrs_have_changed(attrs_snapshot):
mqtt_data.state_write_requests.write_state_request(self) mqtt_data.state_write_requests.write_state_request(self)

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from ast import literal_eval from ast import literal_eval
import asyncio import asyncio
from collections import deque from collections import deque
from collections.abc import Callable, Coroutine from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import StrEnum from enum import StrEnum
import logging import logging
@@ -70,7 +70,6 @@ class ReceiveMessage:
timestamp: float timestamp: float
type AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]]
type MessageCallbackType = Callable[[ReceiveMessage], None] type MessageCallbackType = Callable[[ReceiveMessage], None]

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import voluptuous as vol import voluptuous as vol
@@ -41,12 +42,7 @@ from .const import (
CONF_RETAIN, CONF_RETAIN,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -165,13 +161,8 @@ class MqttNumber(MqttEntity, RestoreNumber):
self._attr_native_step = config[CONF_STEP] self._attr_native_step = config[CONF_STEP]
self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_native_value"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
num_value: int | float | None num_value: int | float | None
payload = str(self._value_template(msg.payload)) payload = str(self._value_template(msg.payload))
@@ -203,17 +194,24 @@ class MqttNumber(MqttEntity, RestoreNumber):
self._attr_native_value = num_value self._attr_native_value = num_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._attr_assumed_state = True self._attr_assumed_state = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._message_received,
{"_attr_native_value"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -222,7 +220,7 @@ class MqttNumber(MqttEntity, RestoreNumber):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and ( if self._attr_assumed_state and (
last_number_data := await self.async_get_last_number_data() last_number_data := await self.async_get_last_number_data()

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import voluptuous as vol import voluptuous as vol
@@ -27,12 +28,7 @@ from .const import (
CONF_RETAIN, CONF_RETAIN,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -113,13 +109,8 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
config.get(CONF_VALUE_TEMPLATE), entity=self config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_current_option"})
def message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT messages.""" """Handle new MQTT messages."""
payload = str(self._value_template(msg.payload)) payload = str(self._value_template(msg.payload))
if not payload.strip(): # No output from template, ignore if not payload.strip(): # No output from template, ignore
@@ -143,17 +134,24 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
return return
self._attr_current_option = payload self._attr_current_option = payload
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._attr_assumed_state = True self._attr_assumed_state = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
"state_topic": { "state_topic": {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": message_received, "msg_callback": partial(
self._message_callback,
self._message_received,
{"_attr_current_option"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -162,7 +160,7 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._attr_assumed_state and ( if self._attr_assumed_state and (
last_state := await self.async_get_last_state() last_state := await self.async_get_last_state()

View File

@@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@callback @callback
def _value_is_expired(self, *_: datetime) -> None: def _value_is_expired(self, *_: datetime) -> None:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any, cast from typing import Any, cast
@@ -48,12 +49,7 @@ from .const import (
PAYLOAD_EMPTY_JSON, PAYLOAD_EMPTY_JSON,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MqttCommandTemplate, MqttCommandTemplate,
MqttValueTemplate, MqttValueTemplate,
@@ -205,13 +201,8 @@ class MqttSiren(MqttEntity, SirenEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_is_on", "_extra_attributes"})
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages.""" """Handle new MQTT state messages."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
if not payload or payload == PAYLOAD_EMPTY_JSON: if not payload or payload == PAYLOAD_EMPTY_JSON:
@@ -271,17 +262,24 @@ class MqttSiren(MqttEntity, SirenEntity):
self._extra_attributes = dict(self._extra_attributes) self._extra_attributes = dict(self._extra_attributes)
self._update(process_turn_on_params(self, params)) self._update(process_turn_on_params(self, params))
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
CONF_STATE_TOPIC: { CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received, "msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_is_on", "_extra_attributes"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -290,7 +288,7 @@ class MqttSiren(MqttEntity, SirenEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
@property @property
def extra_state_attributes(self) -> dict[str, Any] | None: def extra_state_attributes(self) -> dict[str, Any] | None:

View File

@@ -2,14 +2,15 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Coroutine from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from .. import mqtt
from . import debug_info from . import debug_info
from .client import async_subscribe_internal
from .const import DEFAULT_QOS from .const import DEFAULT_QOS
from .models import MessageCallbackType from .models import MessageCallbackType
@@ -21,7 +22,7 @@ class EntitySubscription:
hass: HomeAssistant hass: HomeAssistant
topic: str | None topic: str | None
message_callback: MessageCallbackType message_callback: MessageCallbackType
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None should_subscribe: bool | None
unsubscribe_callback: Callable[[], None] | None unsubscribe_callback: Callable[[], None] | None
qos: int = 0 qos: int = 0
encoding: str = "utf-8" encoding: str = "utf-8"
@@ -53,15 +54,16 @@ class EntitySubscription:
self.hass, self.message_callback, self.topic, self.entity_id self.hass, self.message_callback, self.topic, self.entity_id
) )
self.subscribe_task = mqtt.async_subscribe( self.should_subscribe = True
hass, self.topic, self.message_callback, self.qos, self.encoding
)
async def subscribe(self) -> None: @callback
def subscribe(self) -> None:
"""Subscribe to a topic.""" """Subscribe to a topic."""
if not self.subscribe_task: if not self.should_subscribe or not self.topic:
return return
self.unsubscribe_callback = await self.subscribe_task self.unsubscribe_callback = async_subscribe_internal(
self.hass, self.topic, self.message_callback, self.qos, self.encoding
)
def _should_resubscribe(self, other: EntitySubscription | None) -> bool: def _should_resubscribe(self, other: EntitySubscription | None) -> bool:
"""Check if we should re-subscribe to the topic using the old state.""" """Check if we should re-subscribe to the topic using the old state."""
@@ -79,6 +81,7 @@ class EntitySubscription:
) )
@callback
def async_prepare_subscribe_topics( def async_prepare_subscribe_topics(
hass: HomeAssistant, hass: HomeAssistant,
new_state: dict[str, EntitySubscription] | None, new_state: dict[str, EntitySubscription] | None,
@@ -107,7 +110,7 @@ def async_prepare_subscribe_topics(
qos=value.get("qos", DEFAULT_QOS), qos=value.get("qos", DEFAULT_QOS),
encoding=value.get("encoding", "utf-8"), encoding=value.get("encoding", "utf-8"),
hass=hass, hass=hass,
subscribe_task=None, should_subscribe=None,
entity_id=value.get("entity_id", None), entity_id=value.get("entity_id", None),
) )
# Get the current subscription state # Get the current subscription state
@@ -135,12 +138,29 @@ async def async_subscribe_topics(
sub_state: dict[str, EntitySubscription], sub_state: dict[str, EntitySubscription],
) -> None: ) -> None:
"""(Re)Subscribe to a set of MQTT topics.""" """(Re)Subscribe to a set of MQTT topics."""
async_subscribe_topics_internal(hass, sub_state)
@callback
def async_subscribe_topics_internal(
hass: HomeAssistant,
sub_state: dict[str, EntitySubscription],
) -> None:
"""(Re)Subscribe to a set of MQTT topics.
This function is internal to the MQTT integration and should not be called
from outside the integration.
"""
for sub in sub_state.values(): for sub in sub_state.values():
await sub.subscribe() sub.subscribe()
def async_unsubscribe_topics( if TYPE_CHECKING:
def async_unsubscribe_topics(
hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None
) -> dict[str, EntitySubscription]: ) -> dict[str, EntitySubscription]:
"""Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" """Unsubscribe from all MQTT topics managed by async_subscribe_topics."""
return async_prepare_subscribe_topics(hass, sub_state, {})
async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={})

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
@@ -36,12 +37,7 @@ from .const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MqttValueTemplate, ReceiveMessage from .models import MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
@@ -118,13 +114,8 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_is_on"})
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages.""" """Handle new MQTT state messages."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
if payload == self._state_on: if payload == self._state_on:
@@ -134,17 +125,24 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
elif payload == PAYLOAD_NONE: elif payload == PAYLOAD_NONE:
self._attr_is_on = None self._attr_is_on = None
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
# Force into optimistic mode. # Force into optimistic mode.
self._optimistic = True self._optimistic = True
else: return
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self.hass,
self._sub_state, self._sub_state,
{ {
CONF_STATE_TOPIC: { CONF_STATE_TOPIC: {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received, "msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_is_on"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -153,7 +151,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
if self._optimistic and (last_state := await self.async_get_last_state()): if self._optimistic and (last_state := await self.async_get_last_state()):
self._attr_is_on = last_state.state == STATE_ON self._attr_is_on = last_state.state == STATE_ON

View File

@@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin):
} }
}, },
) )
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_tear_down(self) -> None: async def async_tear_down(self) -> None:
"""Cleanup tag scanner.""" """Cleanup tag scanner."""

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
import re import re
from typing import Any from typing import Any
@@ -34,12 +35,7 @@ from .const import (
CONF_RETAIN, CONF_RETAIN,
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ( from .models import (
MessageCallbackType, MessageCallbackType,
MqttCommandTemplate, MqttCommandTemplate,
@@ -160,32 +156,41 @@ class MqttTextEntity(MqttEntity, TextEntity):
self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None
self._attr_assumed_state = bool(self._optimistic) self._attr_assumed_state = bool(self._optimistic)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscription(
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType
) -> None:
if self._config.get(topic) is not None:
topics[topic] = {
"topic": self._config[topic],
"msg_callback": msg_callback,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback @callback
@log_messages(self.hass, self.entity_id) def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_native_value"})
def handle_state_message_received(msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT.""" """Handle receiving state message via MQTT."""
payload = str(self._value_template(msg.payload)) payload = str(self._value_template(msg.payload))
if check_state_too_long(_LOGGER, payload, self.entity_id, msg): if check_state_too_long(_LOGGER, payload, self.entity_id, msg):
return return
self._attr_native_value = payload self._attr_native_value = payload
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received) def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscription(
topics: dict[str, Any],
topic: str,
msg_callback: MessageCallbackType,
tracked_attributes: set[str],
) -> None:
if self._config.get(topic) is not None:
topics[topic] = {
"topic": self._config[topic],
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
add_subscription(
topics,
CONF_STATE_TOPIC,
self._handle_state_message_received,
{"_attr_native_value"},
)
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
self.hass, self._sub_state, topics self.hass, self._sub_state, topics
@@ -193,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_set_value(self, value: str) -> None: async def async_set_value(self, value: str) -> None:
"""Change the text.""" """Change the text."""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from functools import partial
import logging import logging
from typing import Any, TypedDict, cast from typing import Any, TypedDict, cast
@@ -32,12 +33,7 @@ from .const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
PAYLOAD_EMPTY_JSON, PAYLOAD_EMPTY_JSON,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic, valid_subscribe_topic from .util import valid_publish_topic, valid_subscribe_topic
@@ -141,35 +137,8 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
).async_render_with_possible_json_value, ).async_render_with_possible_json_value,
} }
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscription(
topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType
) -> None:
if self._config.get(topic) is not None:
topics[topic] = {
"topic": self._config[topic],
"msg_callback": msg_callback,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
@callback @callback
@log_messages(self.hass, self.entity_id) def _handle_state_message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self,
{
"_attr_installed_version",
"_attr_latest_version",
"_attr_title",
"_attr_release_summary",
"_attr_release_url",
"_entity_picture",
},
)
def handle_state_message_received(msg: ReceiveMessage) -> None:
"""Handle receiving state message via MQTT.""" """Handle receiving state message via MQTT."""
payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload) payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload)
@@ -233,20 +202,53 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
if "entity_picture" in json_payload: if "entity_picture" in json_payload:
self._entity_picture = json_payload["entity_picture"] self._entity_picture = json_payload["entity_picture"]
add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received)
@callback @callback
@log_messages(self.hass, self.entity_id) def _handle_latest_version_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(self, {"_attr_latest_version"})
def handle_latest_version_received(msg: ReceiveMessage) -> None:
"""Handle receiving latest version via MQTT.""" """Handle receiving latest version via MQTT."""
latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload) latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload)
if isinstance(latest_version, str) and latest_version != "": if isinstance(latest_version, str) and latest_version != "":
self._attr_latest_version = latest_version self._attr_latest_version = latest_version
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
def add_subscription(
topics: dict[str, Any],
topic: str,
msg_callback: MessageCallbackType,
tracked_attributes: set[str],
) -> None:
if self._config.get(topic) is not None:
topics[topic] = {
"topic": self._config[topic],
"msg_callback": partial(
self._message_callback, msg_callback, tracked_attributes
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None,
}
add_subscription( add_subscription(
topics, CONF_LATEST_VERSION_TOPIC, handle_latest_version_received topics,
CONF_STATE_TOPIC,
self._handle_state_message_received,
{
"_attr_installed_version",
"_attr_latest_version",
"_attr_title",
"_attr_release_summary",
"_attr_release_url",
"_entity_picture",
},
)
add_subscription(
topics,
CONF_LATEST_VERSION_TOPIC,
self._handle_latest_version_received,
{"_attr_latest_version"},
) )
self._sub_state = subscription.async_prepare_subscribe_topics( self._sub_state = subscription.async_prepare_subscribe_topics(
@@ -255,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_install( async def async_install(
self, version: str | None, backup: bool, **kwargs: Any self, version: str | None, backup: bool, **kwargs: Any

View File

@@ -8,6 +8,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from functools import partial
import logging import logging
from typing import Any, cast from typing import Any, cast
@@ -49,12 +50,7 @@ from .const import (
CONF_STATE_TOPIC, CONF_STATE_TOPIC,
DOMAIN, DOMAIN,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import ReceiveMessage from .models import ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic from .util import valid_publish_topic
@@ -322,16 +318,8 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0) self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0)
self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0))) self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0)))
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self, {"_attr_battery_level", "_attr_fan_speed", "_attr_state"}
)
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle state MQTT message.""" """Handle state MQTT message."""
payload = json_loads_object(msg.payload) payload = json_loads_object(msg.payload)
if STATE in payload and ( if STATE in payload and (
@@ -343,10 +331,19 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
del payload[STATE] del payload[STATE]
self._update_state_attributes(payload) self._update_state_attributes(payload)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics: dict[str, Any] = {}
if state_topic := self._config.get(CONF_STATE_TOPIC): if state_topic := self._config.get(CONF_STATE_TOPIC):
topics["state_position_topic"] = { topics["state_position_topic"] = {
"topic": state_topic, "topic": state_topic,
"msg_callback": state_message_received, "msg_callback": partial(
self._message_callback,
self._state_message_received,
{"_attr_battery_level", "_attr_fan_speed", "_attr_state"},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -356,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def _async_publish_command(self, feature: VacuumEntityFeature) -> None: async def _async_publish_command(self, feature: VacuumEntityFeature) -> None:
"""Publish a command.""" """Publish a command."""

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import suppress from contextlib import suppress
from functools import partial
import logging import logging
from typing import Any from typing import Any
@@ -61,12 +62,7 @@ from .const import (
DEFAULT_RETAIN, DEFAULT_RETAIN,
PAYLOAD_NONE, PAYLOAD_NONE,
) )
from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper
from .mixins import (
MqttEntity,
async_setup_entity_entry_helper,
write_state_on_attr_change,
)
from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage
from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .schemas import MQTT_ENTITY_COMMON_SCHEMA
from .util import valid_publish_topic, valid_subscribe_topic from .util import valid_publish_topic, valid_subscribe_topic
@@ -302,22 +298,8 @@ class MqttValve(MqttEntity, ValveEntity):
return return
self._update_state(state) self._update_state(state)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics = {}
@callback @callback
@log_messages(self.hass, self.entity_id) def _state_message_received(self, msg: ReceiveMessage) -> None:
@write_state_on_attr_change(
self,
{
"_attr_current_valve_position",
"_attr_is_closed",
"_attr_is_closing",
"_attr_is_opening",
},
)
def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages.""" """Handle new MQTT state messages."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
payload_dict: Any = None payload_dict: Any = None
@@ -351,16 +333,28 @@ class MqttValve(MqttEntity, ValveEntity):
state_payload = payload_dict.get("state") state_payload = payload_dict.get("state")
if self._config[CONF_REPORTS_POSITION]: if self._config[CONF_REPORTS_POSITION]:
self._process_position_valve_update( self._process_position_valve_update(msg, position_payload, state_payload)
msg, position_payload, state_payload
)
else: else:
self._process_binary_valve_update(msg, state_payload) self._process_binary_valve_update(msg, state_payload)
def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics."""
topics = {}
if self._config.get(CONF_STATE_TOPIC): if self._config.get(CONF_STATE_TOPIC):
topics["state_topic"] = { topics["state_topic"] = {
"topic": self._config.get(CONF_STATE_TOPIC), "topic": self._config.get(CONF_STATE_TOPIC),
"msg_callback": state_message_received, "msg_callback": partial(
self._message_callback,
self._state_message_received,
{
"_attr_current_valve_position",
"_attr_is_closed",
"_attr_is_closing",
"_attr_is_opening",
},
),
"entity_id": self.entity_id,
"qos": self._config[CONF_QOS], "qos": self._config[CONF_QOS],
"encoding": self._config[CONF_ENCODING] or None, "encoding": self._config[CONF_ENCODING] or None,
} }
@@ -371,7 +365,7 @@ class MqttValve(MqttEntity, ValveEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) subscription.async_subscribe_topics_internal(self.hass, self._sub_state)
async def async_open_valve(self) -> None: async def async_open_valve(self) -> None:
"""Move the valve up. """Move the valve up.

View File

@@ -9,6 +9,7 @@ from typing import Literal
import ollama import ollama
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
@@ -138,6 +139,11 @@ class OllamaConversationEntity(
ollama.Message(role=MessageRole.USER.value, content=user_input.text) ollama.Message(role=MessageRole.USER.value, content=user_input.text)
) )
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": message_history.messages},
)
# Get response # Get response
try: try:
response = await client.chat( response = await client.chat(

View File

@@ -31,14 +31,15 @@ from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
CONF_PROMPT, CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT, DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -49,6 +50,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
} }
) )
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT,
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
@@ -88,7 +95,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry( return self.async_create_entry(
title="ChatGPT", title="ChatGPT",
data=user_input, data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, options=RECOMMENDED_OPTIONS,
) )
return self.async_show_form( return self.async_show_form(
@@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow):
def __init__(self, config_entry: ConfigEntry) -> None: def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow.""" """Initialize options flow."""
self.config_entry = config_entry self.config_entry = config_entry
self.last_rendered_recommended = config_entry.options.get(
CONF_RECOMMENDED, False
)
async def async_step_init( async def async_step_init(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Manage the options.""" """Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options
if user_input is not None: if user_input is not None:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none": if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API) user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input) return self.async_create_entry(title="", data=user_input)
schema = openai_config_option_schema(self.hass, self.config_entry.options)
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}
schema = openai_config_option_schema(self.hass, options)
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="init",
data_schema=vol.Schema(schema), data_schema=vol.Schema(schema),
@@ -127,16 +150,16 @@ class OpenAIOptionsFlow(OptionsFlow):
def openai_config_option_schema( def openai_config_option_schema(
hass: HomeAssistant, hass: HomeAssistant,
options: MappingProxyType[str, Any], options: dict[str, Any] | MappingProxyType[str, Any],
) -> dict: ) -> dict:
"""Return a schema for OpenAI completion options.""" """Return a schema for OpenAI completion options."""
apis: list[SelectOptionDict] = [ hass_apis: list[SelectOptionDict] = [
SelectOptionDict( SelectOptionDict(
label="No control", label="No control",
value="none", value="none",
) )
] ]
apis.extend( hass_apis.extend(
SelectOptionDict( SelectOptionDict(
label=api.name, label=api.name,
value=api.id, value=api.id,
@@ -144,38 +167,46 @@ def openai_config_option_schema(
for api in llm.async_get_apis(hass) for api in llm.async_get_apis(hass)
) )
return { schema = {
vol.Optional( vol.Optional(
CONF_PROMPT, CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)}, description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(), ): TemplateSelector(),
vol.Optional( vol.Optional(
CONF_LLM_HASS_API, CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)}, description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none", default="none",
): SelectSelector(SelectSelectorConfig(options=apis)), ): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
if options.get(CONF_RECOMMENDED):
return schema
schema.update(
{
vol.Optional( vol.Optional(
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
description={ description={"suggested_value": options.get(CONF_CHAT_MODEL)},
# New key in HA 2023.4 default=RECOMMENDED_CHAT_MODEL,
"suggested_value": options.get(CONF_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str, ): str,
vol.Optional( vol.Optional(
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)}, description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS, default=RECOMMENDED_MAX_TOKENS,
): int, ): int,
vol.Optional( vol.Optional(
CONF_TOP_P, CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)}, description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P, default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional( vol.Optional(
CONF_TEMPERATURE, CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)}, description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE, default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
} }
)
return schema

View File

@@ -4,13 +4,15 @@ import logging
DOMAIN = "openai_conversation" DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
CONF_RECOMMENDED = "recommended"
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point.""" DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-4o" RECOMMENDED_CHAT_MODEL = "gpt-4o"
CONF_MAX_TOKENS = "max_tokens" CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150 RECOMMENDED_MAX_TOKENS = 150
CONF_TOP_P = "top_p" CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 1.0 RECOMMENDED_TOP_P = 1.0
CONF_TEMPERATURE = "temperature" CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 1.0 RECOMMENDED_TEMPERATURE = 1.0

View File

@@ -8,6 +8,7 @@ import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -22,13 +23,13 @@ from .const import (
CONF_PROMPT, CONF_PROMPT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT, DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
) )
# Max number of back and forth with the LLM to generate a response # Max number of back and forth with the LLM to generate a response
@@ -97,15 +98,14 @@ class OpenAIConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language) intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
if self.entry.options.get(CONF_LLM_HASS_API): if options.get(CONF_LLM_HASS_API):
try: try:
llm_api = llm.async_get_api( llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError as err: except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err) LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error( intent_response.async_set_error(
@@ -117,26 +117,12 @@ class OpenAIConversationEntity(
) )
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
if user_input.conversation_id in self.history: if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id conversation_id = user_input.conversation_id
messages = self.history[conversation_id] messages = self.history[conversation_id]
else: else:
conversation_id = ulid.ulid_now() conversation_id = ulid.ulid_now()
try: try:
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
if llm_api: if llm_api:
empty_tool_input = llm.ToolInput( empty_tool_input = llm.ToolInput(
tool_name="", tool_name="",
@@ -149,10 +135,23 @@ class OpenAIConversationEntity(
device_id=user_input.device_id, device_id=user_input.device_id,
) )
prompt = ( api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)
await llm_api.async_get_api_prompt(empty_tool_input)
+ "\n" else:
+ prompt api_prompt = llm.PROMPT_NO_API_CONFIGURED
prompt = "\n".join(
(
template.Template(
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
),
api_prompt,
)
) )
except TemplateError as err: except TemplateError as err:
@@ -170,7 +169,10 @@ class OpenAIConversationEntity(
messages.append({"role": "user", "content": user_input.text}) messages.append({"role": "user", "content": user_input.text})
LOGGER.debug("Prompt for %s: %s", model, messages) LOGGER.debug("Prompt: %s", messages)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages}
)
client = self.hass.data[DOMAIN][self.entry.entry_id] client = self.hass.data[DOMAIN][self.entry.entry_id]
@@ -178,12 +180,12 @@ class OpenAIConversationEntity(
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
try: try:
result = await client.chat.completions.create( result = await client.chat.completions.create(
model=model, model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages, messages=messages,
tools=tools, tools=tools,
max_tokens=max_tokens, max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=top_p, top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=temperature, temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=conversation_id, user=conversation_id,
) )
except openai.OpenAIError as err: except openai.OpenAIError as err:

View File

@@ -22,7 +22,8 @@
"max_tokens": "Maximum tokens to return in response", "max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature", "temperature": "Temperature",
"top_p": "Top P", "top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
}, },
"data_description": { "data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template." "prompt": "Instruct how the LLM should respond. This can be a template."

View File

@@ -30,7 +30,6 @@ from .util import (
PLATFORMS = [Platform.MEDIA_PLAYER] PLATFORMS = [Platform.MEDIA_PLAYER]
__all__ = [ __all__ = [
"async_browse_media", "async_browse_media",
"DOMAIN", "DOMAIN",
@@ -50,7 +49,10 @@ class HomeAssistantSpotifyData:
session: OAuth2Session session: OAuth2Session
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: type SpotifyConfigEntry = ConfigEntry[HomeAssistantSpotifyData]
async def async_setup_entry(hass: HomeAssistant, entry: SpotifyConfigEntry) -> bool:
"""Set up Spotify from a config entry.""" """Set up Spotify from a config entry."""
implementation = await async_get_config_entry_implementation(hass, entry) implementation = await async_get_config_entry_implementation(hass, entry)
session = OAuth2Session(hass, entry, implementation) session = OAuth2Session(hass, entry, implementation)
@@ -100,8 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
await device_coordinator.async_config_entry_first_refresh() await device_coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {}) entry.runtime_data = HomeAssistantSpotifyData(
hass.data[DOMAIN][entry.entry_id] = HomeAssistantSpotifyData(
client=spotify, client=spotify,
current_user=current_user, current_user=current_user,
devices=device_coordinator, devices=device_coordinator,
@@ -117,6 +118,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Spotify config entry.""" """Unload Spotify config entry."""
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
del hass.data[DOMAIN][entry.entry_id]
return unload_ok

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from enum import StrEnum from enum import StrEnum
from functools import partial from functools import partial
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
from spotipy import Spotify from spotipy import Spotify
import yarl import yarl
@@ -22,6 +22,9 @@ from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES
from .util import fetch_image_url from .util import fetch_image_url
if TYPE_CHECKING:
from . import HomeAssistantSpotifyData
BROWSE_LIMIT = 48 BROWSE_LIMIT = 48
@@ -140,21 +143,21 @@ async def async_browse_media(
# Check if caller is requesting the root nodes # Check if caller is requesting the root nodes
if media_content_type is None and media_content_id is None: if media_content_type is None and media_content_id is None:
children = [] config_entries = hass.config_entries.async_entries(
for config_entry_id in hass.data[DOMAIN]: DOMAIN, include_disabled=False, include_ignore=False
config_entry = hass.config_entries.async_get_entry(config_entry_id) )
assert config_entry is not None children = [
children.append(
BrowseMedia( BrowseMedia(
title=config_entry.title, title=config_entry.title,
media_class=MediaClass.APP, media_class=MediaClass.APP,
media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry_id}", media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry.entry_id}",
media_content_type=f"{MEDIA_PLAYER_PREFIX}library", media_content_type=f"{MEDIA_PLAYER_PREFIX}library",
thumbnail="https://brands.home-assistant.io/_/spotify/logo.png", thumbnail="https://brands.home-assistant.io/_/spotify/logo.png",
can_play=False, can_play=False,
can_expand=True, can_expand=True,
) )
) for config_entry in config_entries
]
return BrowseMedia( return BrowseMedia(
title="Spotify", title="Spotify",
media_class=MediaClass.APP, media_class=MediaClass.APP,
@@ -171,9 +174,15 @@ async def async_browse_media(
# Check for config entry specifier, and extract Spotify URI # Check for config entry specifier, and extract Spotify URI
parsed_url = yarl.URL(media_content_id) parsed_url = yarl.URL(media_content_id)
if (info := hass.data[DOMAIN].get(parsed_url.host)) is None:
if (
parsed_url.host is None
or (entry := hass.config_entries.async_get_entry(parsed_url.host)) is None
or not isinstance(entry.runtime_data, HomeAssistantSpotifyData)
):
raise BrowseError("Invalid Spotify account specified") raise BrowseError("Invalid Spotify account specified")
media_content_id = parsed_url.name media_content_id = parsed_url.name
info = entry.runtime_data
result = await async_browse_media_internal( result = await async_browse_media_internal(
hass, hass,

View File

@@ -22,7 +22,6 @@ from homeassistant.components.media_player import (
MediaType, MediaType,
RepeatMode, RepeatMode,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ID from homeassistant.const import CONF_ID
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -30,7 +29,7 @@ from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from . import HomeAssistantSpotifyData from . import HomeAssistantSpotifyData, SpotifyConfigEntry
from .browse_media import async_browse_media_internal from .browse_media import async_browse_media_internal
from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES
from .util import fetch_image_url from .util import fetch_image_url
@@ -70,12 +69,12 @@ SPOTIFY_DJ_PLAYLIST = {"uri": "spotify:playlist:37i9dQZF1EYkqdzj48dyYq", "name":
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
entry: ConfigEntry, entry: SpotifyConfigEntry,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Spotify based on a config entry.""" """Set up Spotify based on a config entry."""
spotify = SpotifyMediaPlayer( spotify = SpotifyMediaPlayer(
hass.data[DOMAIN][entry.entry_id], entry.runtime_data,
entry.data[CONF_ID], entry.data[CONF_ID],
entry.title, entry.title,
) )

View File

@@ -4,15 +4,14 @@ from __future__ import annotations
import logging import logging
from aioswitcher.bridge import SwitcherBridge
from aioswitcher.device import SwitcherBase from aioswitcher.device import SwitcherBase
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from .const import DATA_DEVICE, DOMAIN
from .coordinator import SwitcherDataUpdateCoordinator from .coordinator import SwitcherDataUpdateCoordinator
from .utils import async_start_bridge, async_stop_bridge
PLATFORMS = [ PLATFORMS = [
Platform.BUTTON, Platform.BUTTON,
@@ -25,20 +24,20 @@ PLATFORMS = [
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: type SwitcherConfigEntry = ConfigEntry[dict[str, SwitcherDataUpdateCoordinator]]
async def async_setup_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
"""Set up Switcher from a config entry.""" """Set up Switcher from a config entry."""
hass.data.setdefault(DOMAIN, {})
hass.data[DOMAIN][DATA_DEVICE] = {}
@callback @callback
def on_device_data_callback(device: SwitcherBase) -> None: def on_device_data_callback(device: SwitcherBase) -> None:
"""Use as a callback for device data.""" """Use as a callback for device data."""
coordinators = entry.runtime_data
# Existing device update device data # Existing device update device data
if device.device_id in hass.data[DOMAIN][DATA_DEVICE]: if coordinator := coordinators.get(device.device_id):
coordinator: SwitcherDataUpdateCoordinator = hass.data[DOMAIN][DATA_DEVICE][
device.device_id
]
coordinator.async_set_updated_data(device) coordinator.async_set_updated_data(device)
return return
@@ -52,18 +51,21 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
device.device_type.hex_rep, device.device_type.hex_rep,
) )
coordinator = hass.data[DOMAIN][DATA_DEVICE][device.device_id] = ( coordinator = SwitcherDataUpdateCoordinator(hass, entry, device)
SwitcherDataUpdateCoordinator(hass, entry, device)
)
coordinator.async_setup() coordinator.async_setup()
coordinators[device.device_id] = coordinator
# Must be ready before dispatcher is called # Must be ready before dispatcher is called
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
await async_start_bridge(hass, on_device_data_callback) entry.runtime_data = {}
bridge = SwitcherBridge(on_device_data_callback)
await bridge.start()
async def stop_bridge(event: Event) -> None: async def stop_bridge(event: Event | None = None) -> None:
await async_stop_bridge(hass) await bridge.stop()
entry.async_on_unload(stop_bridge)
entry.async_on_unload( entry.async_on_unload(
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge)
@@ -72,12 +74,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
await async_stop_bridge(hass) return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(DATA_DEVICE)
return unload_ok

View File

@@ -15,7 +15,6 @@ from aioswitcher.api.remotes import SwitcherBreezeRemote
from aioswitcher.device import DeviceCategory from aioswitcher.device import DeviceCategory
from homeassistant.components.button import ButtonEntity, ButtonEntityDescription from homeassistant.components.button import ButtonEntity, ButtonEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -25,6 +24,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import SwitcherConfigEntry
from .const import SIGNAL_DEVICE_ADD from .const import SIGNAL_DEVICE_ADD
from .coordinator import SwitcherDataUpdateCoordinator from .coordinator import SwitcherDataUpdateCoordinator
from .utils import get_breeze_remote_manager from .utils import get_breeze_remote_manager
@@ -78,7 +78,7 @@ THERMOSTAT_BUTTONS = [
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: ConfigEntry, config_entry: SwitcherConfigEntry,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Switcher button from config entry.""" """Set up Switcher button from config entry."""

View File

@@ -25,7 +25,6 @@ from homeassistant.components.climate import (
ClimateEntityFeature, ClimateEntityFeature,
HVACMode, HVACMode,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -35,6 +34,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import SwitcherConfigEntry
from .const import SIGNAL_DEVICE_ADD from .const import SIGNAL_DEVICE_ADD
from .coordinator import SwitcherDataUpdateCoordinator from .coordinator import SwitcherDataUpdateCoordinator
from .utils import get_breeze_remote_manager from .utils import get_breeze_remote_manager
@@ -61,7 +61,7 @@ HA_TO_DEVICE_FAN = {value: key for key, value in DEVICE_FAN_TO_HA.items()}
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: ConfigEntry, config_entry: SwitcherConfigEntry,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Switcher climate from config entry.""" """Set up Switcher climate from config entry."""

View File

@@ -2,9 +2,6 @@
DOMAIN = "switcher_kis" DOMAIN = "switcher_kis"
DATA_BRIDGE = "bridge"
DATA_DEVICE = "device"
DISCOVERY_TIME_SEC = 12 DISCOVERY_TIME_SEC = 12
SIGNAL_DEVICE_ADD = "switcher_device_add" SIGNAL_DEVICE_ADD = "switcher_device_add"

View File

@@ -6,24 +6,23 @@ from dataclasses import asdict
from typing import Any from typing import Any
from homeassistant.components.diagnostics import async_redact_data from homeassistant.components.diagnostics import async_redact_data
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .const import DATA_DEVICE, DOMAIN from . import SwitcherConfigEntry
TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"} TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"}
async def async_get_config_entry_diagnostics( async def async_get_config_entry_diagnostics(
hass: HomeAssistant, entry: ConfigEntry hass: HomeAssistant, entry: SwitcherConfigEntry
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
devices = hass.data[DOMAIN][DATA_DEVICE] coordinators = entry.runtime_data
return async_redact_data( return async_redact_data(
{ {
"entry": entry.as_dict(), "entry": entry.as_dict(),
"devices": [asdict(devices[d].data) for d in devices], "devices": [asdict(coordinators[d].data) for d in coordinators],
}, },
TO_REDACT, TO_REDACT,
) )

View File

@@ -3,9 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable
import logging import logging
from typing import Any
from aioswitcher.api.remotes import SwitcherBreezeRemoteManager from aioswitcher.api.remotes import SwitcherBreezeRemoteManager
from aioswitcher.bridge import SwitcherBase, SwitcherBridge from aioswitcher.bridge import SwitcherBase, SwitcherBridge
@@ -13,29 +11,11 @@ from aioswitcher.bridge import SwitcherBase, SwitcherBridge
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import singleton from homeassistant.helpers import singleton
from .const import DATA_BRIDGE, DISCOVERY_TIME_SEC, DOMAIN from .const import DISCOVERY_TIME_SEC
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_start_bridge(
hass: HomeAssistant, on_device_callback: Callable[[SwitcherBase], Any]
) -> None:
"""Start switcher UDP bridge."""
bridge = hass.data[DOMAIN][DATA_BRIDGE] = SwitcherBridge(on_device_callback)
_LOGGER.debug("Starting Switcher bridge")
await bridge.start()
async def async_stop_bridge(hass: HomeAssistant) -> None:
"""Stop switcher UDP bridge."""
bridge: SwitcherBridge = hass.data[DOMAIN].get(DATA_BRIDGE)
if bridge is not None:
_LOGGER.debug("Stopping Switcher bridge")
await bridge.stop()
hass.data[DOMAIN].pop(DATA_BRIDGE)
async def async_has_devices(hass: HomeAssistant) -> bool: async def async_has_devices(hass: HomeAssistant) -> bool:
"""Discover Switcher devices.""" """Discover Switcher devices."""
_LOGGER.debug("Starting discovery") _LOGGER.debug("Starting discovery")

View File

@@ -30,6 +30,7 @@ PLATFORMS: Final = [
Platform.BINARY_SENSOR, Platform.BINARY_SENSOR,
Platform.CLIMATE, Platform.CLIMATE,
Platform.COVER, Platform.COVER,
Platform.DEVICE_TRACKER,
Platform.LOCK, Platform.LOCK,
Platform.SELECT, Platform.SELECT,
Platform.SENSOR, Platform.SENSOR,

View File

@@ -0,0 +1,85 @@
"""Device tracker platform for Teslemetry integration."""
from __future__ import annotations
from homeassistant.components.device_tracker import SourceType
from homeassistant.components.device_tracker.config_entry import TrackerEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .entity import TeslemetryVehicleEntity
from .models import TeslemetryVehicleData
async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
) -> None:
"""Set up the Teslemetry device tracker platform from a config entry."""
async_add_entities(
klass(vehicle)
for klass in (
TeslemetryDeviceTrackerLocationEntity,
TeslemetryDeviceTrackerRouteEntity,
)
for vehicle in entry.runtime_data.vehicles
)
class TeslemetryDeviceTrackerEntity(TeslemetryVehicleEntity, TrackerEntity):
"""Base class for Teslemetry tracker entities."""
lat_key: str
lon_key: str
def __init__(
self,
vehicle: TeslemetryVehicleData,
) -> None:
"""Initialize the device tracker."""
super().__init__(vehicle, self.key)
def _async_update_attrs(self) -> None:
"""Update the attributes of the device tracker."""
self._attr_available = (
self.get(self.lat_key, False) is not None
and self.get(self.lon_key, False) is not None
)
@property
def latitude(self) -> float | None:
"""Return latitude value of the device."""
return self.get(self.lat_key)
@property
def longitude(self) -> float | None:
"""Return longitude value of the device."""
return self.get(self.lon_key)
@property
def source_type(self) -> SourceType:
"""Return the source type of the device tracker."""
return SourceType.GPS
class TeslemetryDeviceTrackerLocationEntity(TeslemetryDeviceTrackerEntity):
"""Vehicle location device tracker class."""
key = "location"
lat_key = "drive_state_latitude"
lon_key = "drive_state_longitude"
class TeslemetryDeviceTrackerRouteEntity(TeslemetryDeviceTrackerEntity):
"""Vehicle navigation device tracker class."""
key = "route"
lat_key = "drive_state_active_route_latitude"
lon_key = "drive_state_active_route_longitude"
@property
def location_name(self) -> str | None:
"""Return a location name for the current location of the device."""
return self.get("drive_state_active_route_destination")

View File

@@ -109,6 +109,7 @@
"off": "mdi:car-seat" "off": "mdi:car-seat"
} }
}, },
"components_customer_preferred_export_rule": { "components_customer_preferred_export_rule": {
"default": "mdi:transmission-tower", "default": "mdi:transmission-tower",
"state": { "state": {
@@ -126,6 +127,14 @@
} }
} }
}, },
"device_tracker": {
"location": {
"default": "mdi:map-marker"
},
"route": {
"default": "mdi:routes"
}
},
"cover": { "cover": {
"charge_state_charge_port_door_open": { "charge_state_charge_port_door_open": {
"default": "mdi:ev-plug-ccs2" "default": "mdi:ev-plug-ccs2"

View File

@@ -111,6 +111,14 @@
} }
} }
}, },
"device_tracker": {
"location": {
"name": "Location"
},
"route": {
"name": "Route"
}
},
"lock": { "lock": {
"charge_state_charge_port_latch": { "charge_state_charge_port_latch": {
"name": "Charge cable lock" "name": "Charge cable lock"

View File

@@ -13,7 +13,7 @@
"velbus-packet", "velbus-packet",
"velbus-protocol" "velbus-protocol"
], ],
"requirements": ["velbus-aio==2024.4.1"], "requirements": ["velbus-aio==2024.5.1"],
"usb": [ "usb": [
{ {
"vid": "10CF", "vid": "10CF",

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Awaitable
import logging import logging
from typing import Any from typing import Any
@@ -157,16 +156,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Unload Withings config entry.""" """Unload vera config entry."""
controller_data: ControllerData = get_controller_data(hass, config_entry) controller_data: ControllerData = get_controller_data(hass, config_entry)
await asyncio.gather(
tasks: list[Awaitable] = [ *(
hass.config_entries.async_forward_entry_unload(config_entry, platform) hass.config_entries.async_unload_platforms(
for platform in get_configured_platforms(controller_data) config_entry, get_configured_platforms(controller_data)
] ),
tasks.append(hass.async_add_executor_job(controller_data.controller.stop)) hass.async_add_executor_job(controller_data.controller.stop),
await asyncio.gather(*tasks) )
)
return True return True

View File

@@ -1,6 +1,5 @@
"""Support for Zigbee Home Automation devices.""" """Support for Zigbee Home Automation devices."""
import asyncio
import contextlib import contextlib
import copy import copy
import logging import logging
@@ -238,12 +237,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
websocket_api.async_unload_api(hass) websocket_api.async_unload_api(hass)
# our components don't have unload methods so no need to look at return values # our components don't have unload methods so no need to look at return values
await asyncio.gather( await hass.config_entries.async_unload_platforms(config_entry, PLATFORMS)
*(
hass.config_entries.async_forward_entry_unload(config_entry, platform)
for platform in PLATFORMS
)
)
return True return True

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from collections.abc import Coroutine
from contextlib import suppress from contextlib import suppress
import logging import logging
from typing import Any from typing import Any
@@ -958,14 +957,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
client: ZwaveClient = entry.runtime_data[DATA_CLIENT] client: ZwaveClient = entry.runtime_data[DATA_CLIENT]
driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS] driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS]
platforms = [
tasks: list[Coroutine] = [ platform
hass.config_entries.async_forward_entry_unload(entry, platform)
for platform, task in driver_events.platform_setup_tasks.items() for platform, task in driver_events.platform_setup_tasks.items()
if not task.cancel() if not task.cancel()
] ]
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
unload_ok = all(await asyncio.gather(*tasks)) if tasks else True
if client.connected and client.driver: if client.connected and client.driver:
await async_disable_server_logging_if_needed(hass, entry, client.driver) await async_disable_server_logging_if_needed(hass, entry, client.driver)

View File

@@ -3,12 +3,16 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import asdict, dataclass
from typing import Any from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE
from homeassistant.components.conversation.trace import (
ConversationTraceEventType,
async_conversation_trace_append,
)
from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.components.weather.intent import INTENT_GET_WEATHER
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -116,6 +120,10 @@ class API(ABC):
async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType:
"""Call a LLM tool, validate args and return the response.""" """Call a LLM tool, validate args and return the response."""
async_conversation_trace_append(
ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input)
)
for tool in self.async_get_tools(): for tool in self.async_get_tools():
if tool.name == tool_input.tool_name: if tool.name == tool_input.tool_name:
break break
@@ -191,7 +199,10 @@ class AssistAPI(API):
async def async_get_api_prompt(self, tool_input: ToolInput) -> str: async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API.""" """Return the prompt for the API."""
prompt = "Call the intent tools to control Home Assistant. Just pass the name to the intent." prompt = (
"Call the intent tools to control Home Assistant. "
"Just pass the name to the intent."
)
if tool_input.device_id: if tool_input.device_id:
device_reg = device_registry.async_get(self.hass) device_reg = device_registry.async_get(self.hass)
device = device_reg.async_get(tool_input.device_id) device = device_reg.async_get(tool_input.device_id)

View File

@@ -1821,7 +1821,7 @@ pyegps==0.2.5
pyenphase==1.20.3 pyenphase==1.20.3
# homeassistant.components.envisalink # homeassistant.components.envisalink
pyenvisalink==4.6 pyenvisalink==4.7
# homeassistant.components.ephember # homeassistant.components.ephember
pyephember==0.3.1 pyephember==0.3.1
@@ -2817,7 +2817,7 @@ vallox-websocket-api==5.1.1
vehicle==2.2.1 vehicle==2.2.1
# homeassistant.components.velbus # homeassistant.components.velbus
velbus-aio==2024.4.1 velbus-aio==2024.5.1
# homeassistant.components.venstar # homeassistant.components.venstar
venstarcolortouch==0.19 venstarcolortouch==0.19

View File

@@ -2185,7 +2185,7 @@ vallox-websocket-api==5.1.1
vehicle==2.2.1 vehicle==2.2.1
# homeassistant.components.velbus # homeassistant.components.velbus
velbus-aio==2024.4.1 velbus-aio==2024.5.1
# homeassistant.components.venstar # homeassistant.components.venstar
venstarcolortouch==0.19 venstarcolortouch==0.19

View File

@@ -117,7 +117,6 @@ NO_IOT_CLASS = [
# https://github.com/home-assistant/developers.home-assistant/pull/1512 # https://github.com/home-assistant/developers.home-assistant/pull/1512
NO_DIAGNOSTICS = [ NO_DIAGNOSTICS = [
"dlna_dms", "dlna_dms",
"fronius",
"gdacs", "gdacs",
"geonetnz_quakes", "geonetnz_quakes",
"google_assistant_sdk", "google_assistant_sdk",

View File

@@ -2,7 +2,9 @@
from unittest.mock import patch from unittest.mock import patch
from homeassistant.components import conversation
from homeassistant.core import Context, HomeAssistant, State from homeassistant.core import Context, HomeAssistant, State
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@@ -31,6 +33,11 @@ async def test_state_set_and_restore(hass: HomeAssistant) -> None:
) as mock_process, ) as mock_process,
patch("homeassistant.util.dt.utcnow", return_value=now), patch("homeassistant.util.dt.utcnow", return_value=now),
): ):
intent_response = intent.IntentResponse(language="en")
intent_response.async_set_speech("response text")
mock_process.return_value = conversation.ConversationResult(
response=intent_response,
)
await hass.services.async_call( await hass.services.async_call(
"conversation", "conversation",
"process", "process",

View File

@@ -0,0 +1,80 @@
"""Test for the conversation traces."""
from unittest.mock import patch
import pytest
from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
@pytest.fixture
async def init_components(hass: HomeAssistant):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "conversation", {})
assert await async_setup_component(hass, "intent", {})
async def test_converation_trace(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert trace_event.get("data")
assert trace_event["data"].get("text") == "add apples to my shopping list"
assert last_trace.get("result")
assert (
last_trace["result"]
.get("response", {})
.get("speech", {})
.get("plain", {})
.get("speech")
== "Added apples"
)
async def test_converation_trace_error(
hass: HomeAssistant,
init_components: None,
sl_setup: None,
) -> None:
"""Test tracing a conversation."""
with (
patch(
"homeassistant.components.conversation.default_agent.DefaultAgent.async_process",
side_effect=HomeAssistantError("Failed to talk to agent"),
),
pytest.raises(HomeAssistantError),
):
await conversation.async_converse(
hass, "add apples to my shopping list", None, Context()
)
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
assert last_trace.get("events")
assert len(last_trace.get("events")) == 1
trace_event = last_trace["events"][0]
assert (
trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS
)
assert last_trace.get("error") == "Failed to talk to agent"

View File

@@ -25,6 +25,7 @@ async def setup_fronius_integration(
"""Create the Fronius integration.""" """Create the Fronius integration."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
entry_id="f1e2b9837e8adaed6fa682acaa216fd8",
unique_id=unique_id, # has to match mocked logger unique_id unique_id=unique_id, # has to match mocked logger unique_id
data={ data={
CONF_HOST: MOCK_HOST, CONF_HOST: MOCK_HOST,

View File

@@ -0,0 +1,370 @@
# serializer version: 1
# name: test_diagnostics
dict({
'config_entry': dict({
'data': dict({
'host': 'http://fronius',
'is_logger': True,
}),
'disabled_by': None,
'domain': 'fronius',
'entry_id': 'f1e2b9837e8adaed6fa682acaa216fd8',
'minor_version': 1,
'options': dict({
}),
'pref_disable_new_entities': False,
'pref_disable_polling': False,
'source': 'user',
'title': 'Mock Title',
'unique_id': '**REDACTED**',
'version': 1,
}),
'coordinators': dict({
'inverters': dict({
'1': dict({
'current_ac': dict({
'unit': 'A',
'value': 5.19,
}),
'current_dc': dict({
'unit': 'A',
'value': 2.19,
}),
'energy_day': dict({
'unit': 'Wh',
'value': 1113,
}),
'energy_total': dict({
'unit': 'Wh',
'value': 44188000,
}),
'energy_year': dict({
'unit': 'Wh',
'value': 25508798,
}),
'error_code': dict({
'value': 0,
}),
'frequency_ac': dict({
'unit': 'Hz',
'value': 49.94,
}),
'led_color': dict({
'value': 2,
}),
'led_state': dict({
'value': 0,
}),
'power_ac': dict({
'unit': 'W',
'value': 1190,
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'status_code': dict({
'value': 7,
}),
'timestamp': dict({
'value': '2021-10-07T10:01:17+02:00',
}),
'voltage_ac': dict({
'unit': 'V',
'value': 227.9,
}),
'voltage_dc': dict({
'unit': 'V',
'value': 518,
}),
}),
}),
'logger': dict({
'system': dict({
'cash_factor': dict({
'unit': 'EUR/kWh',
'value': 0.07800000160932541,
}),
'co2_factor': dict({
'unit': 'kg/kWh',
'value': 0.5299999713897705,
}),
'delivery_factor': dict({
'unit': 'EUR/kWh',
'value': 0.15000000596046448,
}),
'hardware_platform': dict({
'value': 'wilma',
}),
'hardware_version': dict({
'value': '2.4E',
}),
'product_type': dict({
'value': 'fronius-datamanager-card',
}),
'software_version': dict({
'value': '3.18.7-1',
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'time_zone': dict({
'value': 'CEST',
}),
'time_zone_location': dict({
'value': 'Vienna',
}),
'timestamp': dict({
'value': '2021-10-06T23:56:32+02:00',
}),
'unique_identifier': '**REDACTED**',
'utc_offset': dict({
'value': 7200,
}),
}),
}),
'meter': dict({
'0': dict({
'current_ac_phase_1': dict({
'unit': 'A',
'value': 7.755,
}),
'current_ac_phase_2': dict({
'unit': 'A',
'value': 6.68,
}),
'current_ac_phase_3': dict({
'unit': 'A',
'value': 10.102,
}),
'enable': dict({
'value': 1,
}),
'energy_reactive_ac_consumed': dict({
'unit': 'VArh',
'value': 59960790,
}),
'energy_reactive_ac_produced': dict({
'unit': 'VArh',
'value': 723160,
}),
'energy_real_ac_minus': dict({
'unit': 'Wh',
'value': 35623065,
}),
'energy_real_ac_plus': dict({
'unit': 'Wh',
'value': 15303334,
}),
'energy_real_consumed': dict({
'unit': 'Wh',
'value': 15303334,
}),
'energy_real_produced': dict({
'unit': 'Wh',
'value': 35623065,
}),
'frequency_phase_average': dict({
'unit': 'Hz',
'value': 50,
}),
'manufacturer': dict({
'value': 'Fronius',
}),
'meter_location': dict({
'value': 0,
}),
'model': dict({
'value': 'Smart Meter 63A',
}),
'power_apparent': dict({
'unit': 'VA',
'value': 5592.57,
}),
'power_apparent_phase_1': dict({
'unit': 'VA',
'value': 1772.793,
}),
'power_apparent_phase_2': dict({
'unit': 'VA',
'value': 1527.048,
}),
'power_apparent_phase_3': dict({
'unit': 'VA',
'value': 2333.562,
}),
'power_factor': dict({
'value': 1,
}),
'power_factor_phase_1': dict({
'value': -0.99,
}),
'power_factor_phase_2': dict({
'value': -0.99,
}),
'power_factor_phase_3': dict({
'value': 0.99,
}),
'power_reactive': dict({
'unit': 'VAr',
'value': 2.87,
}),
'power_reactive_phase_1': dict({
'unit': 'VAr',
'value': 51.48,
}),
'power_reactive_phase_2': dict({
'unit': 'VAr',
'value': 115.63,
}),
'power_reactive_phase_3': dict({
'unit': 'VAr',
'value': -164.24,
}),
'power_real': dict({
'unit': 'W',
'value': 5592.57,
}),
'power_real_phase_1': dict({
'unit': 'W',
'value': 1765.55,
}),
'power_real_phase_2': dict({
'unit': 'W',
'value': 1515.8,
}),
'power_real_phase_3': dict({
'unit': 'W',
'value': 2311.22,
}),
'serial': '**REDACTED**',
'visible': dict({
'value': 1,
}),
'voltage_ac_phase_1': dict({
'unit': 'V',
'value': 228.6,
}),
'voltage_ac_phase_2': dict({
'unit': 'V',
'value': 228.6,
}),
'voltage_ac_phase_3': dict({
'unit': 'V',
'value': 231,
}),
'voltage_ac_phase_to_phase_12': dict({
'unit': 'V',
'value': 395.9,
}),
'voltage_ac_phase_to_phase_23': dict({
'unit': 'V',
'value': 398,
}),
'voltage_ac_phase_to_phase_31': dict({
'unit': 'V',
'value': 398,
}),
}),
}),
'ohmpilot': None,
'power_flow': dict({
'power_flow': dict({
'energy_day': dict({
'unit': 'Wh',
'value': 1101.7000732421875,
}),
'energy_total': dict({
'unit': 'Wh',
'value': 44188000,
}),
'energy_year': dict({
'unit': 'Wh',
'value': 25508788,
}),
'meter_location': dict({
'value': 'grid',
}),
'meter_mode': dict({
'value': 'meter',
}),
'power_battery': dict({
'unit': 'W',
'value': None,
}),
'power_grid': dict({
'unit': 'W',
'value': 1703.74,
}),
'power_load': dict({
'unit': 'W',
'value': -2814.74,
}),
'power_photovoltaics': dict({
'unit': 'W',
'value': 1111,
}),
'relative_autonomy': dict({
'unit': '%',
'value': 39.4707859340472,
}),
'relative_self_consumption': dict({
'unit': '%',
'value': 100,
}),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'timestamp': dict({
'value': '2021-10-07T10:00:43+02:00',
}),
}),
}),
'storage': None,
}),
'inverter_info': dict({
'inverters': list([
dict({
'custom_name': dict({
'value': 'Symo 20',
}),
'device_id': dict({
'value': '1',
}),
'device_type': dict({
'manufacturer': 'Fronius',
'model': 'Symo 20.0-3-M',
'value': 121,
}),
'error_code': dict({
'value': 0,
}),
'pv_power': dict({
'unit': 'W',
'value': 23100,
}),
'show': dict({
'value': 1,
}),
'status_code': dict({
'value': 7,
}),
'unique_id': '**REDACTED**',
}),
]),
'status': dict({
'Code': 0,
'Reason': '',
'UserMessage': '',
}),
'timestamp': dict({
'value': '2021-10-07T13:41:00+02:00',
}),
}),
})
# ---

View File

@@ -0,0 +1,31 @@
"""Tests for the diagnostics data provided by the KNX integration."""
from syrupy import SnapshotAssertion
from homeassistant.core import HomeAssistant
from . import mock_responses, setup_fronius_integration
from tests.components.diagnostics import get_diagnostics_for_config_entry
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
async def test_diagnostics(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
snapshot: SnapshotAssertion,
) -> None:
"""Test diagnostics."""
mock_responses(aioclient_mock)
entry = await setup_fronius_integration(hass)
assert (
await get_diagnostics_for_config_entry(
hass,
hass_client,
entry,
)
== snapshot
)

View File

@@ -1,4 +1,114 @@
# serializer version: 1 # serializer version: 1
# name: test_chat_history
list([
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'1st user request',
),
dict({
}),
),
tuple(
'',
tuple(
),
dict({
'generation_config': dict({
'max_output_tokens': 150,
'temperature': 1.0,
'top_k': 64,
'top_p': 0.95,
}),
'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({
'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}),
'tools': None,
}),
),
tuple(
'().start_chat',
tuple(
),
dict({
'history': list([
dict({
'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user',
}),
dict({
'parts': 'Ok',
'role': 'model',
}),
dict({
'parts': '1st user request',
'role': 'user',
}),
dict({
'parts': '1st model response',
'role': 'model',
}),
]),
}),
),
tuple(
'().start_chat().send_message_async',
tuple(
'2nd user request',
),
dict({
}),
),
])
# ---
# name: test_default_prompt[config_entry_options0-None] # name: test_default_prompt[config_entry_options0-None]
list([ list([
tuple( tuple(
@@ -14,10 +124,10 @@
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'tools': None, 'tools': None,
}), }),
@@ -29,7 +139,10 @@
dict({ dict({
'history': list([ 'history': list([
dict({ dict({
'parts': 'Answer in plain text. Keep it simple and to the point.', 'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user', 'role': 'user',
}), }),
dict({ dict({
@@ -64,10 +177,10 @@
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'tools': None, 'tools': None,
}), }),
@@ -79,7 +192,10 @@
dict({ dict({
'history': list([ 'history': list([
dict({ dict({
'parts': 'Answer in plain text. Keep it simple and to the point.', 'parts': '''
Answer in plain text. Keep it simple and to the point.
Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant.
''',
'role': 'user', 'role': 'user',
}), }),
dict({ dict({
@@ -114,10 +230,10 @@
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'tools': None, 'tools': None,
}), }),
@@ -130,8 +246,8 @@
'history': list([ 'history': list([
dict({ dict({
'parts': ''' 'parts': '''
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
''', ''',
'role': 'user', 'role': 'user',
}), }),
@@ -167,10 +283,10 @@
}), }),
'model_name': 'models/gemini-1.5-flash-latest', 'model_name': 'models/gemini-1.5-flash-latest',
'safety_settings': dict({ 'safety_settings': dict({
'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE',
'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE',
'HATE': 'BLOCK_LOW_AND_ABOVE', 'HATE': 'BLOCK_MEDIUM_AND_ABOVE',
'SEXUAL': 'BLOCK_LOW_AND_ABOVE', 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE',
}), }),
'tools': None, 'tools': None,
}), }),
@@ -183,8 +299,8 @@
'history': list([ 'history': list([
dict({ dict({
'parts': ''' 'parts': '''
Call the intent tools to control Home Assistant. Just pass the name to the intent.
Answer in plain text. Keep it simple and to the point. Answer in plain text. Keep it simple and to the point.
Call the intent tools to control Home Assistant. Just pass the name to the intent.
''', ''',
'role': 'user', 'role': 'user',
}), }),

View File

@@ -2,12 +2,14 @@
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from google.api_core.exceptions import ClientError from google.api_core.exceptions import GoogleAPICallError
import google.generativeai.types as genai_types
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -150,6 +152,57 @@ async def test_default_prompt(
assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options) assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options)
async def test_chat_history(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the agent keeps track of the chat history."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
chat_response = MagicMock()
mock_chat.send_message_async.return_value = chat_response
mock_part = MagicMock()
mock_part.function_call = None
chat_response.parts = [mock_part]
chat_response.text = "1st model response"
mock_chat.history = [
{"role": "user", "parts": "prompt"},
{"role": "model", "parts": "Ok"},
{"role": "user", "parts": "1st user request"},
{"role": "model", "parts": "1st model response"},
]
result = await conversation.async_converse(
hass,
"1st user request",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.as_dict()["speech"]["plain"]["speech"]
== "1st model response"
)
chat_response.text = "2nd model response"
result = await conversation.async_converse(
hass,
"2nd user request",
result.conversation_id,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert (
result.response.as_dict()["speech"]["plain"]["speech"]
== "2nd model response"
)
assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot
@patch( @patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
) )
@@ -233,6 +286,20 @@ async def test_function_call(
), ),
) )
# Test conversating tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"]
@patch( @patch(
"homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools"
@@ -325,7 +392,7 @@ async def test_error_handling(
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = ClientError("some error") mock_chat.send_message_async.side_effect = GoogleAPICallError("some error")
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
) )
@@ -340,7 +407,28 @@ async def test_error_handling(
async def test_blocked_response( async def test_blocked_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None: ) -> None:
"""Test response was blocked.""" """Test blocked response."""
with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat
mock_chat.send_message_async.side_effect = genai_types.StopCandidateException(
"finish_reason: SAFETY\n"
)
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"The message got blocked by your safety settings"
)
async def test_empty_response(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test empty response."""
with patch("google.generativeai.GenerativeModel") as mock_model: with patch("google.generativeai.GenerativeModel") as mock_model:
mock_chat = AsyncMock() mock_chat = AsyncMock()
mock_model.return_value.start_chat.return_value = mock_chat mock_model.return_value.start_chat.return_value = mock_chat
@@ -358,6 +446,32 @@ async def test_blocked_response(
) )
async def test_invalid_llm_api(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test handling of invalid llm api."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"},
)
result = await conversation.async_converse(
hass,
"hello",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
assert result.response.as_dict()["speech"]["plain"]["speech"] == (
"Error preparing LLM API: API invalid_llm_api not found"
)
async def test_template_error( async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:

View File

@@ -529,16 +529,16 @@ async def test_non_unique_triggers(
async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 2 assert len(calls) == 2
assert calls[0].data["some"] == "press1" all_calls = {calls[0].data["some"], calls[1].data["some"]}
assert calls[1].data["some"] == "press2" assert all_calls == {"press1", "press2"}
# Trigger second config references to same trigger # Trigger second config references to same trigger
# and triggers both attached instances. # and triggers both attached instances.
async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press") async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press")
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(calls) == 2 assert len(calls) == 2
assert calls[0].data["some"] == "press1" all_calls = {calls[0].data["some"], calls[1].data["some"]}
assert calls[1].data["some"] == "press2" assert all_calls == {"press1", "press2"}
# Removing the first trigger will clean up # Removing the first trigger will clean up
calls.clear() calls.clear()

View File

@@ -4,6 +4,7 @@ import asyncio
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial
import json import json
import logging import logging
import socket import socket
@@ -1050,6 +1051,27 @@ async def test_subscribe_topic_not_initialize(
await mqtt.async_subscribe(hass, "test-topic", record_calls) await mqtt.async_subscribe(hass, "test-topic", record_calls)
async def test_subscribe_mqtt_config_entry_disabled(
hass: HomeAssistant, mqtt_mock: MqttMockHAClient
) -> None:
"""Test the subscription of a topic when MQTT config entry is disabled."""
mqtt_mock.connected = True
mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0]
assert mqtt_config_entry.state is ConfigEntryState.LOADED
assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id)
assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED
await hass.config_entries.async_set_disabled_by(
mqtt_config_entry.entry_id, ConfigEntryDisabler.USER
)
mqtt_mock.connected = False
with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"):
await mqtt.async_subscribe(hass, "test-topic", record_calls)
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2) @patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2)
async def test_subscribe_and_resubscribe( async def test_subscribe_and_resubscribe(
@@ -2912,8 +2934,8 @@ async def test_message_callback_exception_gets_logged(
await mqtt_mock_entry() await mqtt_mock_entry()
@callback @callback
def bad_handler(*args) -> None: def bad_handler(msg: ReceiveMessage) -> None:
"""Record calls.""" """Handle callback."""
raise ValueError("This is a bad message callback") raise ValueError("This is a bad message callback")
await mqtt.async_subscribe(hass, "test-topic", bad_handler) await mqtt.async_subscribe(hass, "test-topic", bad_handler)
@@ -2926,6 +2948,40 @@ async def test_message_callback_exception_gets_logged(
) )
@pytest.mark.no_fail_on_log_exception
async def test_message_partial_callback_exception_gets_logged(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
mqtt_mock_entry: MqttMockHAClientGenerator,
) -> None:
"""Test exception raised by message handler."""
await mqtt_mock_entry()
@callback
def bad_handler(msg: ReceiveMessage) -> None:
"""Handle callback."""
raise ValueError("This is a bad message callback")
def parial_handler(
msg_callback: MessageCallbackType,
attributes: set[str],
msg: ReceiveMessage,
) -> None:
"""Partial callback handler."""
msg_callback(msg)
await mqtt.async_subscribe(
hass, "test-topic", partial(parial_handler, bad_handler, {"some_attr"})
)
async_fire_mqtt_message(hass, "test-topic", "test")
await hass.async_block_till_done()
assert (
"Exception in bad_handler when handling msg on 'test-topic':"
" 'test'" in caplog.text
)
async def test_mqtt_ws_subscription( async def test_mqtt_ws_subscription(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
@@ -3787,7 +3843,7 @@ async def test_unload_config_entry(
async def test_publish_or_subscribe_without_valid_config_entry( async def test_publish_or_subscribe_without_valid_config_entry(
hass: HomeAssistant, record_calls: MessageCallbackType hass: HomeAssistant, record_calls: MessageCallbackType
) -> None: ) -> None:
"""Test internal publish function with bas use cases.""" """Test internal publish function with bad use cases."""
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
await mqtt.async_publish( await mqtt.async_publish(
hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None

View File

@@ -11,8 +11,12 @@ from .const import DEFAULT_FORECAST, DEFAULT_OBSERVATION
@pytest.fixture @pytest.fixture
def mock_simple_nws(): def mock_simple_nws():
"""Mock pynws SimpleNWS with default values.""" """Mock pynws SimpleNWS with default values."""
# set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws: with (
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_nws.return_value instance = mock_nws.return_value
instance.set_station = AsyncMock(return_value=None) instance.set_station = AsyncMock(return_value=None)
instance.update_observation = AsyncMock(return_value=None) instance.update_observation = AsyncMock(return_value=None)
@@ -29,7 +33,12 @@ def mock_simple_nws():
@pytest.fixture @pytest.fixture
def mock_simple_nws_times_out(): def mock_simple_nws_times_out():
"""Mock pynws SimpleNWS that times out.""" """Mock pynws SimpleNWS that times out."""
with patch("homeassistant.components.nws.SimpleNWS") as mock_nws: # set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests
with (
patch("homeassistant.components.nws.SimpleNWS") as mock_nws,
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_nws.return_value instance = mock_nws.return_value
instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError) instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError)
instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError) instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError)

View File

@@ -1,7 +1,6 @@
"""Tests for the NWS weather component.""" """Tests for the NWS weather component."""
from datetime import timedelta from datetime import timedelta
from unittest.mock import patch
import aiohttp import aiohttp
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
@@ -24,7 +23,6 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util
from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM
from .const import ( from .const import (
@@ -127,10 +125,6 @@ async def test_data_caching_error_observation(
caplog, caplog,
) -> None: ) -> None:
"""Test caching of data with errors.""" """Test caching of data with errors."""
with (
patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0),
patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0),
):
instance = mock_simple_nws.return_value instance = mock_simple_nws.return_value
entry = MockConfigEntry( entry = MockConfigEntry(
@@ -302,9 +296,6 @@ async def test_error_observation(
hass: HomeAssistant, mock_simple_nws, no_sensor hass: HomeAssistant, mock_simple_nws, no_sensor
) -> None: ) -> None:
"""Test error during update observation.""" """Test error during update observation."""
utc_time = dt_util.utcnow()
with patch("homeassistant.components.nws.coordinator.utcnow") as mock_utc:
mock_utc.return_value = utc_time
instance = mock_simple_nws.return_value instance = mock_simple_nws.return_value
# first update fails # first update fails
instance.update_observation.side_effect = aiohttp.ClientError instance.update_observation.side_effect = aiohttp.ClientError

View File

@@ -6,6 +6,7 @@ from ollama import Message, ResponseError
import pytest import pytest
from homeassistant.components import conversation, ollama from homeassistant.components import conversation, ollama
from homeassistant.components.conversation import trace
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
@@ -110,6 +111,19 @@ async def test_chat(
), result ), result
assert result.response.speech["plain"]["speech"] == "test response" assert result.response.speech["plain"]["speech"] == "test response"
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "The current time is" in detail_event["data"]["messages"][0]["content"]
async def test_message_history_trimming( async def test_message_history_trimming(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component

View File

@@ -9,9 +9,17 @@ import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.openai_conversation.const import ( from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
DEFAULT_CHAT_MODEL, CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_P,
) )
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@@ -75,7 +83,7 @@ async def test_options(
assert options["type"] is FlowResultType.CREATE_ENTRY assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"]["prompt"] == "Speak like a pirate" assert options["data"]["prompt"] == "Speak like a pirate"
assert options["data"]["max_tokens"] == 200 assert options["data"]["max_tokens"] == 200
assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
assert result2["type"] is FlowResultType.FORM assert result2["type"] is FlowResultType.FORM
assert result2["errors"] == {"base": error} assert result2["errors"] == {"base": error}
@pytest.mark.parametrize(
("current_options", "new_options", "expected_options"),
[
(
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "none",
CONF_PROMPT: "bla",
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
},
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
},
),
(
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
CONF_TOP_P: RECOMMENDED_TOP_P,
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS,
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
{
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: "assist",
CONF_PROMPT: "",
},
),
],
)
async def test_options_switching(
hass: HomeAssistant,
mock_config_entry,
mock_init_component,
current_options,
new_options,
expected_options,
) -> None:
"""Test the options form."""
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED):
options_flow = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{
**current_options,
CONF_RECOMMENDED: new_options[CONF_RECOMMENDED],
},
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
new_options,
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == expected_options

View File

@@ -15,6 +15,7 @@ from syrupy.assertion import SnapshotAssertion
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.components.conversation import trace
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import Context, HomeAssistant from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -200,6 +201,20 @@ async def test_function_call(
), ),
) )
# Test Conversation tracing
traces = trace.async_get_traces()
assert traces
last_trace = traces[-1].as_dict()
trace_events = last_trace.get("events", [])
assert [event["event_type"] for event in trace_events] == [
trace.ConversationTraceEventType.ASYNC_PROCESS,
trace.ConversationTraceEventType.AGENT_DETAIL,
trace.ConversationTraceEventType.LLM_TOOL_CALL,
]
# AGENT_DETAIL event contains the raw prompt passed to the model
detail_event = trace_events[1]
assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"]
@patch( @patch(
"homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" "homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools"

View File

@@ -2,6 +2,7 @@
from unittest.mock import patch from unittest.mock import patch
from freezegun import freeze_time
from pyplaato.models.airlock import PlaatoAirlock from pyplaato.models.airlock import PlaatoAirlock
from pyplaato.models.device import PlaatoDeviceType from pyplaato.models.device import PlaatoDeviceType
from pyplaato.models.keg import PlaatoKeg from pyplaato.models.keg import PlaatoKeg
@@ -23,6 +24,7 @@ AIRLOCK_DATA = {}
KEG_DATA = {} KEG_DATA = {}
@freeze_time("2024-05-24 12:00:00", tz_offset=0)
async def init_integration( async def init_integration(
hass: HomeAssistant, device_type: PlaatoDeviceType hass: HomeAssistant, device_type: PlaatoDeviceType
) -> MockConfigEntry: ) -> MockConfigEntry:

View File

@@ -492,7 +492,6 @@ async def test_block_set_mode_auth_error(
{ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT}, {ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED

View File

@@ -227,7 +227,6 @@ async def test_block_set_value_auth_error(
{ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30}, {ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED

View File

@@ -618,7 +618,6 @@ async def test_rpc_sleeping_update_entity_service(
service_data={ATTR_ENTITY_ID: entity_id}, service_data={ATTR_ENTITY_ID: entity_id},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
# Entity should be available after update_entity service call # Entity should be available after update_entity service call
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
@@ -667,7 +666,6 @@ async def test_block_sleeping_update_entity_service(
service_data={ATTR_ENTITY_ID: entity_id}, service_data={ATTR_ENTITY_ID: entity_id},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
# Entity should be available after update_entity service call # Entity should be available after update_entity service call
state = hass.states.get(entity_id) state = hass.states.get(entity_id)

View File

@@ -230,7 +230,6 @@ async def test_block_set_state_auth_error(
{ATTR_ENTITY_ID: "switch.test_name_channel_1"}, {ATTR_ENTITY_ID: "switch.test_name_channel_1"},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
@@ -374,7 +373,6 @@ async def test_rpc_auth_error(
{ATTR_ENTITY_ID: "switch.test_switch_0"}, {ATTR_ENTITY_ID: "switch.test_switch_0"},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED

View File

@@ -207,7 +207,6 @@ async def test_block_update_auth_error(
{ATTR_ENTITY_ID: "update.test_name_firmware_update"}, {ATTR_ENTITY_ID: "update.test_name_firmware_update"},
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
@@ -669,7 +668,6 @@ async def test_rpc_update_auth_error(
blocking=True, blocking=True,
) )
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress() flows = hass.config_entries.flow.async_progress()

View File

@@ -18,9 +18,15 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
@pytest.fixture @pytest.fixture
def mock_bridge(request): def mock_bridge(request):
"""Return a mocked SwitcherBridge.""" """Return a mocked SwitcherBridge."""
with patch( with (
"homeassistant.components.switcher_kis.utils.SwitcherBridge", autospec=True patch(
) as bridge_mock: "homeassistant.components.switcher_kis.SwitcherBridge", autospec=True
) as bridge_mock,
patch(
"homeassistant.components.switcher_kis.utils.SwitcherBridge",
new=bridge_mock,
),
):
bridge = bridge_mock.return_value bridge = bridge_mock.return_value
bridge.devices = [] bridge.devices = []

View File

@@ -4,11 +4,7 @@ from datetime import timedelta
import pytest import pytest
from homeassistant.components.switcher_kis.const import ( from homeassistant.components.switcher_kis.const import MAX_UPDATE_INTERVAL_SEC
DATA_DEVICE,
DOMAIN,
MAX_UPDATE_INTERVAL_SEC,
)
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import STATE_UNAVAILABLE from homeassistant.const import STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -24,15 +20,14 @@ async def test_update_fail(
hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test entities state unavailable when updates fail..""" """Test entities state unavailable when updates fail.."""
await init_integration(hass) entry = await init_integration(hass)
assert mock_bridge assert mock_bridge
mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES) mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mock_bridge.is_running is True assert mock_bridge.is_running is True
assert len(hass.data[DOMAIN]) == 2 assert len(entry.runtime_data) == 2
assert len(hass.data[DOMAIN][DATA_DEVICE]) == 2
async_fire_time_changed( async_fire_time_changed(
hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1) hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1)
@@ -77,11 +72,9 @@ async def test_entry_unload(hass: HomeAssistant, mock_bridge) -> None:
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
assert mock_bridge.is_running is True assert mock_bridge.is_running is True
assert len(hass.data[DOMAIN]) == 2
await hass.config_entries.async_unload(entry.entry_id) await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
assert entry.state is ConfigEntryState.NOT_LOADED assert entry.state is ConfigEntryState.NOT_LOADED
assert mock_bridge.is_running is False assert mock_bridge.is_running is False
assert len(hass.data[DOMAIN]) == 0

Some files were not shown because too many files have changed in this diff Show More