mirror of
https://github.com/home-assistant/core.git
synced 2025-09-10 07:11:37 +02:00
Add fuzzy matching to default agent (#150595)
This commit is contained in:
@@ -117,7 +117,7 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
{cv.string: vol.All(cv.ensure_list, [cv.string])}
|
||||
)
|
||||
}
|
||||
)
|
||||
),
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
@@ -268,8 +268,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
entity_component = EntityComponent[ConversationEntity](_LOGGER, DOMAIN, hass)
|
||||
hass.data[DATA_COMPONENT] = entity_component
|
||||
|
||||
agent_config = config.get(DOMAIN, {})
|
||||
await async_setup_default_agent(
|
||||
hass, entity_component, config.get(DOMAIN, {}).get("intents", {})
|
||||
hass, entity_component, config_intents=agent_config.get("intents", {})
|
||||
)
|
||||
|
||||
async def handle_process(service: ServiceCall) -> ServiceResponse:
|
||||
|
@@ -15,13 +15,18 @@ import time
|
||||
from typing import IO, Any, cast
|
||||
|
||||
from hassil.expression import Expression, Group, ListReference, TextChunk
|
||||
from hassil.fuzzy import FuzzyNgramMatcher, SlotCombinationInfo
|
||||
from hassil.intents import (
|
||||
Intent,
|
||||
IntentData,
|
||||
Intents,
|
||||
SlotList,
|
||||
TextSlotList,
|
||||
TextSlotValue,
|
||||
WildcardSlotList,
|
||||
)
|
||||
from hassil.models import MatchEntity
|
||||
from hassil.ngram import Sqlite3NgramModel
|
||||
from hassil.recognize import (
|
||||
MISSING_ENTITY,
|
||||
RecognizeResult,
|
||||
@@ -31,7 +36,15 @@ from hassil.recognize import (
|
||||
from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity
|
||||
from hassil.trie import Trie
|
||||
from hassil.util import merge_dict
|
||||
from home_assistant_intents import ErrorKey, get_intents, get_languages
|
||||
from home_assistant_intents import (
|
||||
ErrorKey,
|
||||
FuzzyConfig,
|
||||
FuzzyLanguageResponses,
|
||||
get_fuzzy_config,
|
||||
get_fuzzy_language,
|
||||
get_intents,
|
||||
get_languages,
|
||||
)
|
||||
import yaml
|
||||
|
||||
from homeassistant import core
|
||||
@@ -76,6 +89,7 @@ TRIGGER_CALLBACK_TYPE = Callable[
|
||||
]
|
||||
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
|
||||
METADATA_CUSTOM_FILE = "hass_custom_file"
|
||||
METADATA_FUZZY_MATCH = "hass_fuzzy_match"
|
||||
|
||||
ERROR_SENTINEL = object()
|
||||
|
||||
@@ -94,6 +108,8 @@ class LanguageIntents:
|
||||
intent_responses: dict[str, Any]
|
||||
error_responses: dict[str, Any]
|
||||
language_variant: str | None
|
||||
fuzzy_matcher: FuzzyNgramMatcher | None = None
|
||||
fuzzy_responses: FuzzyLanguageResponses | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -119,10 +135,13 @@ class IntentMatchingStage(Enum):
|
||||
EXPOSED_ENTITIES_ONLY = auto()
|
||||
"""Match against exposed entities only."""
|
||||
|
||||
FUZZY = auto()
|
||||
"""Use fuzzy matching to guess intent."""
|
||||
|
||||
UNEXPOSED_ENTITIES = auto()
|
||||
"""Match against unexposed entities in Home Assistant."""
|
||||
|
||||
FUZZY = auto()
|
||||
UNKNOWN_NAMES = auto()
|
||||
"""Capture names that are not known to Home Assistant."""
|
||||
|
||||
|
||||
@@ -241,6 +260,10 @@ class DefaultAgent(ConversationEntity):
|
||||
# LRU cache to avoid unnecessary intent matching
|
||||
self._intent_cache = IntentCache(capacity=128)
|
||||
|
||||
# Shared configuration for fuzzy matching
|
||||
self.fuzzy_matching = True
|
||||
self._fuzzy_config: FuzzyConfig | None = None
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
@@ -299,7 +322,7 @@ class DefaultAgent(ConversationEntity):
|
||||
_LOGGER.warning("No intents were loaded for language: %s", language)
|
||||
return None
|
||||
|
||||
slot_lists = self._make_slot_lists()
|
||||
slot_lists = await self._make_slot_lists()
|
||||
intent_context = self._make_intent_context(user_input)
|
||||
|
||||
if self._exposed_names_trie is not None:
|
||||
@@ -556,6 +579,36 @@ class DefaultAgent(ConversationEntity):
|
||||
# Don't try matching against all entities or doing a fuzzy match
|
||||
return None
|
||||
|
||||
# Use fuzzy matching
|
||||
skip_fuzzy_match = False
|
||||
if cache_value is not None:
|
||||
if (cache_value.result is not None) and (
|
||||
cache_value.stage == IntentMatchingStage.FUZZY
|
||||
):
|
||||
_LOGGER.debug("Got cached result for fuzzy match")
|
||||
return cache_value.result
|
||||
|
||||
# Continue with matching, but we know we won't succeed for fuzzy
|
||||
# match.
|
||||
skip_fuzzy_match = True
|
||||
|
||||
if (not skip_fuzzy_match) and self.fuzzy_matching:
|
||||
start_time = time.monotonic()
|
||||
fuzzy_result = self._recognize_fuzzy(lang_intents, user_input)
|
||||
|
||||
# Update cache
|
||||
self._intent_cache.put(
|
||||
cache_key,
|
||||
IntentCacheValue(result=fuzzy_result, stage=IntentMatchingStage.FUZZY),
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Did fuzzy match in %s second(s)", time.monotonic() - start_time
|
||||
)
|
||||
|
||||
if fuzzy_result is not None:
|
||||
return fuzzy_result
|
||||
|
||||
# Try again with all entities (including unexposed)
|
||||
skip_unexposed_entities_match = False
|
||||
if cache_value is not None:
|
||||
@@ -601,102 +654,160 @@ class DefaultAgent(ConversationEntity):
|
||||
# This should fail the intent handling phase (async_match_targets).
|
||||
return strict_result
|
||||
|
||||
# Try again with missing entities enabled
|
||||
skip_fuzzy_match = False
|
||||
# Check unknown names
|
||||
skip_unknown_names = False
|
||||
if cache_value is not None:
|
||||
if (cache_value.result is not None) and (
|
||||
cache_value.stage == IntentMatchingStage.FUZZY
|
||||
cache_value.stage == IntentMatchingStage.UNKNOWN_NAMES
|
||||
):
|
||||
_LOGGER.debug("Got cached result for fuzzy match")
|
||||
_LOGGER.debug("Got cached result for unknown names")
|
||||
return cache_value.result
|
||||
|
||||
# We know we won't succeed for fuzzy matching.
|
||||
skip_fuzzy_match = True
|
||||
skip_unknown_names = True
|
||||
|
||||
maybe_result: RecognizeResult | None = None
|
||||
if not skip_fuzzy_match:
|
||||
if not skip_unknown_names:
|
||||
start_time = time.monotonic()
|
||||
best_num_matched_entities = 0
|
||||
best_num_unmatched_entities = 0
|
||||
best_num_unmatched_ranges = 0
|
||||
for result in recognize_all(
|
||||
user_input.text,
|
||||
lang_intents.intents,
|
||||
slot_lists=slot_lists,
|
||||
intent_context=intent_context,
|
||||
allow_unmatched_entities=True,
|
||||
):
|
||||
if result.text_chunks_matched < 1:
|
||||
# Skip results that don't match any literal text
|
||||
continue
|
||||
|
||||
# Don't count missing entities that couldn't be filled from context
|
||||
num_matched_entities = 0
|
||||
for matched_entity in result.entities_list:
|
||||
if matched_entity.name not in result.unmatched_entities:
|
||||
num_matched_entities += 1
|
||||
|
||||
num_unmatched_entities = 0
|
||||
num_unmatched_ranges = 0
|
||||
for unmatched_entity in result.unmatched_entities_list:
|
||||
if isinstance(unmatched_entity, UnmatchedTextEntity):
|
||||
if unmatched_entity.text != MISSING_ENTITY:
|
||||
num_unmatched_entities += 1
|
||||
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
|
||||
num_unmatched_ranges += 1
|
||||
num_unmatched_entities += 1
|
||||
else:
|
||||
num_unmatched_entities += 1
|
||||
|
||||
if (
|
||||
(maybe_result is None) # first result
|
||||
or (
|
||||
# More literal text matched
|
||||
result.text_chunks_matched > maybe_result.text_chunks_matched
|
||||
)
|
||||
or (
|
||||
# More entities matched
|
||||
num_matched_entities > best_num_matched_entities
|
||||
)
|
||||
or (
|
||||
# Fewer unmatched entities
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities < best_num_unmatched_entities)
|
||||
)
|
||||
or (
|
||||
# Prefer unmatched ranges
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges > best_num_unmatched_ranges)
|
||||
)
|
||||
or (
|
||||
# Prefer match failures with entities
|
||||
(result.text_chunks_matched == maybe_result.text_chunks_matched)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||
and (
|
||||
("name" in result.entities)
|
||||
or ("name" in result.unmatched_entities)
|
||||
)
|
||||
)
|
||||
):
|
||||
maybe_result = result
|
||||
best_num_matched_entities = num_matched_entities
|
||||
best_num_unmatched_entities = num_unmatched_entities
|
||||
best_num_unmatched_ranges = num_unmatched_ranges
|
||||
maybe_result = self._recognize_unknown_names(
|
||||
lang_intents, user_input, slot_lists, intent_context
|
||||
)
|
||||
|
||||
# Update cache
|
||||
self._intent_cache.put(
|
||||
cache_key,
|
||||
IntentCacheValue(result=maybe_result, stage=IntentMatchingStage.FUZZY),
|
||||
IntentCacheValue(
|
||||
result=maybe_result, stage=IntentMatchingStage.UNKNOWN_NAMES
|
||||
),
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Did fuzzy match in %s second(s)", time.monotonic() - start_time
|
||||
"Did unknown names match in %s second(s)", time.monotonic() - start_time
|
||||
)
|
||||
|
||||
return maybe_result
|
||||
|
||||
def _recognize_fuzzy(
|
||||
self, lang_intents: LanguageIntents, user_input: ConversationInput
|
||||
) -> RecognizeResult | None:
|
||||
"""Return fuzzy recognition from hassil."""
|
||||
if lang_intents.fuzzy_matcher is None:
|
||||
return None
|
||||
|
||||
fuzzy_result = lang_intents.fuzzy_matcher.match(user_input.text)
|
||||
if fuzzy_result is None:
|
||||
return None
|
||||
|
||||
response = "default"
|
||||
if lang_intents.fuzzy_responses:
|
||||
domain = "" # no domain
|
||||
if "name" in fuzzy_result.slots:
|
||||
domain = fuzzy_result.name_domain
|
||||
elif "domain" in fuzzy_result.slots:
|
||||
domain = fuzzy_result.slots["domain"].value
|
||||
|
||||
slot_combo = tuple(sorted(fuzzy_result.slots))
|
||||
if (
|
||||
intent_responses := lang_intents.fuzzy_responses.get(
|
||||
fuzzy_result.intent_name
|
||||
)
|
||||
) and (combo_responses := intent_responses.get(slot_combo)):
|
||||
response = combo_responses.get(domain, response)
|
||||
|
||||
entities = [
|
||||
MatchEntity(name=slot_name, value=slot_value.value, text=slot_value.text)
|
||||
for slot_name, slot_value in fuzzy_result.slots.items()
|
||||
]
|
||||
|
||||
return RecognizeResult(
|
||||
intent=Intent(name=fuzzy_result.intent_name),
|
||||
intent_data=IntentData(sentence_texts=[]),
|
||||
intent_metadata={METADATA_FUZZY_MATCH: True},
|
||||
entities={entity.name: entity for entity in entities},
|
||||
entities_list=entities,
|
||||
response=response,
|
||||
)
|
||||
|
||||
def _recognize_unknown_names(
|
||||
self,
|
||||
lang_intents: LanguageIntents,
|
||||
user_input: ConversationInput,
|
||||
slot_lists: dict[str, SlotList],
|
||||
intent_context: dict[str, Any] | None,
|
||||
) -> RecognizeResult | None:
|
||||
"""Return result with unknown names for an error message."""
|
||||
maybe_result: RecognizeResult | None = None
|
||||
|
||||
best_num_matched_entities = 0
|
||||
best_num_unmatched_entities = 0
|
||||
best_num_unmatched_ranges = 0
|
||||
for result in recognize_all(
|
||||
user_input.text,
|
||||
lang_intents.intents,
|
||||
slot_lists=slot_lists,
|
||||
intent_context=intent_context,
|
||||
allow_unmatched_entities=True,
|
||||
):
|
||||
if result.text_chunks_matched < 1:
|
||||
# Skip results that don't match any literal text
|
||||
continue
|
||||
|
||||
# Don't count missing entities that couldn't be filled from context
|
||||
num_matched_entities = 0
|
||||
for matched_entity in result.entities_list:
|
||||
if matched_entity.name not in result.unmatched_entities:
|
||||
num_matched_entities += 1
|
||||
|
||||
num_unmatched_entities = 0
|
||||
num_unmatched_ranges = 0
|
||||
for unmatched_entity in result.unmatched_entities_list:
|
||||
if isinstance(unmatched_entity, UnmatchedTextEntity):
|
||||
if unmatched_entity.text != MISSING_ENTITY:
|
||||
num_unmatched_entities += 1
|
||||
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
|
||||
num_unmatched_ranges += 1
|
||||
num_unmatched_entities += 1
|
||||
else:
|
||||
num_unmatched_entities += 1
|
||||
|
||||
if (
|
||||
(maybe_result is None) # first result
|
||||
or (
|
||||
# More literal text matched
|
||||
result.text_chunks_matched > maybe_result.text_chunks_matched
|
||||
)
|
||||
or (
|
||||
# More entities matched
|
||||
num_matched_entities > best_num_matched_entities
|
||||
)
|
||||
or (
|
||||
# Fewer unmatched entities
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities < best_num_unmatched_entities)
|
||||
)
|
||||
or (
|
||||
# Prefer unmatched ranges
|
||||
(num_matched_entities == best_num_matched_entities)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges > best_num_unmatched_ranges)
|
||||
)
|
||||
or (
|
||||
# Prefer match failures with entities
|
||||
(result.text_chunks_matched == maybe_result.text_chunks_matched)
|
||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||
and (
|
||||
("name" in result.entities)
|
||||
or ("name" in result.unmatched_entities)
|
||||
)
|
||||
)
|
||||
):
|
||||
maybe_result = result
|
||||
best_num_matched_entities = num_matched_entities
|
||||
best_num_unmatched_entities = num_unmatched_entities
|
||||
best_num_unmatched_ranges = num_unmatched_ranges
|
||||
|
||||
return maybe_result
|
||||
|
||||
def _get_unexposed_entity_names(self, text: str) -> TextSlotList:
|
||||
"""Get filtered slot list with unexposed entity names in Home Assistant."""
|
||||
if self._unexposed_names_trie is None:
|
||||
@@ -851,7 +962,7 @@ class DefaultAgent(ConversationEntity):
|
||||
if lang_intents is None:
|
||||
return
|
||||
|
||||
self._make_slot_lists()
|
||||
await self._make_slot_lists()
|
||||
|
||||
async def async_get_or_load_intents(self, language: str) -> LanguageIntents | None:
|
||||
"""Load all intents of a language with lock."""
|
||||
@@ -1002,12 +1113,85 @@ class DefaultAgent(ConversationEntity):
|
||||
intent_responses = responses_dict.get("intents", {})
|
||||
error_responses = responses_dict.get("errors", {})
|
||||
|
||||
if not self.fuzzy_matching:
|
||||
_LOGGER.debug("Fuzzy matching is disabled")
|
||||
return LanguageIntents(
|
||||
intents,
|
||||
intents_dict,
|
||||
intent_responses,
|
||||
error_responses,
|
||||
language_variant,
|
||||
)
|
||||
|
||||
# Load fuzzy
|
||||
fuzzy_info = get_fuzzy_language(language_variant, json_load=json_load)
|
||||
if fuzzy_info is None:
|
||||
_LOGGER.debug(
|
||||
"Fuzzy matching not available for language: %s", language_variant
|
||||
)
|
||||
return LanguageIntents(
|
||||
intents,
|
||||
intents_dict,
|
||||
intent_responses,
|
||||
error_responses,
|
||||
language_variant,
|
||||
)
|
||||
|
||||
if self._fuzzy_config is None:
|
||||
# Load shared config
|
||||
self._fuzzy_config = get_fuzzy_config(json_load=json_load)
|
||||
_LOGGER.debug("Loaded shared fuzzy matching config")
|
||||
|
||||
assert self._fuzzy_config is not None
|
||||
|
||||
fuzzy_matcher: FuzzyNgramMatcher | None = None
|
||||
fuzzy_responses: FuzzyLanguageResponses | None = None
|
||||
|
||||
start_time = time.monotonic()
|
||||
fuzzy_responses = fuzzy_info.responses
|
||||
fuzzy_matcher = FuzzyNgramMatcher(
|
||||
intents=intents,
|
||||
intent_models={
|
||||
intent_name: Sqlite3NgramModel(
|
||||
order=fuzzy_model.order,
|
||||
words={
|
||||
word: str(word_id)
|
||||
for word, word_id in fuzzy_model.words.items()
|
||||
},
|
||||
database_path=fuzzy_model.database_path,
|
||||
)
|
||||
for intent_name, fuzzy_model in fuzzy_info.ngram_models.items()
|
||||
},
|
||||
intent_slot_list_names=self._fuzzy_config.slot_list_names,
|
||||
slot_combinations={
|
||||
intent_name: {
|
||||
combo_key: [
|
||||
SlotCombinationInfo(
|
||||
name_domains=(set(name_domains) if name_domains else None)
|
||||
)
|
||||
]
|
||||
for combo_key, name_domains in intent_combos.items()
|
||||
}
|
||||
for intent_name, intent_combos in self._fuzzy_config.slot_combinations.items()
|
||||
},
|
||||
domain_keywords=fuzzy_info.domain_keywords,
|
||||
stop_words=fuzzy_info.stop_words,
|
||||
)
|
||||
_LOGGER.debug(
|
||||
"Loaded fuzzy matcher in %s second(s): language=%s, intents=%s",
|
||||
time.monotonic() - start_time,
|
||||
language_variant,
|
||||
sorted(fuzzy_matcher.intent_models.keys()),
|
||||
)
|
||||
|
||||
return LanguageIntents(
|
||||
intents,
|
||||
intents_dict,
|
||||
intent_responses,
|
||||
error_responses,
|
||||
language_variant,
|
||||
fuzzy_matcher=fuzzy_matcher,
|
||||
fuzzy_responses=fuzzy_responses,
|
||||
)
|
||||
|
||||
@core.callback
|
||||
@@ -1027,8 +1211,7 @@ class DefaultAgent(ConversationEntity):
|
||||
# Slot lists have changed, so we must clear the cache
|
||||
self._intent_cache.clear()
|
||||
|
||||
@core.callback
|
||||
def _make_slot_lists(self) -> dict[str, SlotList]:
|
||||
async def _make_slot_lists(self) -> dict[str, SlotList]:
|
||||
"""Create slot lists with areas and entity names/aliases."""
|
||||
if self._slot_lists is not None:
|
||||
return self._slot_lists
|
||||
@@ -1089,6 +1272,10 @@ class DefaultAgent(ConversationEntity):
|
||||
"floor": TextSlotList.from_tuples(floor_names, allow_template=False),
|
||||
}
|
||||
|
||||
# Reload fuzzy matchers with new slot lists
|
||||
if self.fuzzy_matching:
|
||||
await self.hass.async_add_executor_job(self._load_fuzzy_matchers)
|
||||
|
||||
self._listen_clear_slot_list()
|
||||
|
||||
_LOGGER.debug(
|
||||
@@ -1098,6 +1285,25 @@ class DefaultAgent(ConversationEntity):
|
||||
|
||||
return self._slot_lists
|
||||
|
||||
def _load_fuzzy_matchers(self) -> None:
|
||||
"""Reload fuzzy matchers for all loaded languages."""
|
||||
for lang_intents in self._lang_intents.values():
|
||||
if (not isinstance(lang_intents, LanguageIntents)) or (
|
||||
lang_intents.fuzzy_matcher is None
|
||||
):
|
||||
continue
|
||||
|
||||
lang_matcher = lang_intents.fuzzy_matcher
|
||||
lang_intents.fuzzy_matcher = FuzzyNgramMatcher(
|
||||
intents=lang_matcher.intents,
|
||||
intent_models=lang_matcher.intent_models,
|
||||
intent_slot_list_names=lang_matcher.intent_slot_list_names,
|
||||
slot_combinations=lang_matcher.slot_combinations,
|
||||
domain_keywords=lang_matcher.domain_keywords,
|
||||
stop_words=lang_matcher.stop_words,
|
||||
slot_lists=self._slot_lists,
|
||||
)
|
||||
|
||||
def _make_intent_context(
|
||||
self, user_input: ConversationInput
|
||||
) -> dict[str, Any] | None:
|
||||
@@ -1521,10 +1727,8 @@ def _get_match_error_response(
|
||||
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
|
||||
"""Collect list reference names recursively."""
|
||||
if isinstance(expression, Group):
|
||||
grp: Group = expression
|
||||
for item in grp.items:
|
||||
for item in expression.items:
|
||||
_collect_list_references(item, list_names)
|
||||
elif isinstance(expression, ListReference):
|
||||
# {list}
|
||||
list_ref: ListReference = expression
|
||||
list_names.add(list_ref.slot_name)
|
||||
list_names.add(expression.slot_name)
|
||||
|
@@ -26,7 +26,11 @@ from .agent_manager import (
|
||||
get_agent_manager,
|
||||
)
|
||||
from .const import DATA_COMPONENT, DATA_DEFAULT_ENTITY
|
||||
from .default_agent import METADATA_CUSTOM_FILE, METADATA_CUSTOM_SENTENCE
|
||||
from .default_agent import (
|
||||
METADATA_CUSTOM_FILE,
|
||||
METADATA_CUSTOM_SENTENCE,
|
||||
METADATA_FUZZY_MATCH,
|
||||
)
|
||||
from .entity import ConversationEntity
|
||||
from .models import ConversationInput
|
||||
|
||||
@@ -240,6 +244,8 @@ async def websocket_hass_agent_debug(
|
||||
"sentence_template": "",
|
||||
# When match is incomplete, this will contain the best slot guesses
|
||||
"unmatched_slots": _get_unmatched_slots(intent_result),
|
||||
# True if match was not exact
|
||||
"fuzzy_match": False,
|
||||
}
|
||||
|
||||
if successful_match:
|
||||
@@ -251,16 +257,19 @@ async def websocket_hass_agent_debug(
|
||||
if intent_result.intent_sentence is not None:
|
||||
result_dict["sentence_template"] = intent_result.intent_sentence.text
|
||||
|
||||
# Inspect metadata to determine if this matched a custom sentence
|
||||
if intent_result.intent_metadata and intent_result.intent_metadata.get(
|
||||
METADATA_CUSTOM_SENTENCE
|
||||
):
|
||||
result_dict["source"] = "custom"
|
||||
result_dict["file"] = intent_result.intent_metadata.get(
|
||||
METADATA_CUSTOM_FILE
|
||||
if intent_result.intent_metadata:
|
||||
# Inspect metadata to determine if this matched a custom sentence
|
||||
if intent_result.intent_metadata.get(METADATA_CUSTOM_SENTENCE):
|
||||
result_dict["source"] = "custom"
|
||||
result_dict["file"] = intent_result.intent_metadata.get(
|
||||
METADATA_CUSTOM_FILE
|
||||
)
|
||||
else:
|
||||
result_dict["source"] = "builtin"
|
||||
|
||||
result_dict["fuzzy_match"] = intent_result.intent_metadata.get(
|
||||
METADATA_FUZZY_MATCH, False
|
||||
)
|
||||
else:
|
||||
result_dict["source"] = "builtin"
|
||||
|
||||
result_dicts.append(result_dict)
|
||||
|
||||
|
@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt, tts, wake_word
|
||||
from homeassistant.components import conversation, stt, tts, wake_word
|
||||
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
|
||||
from homeassistant.components.assist_pipeline.const import (
|
||||
BYTES_PER_CHUNK,
|
||||
@@ -295,6 +295,11 @@ async def init_supporting_components(
|
||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
assert await async_setup_component(hass, "conversation", {"conversation": {}})
|
||||
|
||||
# Disable fuzzy matching by default for tests
|
||||
agent = hass.data[conversation.DATA_DEFAULT_ENTITY]
|
||||
agent.fuzzy_matching = False
|
||||
|
||||
config_entry = MockConfigEntry(domain="test")
|
||||
config_entry.add_to_hass(hass)
|
||||
|
@@ -73,4 +73,8 @@ async def sl_setup(hass: HomeAssistant):
|
||||
async def init_components(hass: HomeAssistant):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
assert await async_setup_component(hass, "conversation", {conversation.DOMAIN: {}})
|
||||
|
||||
# Disable fuzzy matching by default for tests
|
||||
agent = hass.data[conversation.DATA_DEFAULT_ENTITY]
|
||||
agent.fuzzy_matching = False
|
||||
|
@@ -464,6 +464,7 @@
|
||||
'value': 'my cool light',
|
||||
}),
|
||||
}),
|
||||
'fuzzy_match': False,
|
||||
'intent': dict({
|
||||
'name': 'HassTurnOn',
|
||||
}),
|
||||
@@ -472,7 +473,6 @@
|
||||
'slots': dict({
|
||||
'name': 'my cool light',
|
||||
}),
|
||||
'source': 'builtin',
|
||||
'targets': dict({
|
||||
'light.kitchen': dict({
|
||||
'matched': True,
|
||||
@@ -489,6 +489,7 @@
|
||||
'value': 'my cool light',
|
||||
}),
|
||||
}),
|
||||
'fuzzy_match': False,
|
||||
'intent': dict({
|
||||
'name': 'HassTurnOff',
|
||||
}),
|
||||
@@ -497,7 +498,6 @@
|
||||
'slots': dict({
|
||||
'name': 'my cool light',
|
||||
}),
|
||||
'source': 'builtin',
|
||||
'targets': dict({
|
||||
'light.kitchen': dict({
|
||||
'matched': True,
|
||||
@@ -519,6 +519,7 @@
|
||||
'value': 'light',
|
||||
}),
|
||||
}),
|
||||
'fuzzy_match': False,
|
||||
'intent': dict({
|
||||
'name': 'HassTurnOn',
|
||||
}),
|
||||
@@ -528,7 +529,6 @@
|
||||
'area': 'kitchen',
|
||||
'domain': 'light',
|
||||
}),
|
||||
'source': 'builtin',
|
||||
'targets': dict({
|
||||
'light.kitchen': dict({
|
||||
'matched': True,
|
||||
@@ -555,6 +555,7 @@
|
||||
'value': 'on',
|
||||
}),
|
||||
}),
|
||||
'fuzzy_match': False,
|
||||
'intent': dict({
|
||||
'name': 'HassGetState',
|
||||
}),
|
||||
@@ -565,7 +566,6 @@
|
||||
'domain': 'lights',
|
||||
'state': 'on',
|
||||
}),
|
||||
'source': 'builtin',
|
||||
'targets': dict({
|
||||
'light.kitchen': dict({
|
||||
'matched': False,
|
||||
@@ -590,6 +590,7 @@
|
||||
}),
|
||||
}),
|
||||
'file': 'en/beer.yaml',
|
||||
'fuzzy_match': False,
|
||||
'intent': dict({
|
||||
'name': 'OrderBeer',
|
||||
}),
|
||||
@@ -630,6 +631,7 @@
|
||||
'value': 'test light',
|
||||
}),
|
||||
}),
|
||||
'fuzzy_match': False,
|
||||
'intent': dict({
|
||||
'name': 'HassLightSet',
|
||||
}),
|
||||
@@ -639,7 +641,6 @@
|
||||
'brightness': '100',
|
||||
'name': 'test light',
|
||||
}),
|
||||
'source': 'builtin',
|
||||
'targets': dict({
|
||||
'light.demo_1234': dict({
|
||||
'matched': True,
|
||||
@@ -662,6 +663,7 @@
|
||||
'value': 'test light',
|
||||
}),
|
||||
}),
|
||||
'fuzzy_match': False,
|
||||
'intent': dict({
|
||||
'name': 'HassLightSet',
|
||||
}),
|
||||
@@ -670,7 +672,6 @@
|
||||
'slots': dict({
|
||||
'name': 'test light',
|
||||
}),
|
||||
'source': 'builtin',
|
||||
'targets': dict({
|
||||
}),
|
||||
'unmatched_slots': dict({
|
||||
|
@@ -25,7 +25,12 @@ from homeassistant.components.intent import (
|
||||
TimerInfo,
|
||||
async_register_timer_handler,
|
||||
)
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
from homeassistant.components.light import (
|
||||
ATTR_SUPPORTED_COLOR_MODES,
|
||||
DOMAIN as LIGHT_DOMAIN,
|
||||
ColorMode,
|
||||
intent as light_intent,
|
||||
)
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
ATTR_FRIENDLY_NAME,
|
||||
@@ -81,6 +86,10 @@ async def init_components(hass: HomeAssistant) -> None:
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
|
||||
# Disable fuzzy matching by default for tests
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
agent.fuzzy_matching = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"er_kwargs",
|
||||
@@ -3287,3 +3296,97 @@ async def test_language_with_alternative_code(
|
||||
assert call.domain == LIGHT_DOMAIN
|
||||
assert call.service == "turn_on"
|
||||
assert call.data == {"entity_id": [entity_id]}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fuzzy_matching", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
("sentence", "intent_type", "slots"),
|
||||
[
|
||||
("time", "HassGetCurrentTime", {}),
|
||||
("how about my timers", "HassTimerStatus", {}),
|
||||
(
|
||||
"the office needs more blue",
|
||||
"HassLightSet",
|
||||
{"area": "office", "color": "blue"},
|
||||
),
|
||||
(
|
||||
"50% office light",
|
||||
"HassLightSet",
|
||||
{"name": "office light", "brightness": "50%"},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_fuzzy_matching(
|
||||
hass: HomeAssistant,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
fuzzy_matching: bool,
|
||||
sentence: str,
|
||||
intent_type: str,
|
||||
slots: dict[str, Any],
|
||||
) -> None:
|
||||
"""Test fuzzy vs. non-fuzzy matching on some English sentences."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "conversation", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
await light_intent.async_setup_intents(hass)
|
||||
|
||||
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||
agent.fuzzy_matching = fuzzy_matching
|
||||
|
||||
area_office = area_registry.async_get_or_create("office_id")
|
||||
area_office = area_registry.async_update(area_office.id, name="office")
|
||||
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
office_satellite = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "id-1234")},
|
||||
)
|
||||
device_registry.async_update_device(office_satellite.id, area_id=area_office.id)
|
||||
|
||||
office_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
||||
office_light = entity_registry.async_update_entity(
|
||||
office_light.entity_id, area_id=area_office.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
office_light.entity_id,
|
||||
"on",
|
||||
attributes={
|
||||
ATTR_FRIENDLY_NAME: "office light",
|
||||
ATTR_SUPPORTED_COLOR_MODES: [ColorMode.BRIGHTNESS, ColorMode.RGB],
|
||||
},
|
||||
)
|
||||
_on_calls = async_mock_service(hass, LIGHT_DOMAIN, "turn_on")
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
sentence,
|
||||
None,
|
||||
Context(),
|
||||
language="en",
|
||||
device_id=office_satellite.id,
|
||||
)
|
||||
response = result.response
|
||||
|
||||
if not fuzzy_matching:
|
||||
# Should not match
|
||||
assert response.response_type == intent.IntentResponseType.ERROR
|
||||
return
|
||||
|
||||
assert response.response_type in (
|
||||
intent.IntentResponseType.ACTION_DONE,
|
||||
intent.IntentResponseType.QUERY_ANSWER,
|
||||
)
|
||||
assert response.intent is not None
|
||||
assert response.intent.intent_type == intent_type
|
||||
|
||||
# Verify slot texts match
|
||||
actual_slots = {
|
||||
slot_name: slot_value["text"]
|
||||
for slot_name, slot_value in response.intent.slots.items()
|
||||
if slot_name != "preferred_area_id" # context area
|
||||
}
|
||||
assert actual_slots == slots
|
||||
|
Reference in New Issue
Block a user