diff --git a/homeassistant/components/conversation/http.py b/homeassistant/components/conversation/http.py index 290e3aab955..f3370185a7a 100644 --- a/homeassistant/components/conversation/http.py +++ b/homeassistant/components/conversation/http.py @@ -2,15 +2,19 @@ from __future__ import annotations +from collections import defaultdict from collections.abc import Iterable from dataclasses import asdict +from pathlib import Path from typing import Any from aiohttp import web from hassil.recognize import MISSING_ENTITY, RecognizeResult from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity +from hassil.util import merge_dict from home_assistant_intents import get_language_scores import voluptuous as vol +from yaml import safe_load from homeassistant.components import http, websocket_api from homeassistant.components.http.data_validator import RequestDataValidator @@ -45,6 +49,7 @@ def async_setup(hass: HomeAssistant) -> None: websocket_api.async_register_command(hass, websocket_list_sentences) websocket_api.async_register_command(hass, websocket_hass_agent_debug) websocket_api.async_register_command(hass, websocket_hass_agent_language_scores) + websocket_api.async_register_command(hass, websocket_hass_agent_custom_sentences) @websocket_api.websocket_command( @@ -378,6 +383,46 @@ async def websocket_hass_agent_language_scores( connection.send_result(msg["id"], result) +@websocket_api.websocket_command( + { + vol.Required("type"): "conversation/agent/homeassistant/custom_sentences", + vol.Optional("language"): str, + vol.Optional("country"): str, + } +) +@websocket_api.async_response +async def websocket_hass_agent_custom_sentences( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Get user-defined custom sentences.""" + custom_sentences_dir = Path(hass.config.path("custom_sentences")) + language = msg.get("language") + country = msg.get("country") + + def load_custom_sentences(): + lang_dirs = [d for d in custom_sentences_dir.iterdir() if d.is_dir()] + custom_langs = [d.name for d in lang_dirs] + + if language: + lang_dirs = [ + custom_sentences_dir / lang_match + for lang_match in language_util.matches(language, custom_langs, country) + ] + + lang_intents = defaultdict(dict) + for lang_dir in lang_dirs: + for sentence_path in lang_dir.glob("*.yaml"): + with open(sentence_path, encoding="utf-8") as sentence_file: + merge_dict(lang_intents[lang_dir.name], safe_load(sentence_file)) + + return lang_intents + + result = await hass.async_add_executor_job(load_custom_sentences) + connection.send_result(msg["id"], result) + + class ConversationProcessView(http.HomeAssistantView): """View to process text.""" diff --git a/tests/components/conversation/snapshots/test_http.ambr b/tests/components/conversation/snapshots/test_http.ambr index 8b8ed6fa71c..597d348a899 100644 --- a/tests/components/conversation/snapshots/test_http.ambr +++ b/tests/components/conversation/snapshots/test_http.ambr @@ -453,6 +453,194 @@ }), }) # --- +# name: test_ws_hass_agent_custom_sentences + dict({ + 'en': dict({ + 'intents': dict({ + 'OrderBeer': dict({ + 'data': list([ + dict({ + 'sentences': list([ + "[I'd like to ]order a {beer_style} [please]", + ]), + }), + ]), + }), + 'OrderFood': dict({ + 'data': list([ + dict({ + 'sentences': list([ + "[I'd like to ]order {food_name:name} [please]", + ]), + }), + ]), + }), + }), + 'language': 'en', + 'lists': dict({ + 'beer_style': dict({ + 'values': list([ + 'stout', + 'lager', + ]), + }), + 'food_name': dict({ + 'wildcard': True, + }), + }), + }), + 'en-GB': dict({ + 'intents': dict({ + 'OrderBeer': dict({ + 'data': list([ + dict({ + 'sentences': list([ + 'lager please', + ]), + }), + ]), + }), + }), + 'language': 'en', + }), + 'nl': dict({ + 'intents': dict({ + 'OrderBeer': dict({ + 'data': list([ + dict({ + 'sentences': list([ + 'biertje', + ]), + }), + ]), + }), + }), + 'language': 'nl', + }), + }) +# --- +# name: test_ws_hass_agent_custom_sentences.1 + dict({ + 'nl': dict({ + 'intents': dict({ + 'OrderBeer': dict({ + 'data': list([ + dict({ + 'sentences': list([ + 'biertje', + ]), + }), + ]), + }), + }), + 'language': 'nl', + }), + }) +# --- +# name: test_ws_hass_agent_custom_sentences.2 + dict({ + 'en': dict({ + 'intents': dict({ + 'OrderBeer': dict({ + 'data': list([ + dict({ + 'sentences': list([ + "[I'd like to ]order a {beer_style} [please]", + ]), + }), + ]), + }), + 'OrderFood': dict({ + 'data': list([ + dict({ + 'sentences': list([ + "[I'd like to ]order {food_name:name} [please]", + ]), + }), + ]), + }), + }), + 'language': 'en', + 'lists': dict({ + 'beer_style': dict({ + 'values': list([ + 'stout', + 'lager', + ]), + }), + 'food_name': dict({ + 'wildcard': True, + }), + }), + }), + 'en-GB': dict({ + 'intents': dict({ + 'OrderBeer': dict({ + 'data': list([ + dict({ + 'sentences': list([ + 'lager please', + ]), + }), + ]), + }), + }), + 'language': 'en', + }), + }) +# --- +# name: test_ws_hass_agent_custom_sentences.3 + dict({ + 'en': dict({ + 'intents': dict({ + 'OrderBeer': dict({ + 'data': list([ + dict({ + 'sentences': list([ + "[I'd like to ]order a {beer_style} [please]", + ]), + }), + ]), + }), + 'OrderFood': dict({ + 'data': list([ + dict({ + 'sentences': list([ + "[I'd like to ]order {food_name:name} [please]", + ]), + }), + ]), + }), + }), + 'language': 'en', + 'lists': dict({ + 'beer_style': dict({ + 'values': list([ + 'stout', + 'lager', + ]), + }), + 'food_name': dict({ + 'wildcard': True, + }), + }), + }), + 'en-GB': dict({ + 'intents': dict({ + 'OrderBeer': dict({ + 'data': list([ + dict({ + 'sentences': list([ + 'lager please', + ]), + }), + ]), + }), + }), + 'language': 'en', + }), + }) +# --- # name: test_ws_hass_agent_debug dict({ 'results': list([ diff --git a/tests/components/conversation/test_http.py b/tests/components/conversation/test_http.py index 29cd567e904..fac58f07bd2 100644 --- a/tests/components/conversation/test_http.py +++ b/tests/components/conversation/test_http.py @@ -1,11 +1,15 @@ """The tests for the HTTP API of the Conversation component.""" +from collections import defaultdict from http import HTTPStatus +from pathlib import Path from typing import Any from unittest.mock import patch +from hassil.util import merge_dict import pytest from syrupy.assertion import SnapshotAssertion +from yaml import safe_load from homeassistant.components.conversation import default_agent from homeassistant.components.conversation.const import ( @@ -594,3 +598,94 @@ async def test_ws_hass_language_scores_with_filter( # GB English should be preferred result = msg["result"] assert result["preferred_language"] == "en-GB" + + +async def test_ws_hass_agent_custom_sentences( + hass: HomeAssistant, + init_components, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test homeassistant agent websocket command to get custom sentences.""" + + # Expecting in testing_config/custom_sentences: + # - /en/beer.yaml + # - /en-GB/beer.yaml + # - /nl/beer.yaml + expected_intents = await hass.async_add_executor_job(_load_custom_sentences, hass) + + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "conversation/agent/homeassistant/custom_sentences", + } + ) + msg = await client.receive_json() + + assert msg["success"] + assert msg["result"] == snapshot + + # All languages should be loaded + custom_sentences = msg["result"] + assert custom_sentences.keys() == {"en", "en-GB", "nl"} + + # Each language contains the merged YAML as a dict + for lang, actual_intents in custom_sentences.items(): + assert lang in expected_intents + assert actual_intents == expected_intents[lang] + + # Only Dutch + await client.send_json_auto_id( + {"type": "conversation/agent/homeassistant/custom_sentences", "language": "nl"} + ) + msg = await client.receive_json() + + assert msg["success"] + assert msg["result"] == snapshot + custom_sentences = msg["result"] + assert custom_sentences.keys() == {"nl"} + + # British English is first + await client.send_json_auto_id( + { + "type": "conversation/agent/homeassistant/custom_sentences", + "language": "en", + "country": "GB", + } + ) + msg = await client.receive_json() + + assert msg["success"] + assert msg["result"] == snapshot + custom_sentences = msg["result"] + assert list(custom_sentences.keys()) == ["en-GB", "en"] + + # General English is first + await client.send_json_auto_id( + { + "type": "conversation/agent/homeassistant/custom_sentences", + "language": "en", + } + ) + msg = await client.receive_json() + + assert msg["success"] + assert msg["result"] == snapshot + custom_sentences = msg["result"] + assert list(custom_sentences.keys()) == ["en", "en-GB"] + + +def _load_custom_sentences(hass: HomeAssistant) -> dict[str, dict[str, Any]]: + """Loads custom sentences from testing_config/custom_sentences.""" + custom_sentences = defaultdict(dict) + custom_sentences_dir = Path(hass.config.path("custom_sentences")) + for lang_dir in custom_sentences_dir.iterdir(): + if not lang_dir.is_dir(): + continue + + for yaml_path in lang_dir.glob("*.yaml"): + with open(yaml_path, encoding="utf-8") as yaml_file: + merge_dict(custom_sentences[lang_dir.name], safe_load(yaml_file)) + + return custom_sentences