mirror of
https://github.com/home-assistant/core.git
synced 2025-09-10 07:11:37 +02:00
Add fuzzy matching to default agent (#150595)
This commit is contained in:
@@ -117,7 +117,7 @@ CONFIG_SCHEMA = vol.Schema(
|
|||||||
{cv.string: vol.All(cv.ensure_list, [cv.string])}
|
{cv.string: vol.All(cv.ensure_list, [cv.string])}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
),
|
||||||
},
|
},
|
||||||
extra=vol.ALLOW_EXTRA,
|
extra=vol.ALLOW_EXTRA,
|
||||||
)
|
)
|
||||||
@@ -268,8 +268,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
entity_component = EntityComponent[ConversationEntity](_LOGGER, DOMAIN, hass)
|
entity_component = EntityComponent[ConversationEntity](_LOGGER, DOMAIN, hass)
|
||||||
hass.data[DATA_COMPONENT] = entity_component
|
hass.data[DATA_COMPONENT] = entity_component
|
||||||
|
|
||||||
|
agent_config = config.get(DOMAIN, {})
|
||||||
await async_setup_default_agent(
|
await async_setup_default_agent(
|
||||||
hass, entity_component, config.get(DOMAIN, {}).get("intents", {})
|
hass, entity_component, config_intents=agent_config.get("intents", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_process(service: ServiceCall) -> ServiceResponse:
|
async def handle_process(service: ServiceCall) -> ServiceResponse:
|
||||||
|
@@ -15,13 +15,18 @@ import time
|
|||||||
from typing import IO, Any, cast
|
from typing import IO, Any, cast
|
||||||
|
|
||||||
from hassil.expression import Expression, Group, ListReference, TextChunk
|
from hassil.expression import Expression, Group, ListReference, TextChunk
|
||||||
|
from hassil.fuzzy import FuzzyNgramMatcher, SlotCombinationInfo
|
||||||
from hassil.intents import (
|
from hassil.intents import (
|
||||||
|
Intent,
|
||||||
|
IntentData,
|
||||||
Intents,
|
Intents,
|
||||||
SlotList,
|
SlotList,
|
||||||
TextSlotList,
|
TextSlotList,
|
||||||
TextSlotValue,
|
TextSlotValue,
|
||||||
WildcardSlotList,
|
WildcardSlotList,
|
||||||
)
|
)
|
||||||
|
from hassil.models import MatchEntity
|
||||||
|
from hassil.ngram import Sqlite3NgramModel
|
||||||
from hassil.recognize import (
|
from hassil.recognize import (
|
||||||
MISSING_ENTITY,
|
MISSING_ENTITY,
|
||||||
RecognizeResult,
|
RecognizeResult,
|
||||||
@@ -31,7 +36,15 @@ from hassil.recognize import (
|
|||||||
from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity
|
from hassil.string_matcher import UnmatchedRangeEntity, UnmatchedTextEntity
|
||||||
from hassil.trie import Trie
|
from hassil.trie import Trie
|
||||||
from hassil.util import merge_dict
|
from hassil.util import merge_dict
|
||||||
from home_assistant_intents import ErrorKey, get_intents, get_languages
|
from home_assistant_intents import (
|
||||||
|
ErrorKey,
|
||||||
|
FuzzyConfig,
|
||||||
|
FuzzyLanguageResponses,
|
||||||
|
get_fuzzy_config,
|
||||||
|
get_fuzzy_language,
|
||||||
|
get_intents,
|
||||||
|
get_languages,
|
||||||
|
)
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from homeassistant import core
|
from homeassistant import core
|
||||||
@@ -76,6 +89,7 @@ TRIGGER_CALLBACK_TYPE = Callable[
|
|||||||
]
|
]
|
||||||
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
|
METADATA_CUSTOM_SENTENCE = "hass_custom_sentence"
|
||||||
METADATA_CUSTOM_FILE = "hass_custom_file"
|
METADATA_CUSTOM_FILE = "hass_custom_file"
|
||||||
|
METADATA_FUZZY_MATCH = "hass_fuzzy_match"
|
||||||
|
|
||||||
ERROR_SENTINEL = object()
|
ERROR_SENTINEL = object()
|
||||||
|
|
||||||
@@ -94,6 +108,8 @@ class LanguageIntents:
|
|||||||
intent_responses: dict[str, Any]
|
intent_responses: dict[str, Any]
|
||||||
error_responses: dict[str, Any]
|
error_responses: dict[str, Any]
|
||||||
language_variant: str | None
|
language_variant: str | None
|
||||||
|
fuzzy_matcher: FuzzyNgramMatcher | None = None
|
||||||
|
fuzzy_responses: FuzzyLanguageResponses | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@@ -119,10 +135,13 @@ class IntentMatchingStage(Enum):
|
|||||||
EXPOSED_ENTITIES_ONLY = auto()
|
EXPOSED_ENTITIES_ONLY = auto()
|
||||||
"""Match against exposed entities only."""
|
"""Match against exposed entities only."""
|
||||||
|
|
||||||
|
FUZZY = auto()
|
||||||
|
"""Use fuzzy matching to guess intent."""
|
||||||
|
|
||||||
UNEXPOSED_ENTITIES = auto()
|
UNEXPOSED_ENTITIES = auto()
|
||||||
"""Match against unexposed entities in Home Assistant."""
|
"""Match against unexposed entities in Home Assistant."""
|
||||||
|
|
||||||
FUZZY = auto()
|
UNKNOWN_NAMES = auto()
|
||||||
"""Capture names that are not known to Home Assistant."""
|
"""Capture names that are not known to Home Assistant."""
|
||||||
|
|
||||||
|
|
||||||
@@ -241,6 +260,10 @@ class DefaultAgent(ConversationEntity):
|
|||||||
# LRU cache to avoid unnecessary intent matching
|
# LRU cache to avoid unnecessary intent matching
|
||||||
self._intent_cache = IntentCache(capacity=128)
|
self._intent_cache = IntentCache(capacity=128)
|
||||||
|
|
||||||
|
# Shared configuration for fuzzy matching
|
||||||
|
self.fuzzy_matching = True
|
||||||
|
self._fuzzy_config: FuzzyConfig | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
@@ -299,7 +322,7 @@ class DefaultAgent(ConversationEntity):
|
|||||||
_LOGGER.warning("No intents were loaded for language: %s", language)
|
_LOGGER.warning("No intents were loaded for language: %s", language)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
slot_lists = self._make_slot_lists()
|
slot_lists = await self._make_slot_lists()
|
||||||
intent_context = self._make_intent_context(user_input)
|
intent_context = self._make_intent_context(user_input)
|
||||||
|
|
||||||
if self._exposed_names_trie is not None:
|
if self._exposed_names_trie is not None:
|
||||||
@@ -556,6 +579,36 @@ class DefaultAgent(ConversationEntity):
|
|||||||
# Don't try matching against all entities or doing a fuzzy match
|
# Don't try matching against all entities or doing a fuzzy match
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Use fuzzy matching
|
||||||
|
skip_fuzzy_match = False
|
||||||
|
if cache_value is not None:
|
||||||
|
if (cache_value.result is not None) and (
|
||||||
|
cache_value.stage == IntentMatchingStage.FUZZY
|
||||||
|
):
|
||||||
|
_LOGGER.debug("Got cached result for fuzzy match")
|
||||||
|
return cache_value.result
|
||||||
|
|
||||||
|
# Continue with matching, but we know we won't succeed for fuzzy
|
||||||
|
# match.
|
||||||
|
skip_fuzzy_match = True
|
||||||
|
|
||||||
|
if (not skip_fuzzy_match) and self.fuzzy_matching:
|
||||||
|
start_time = time.monotonic()
|
||||||
|
fuzzy_result = self._recognize_fuzzy(lang_intents, user_input)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._intent_cache.put(
|
||||||
|
cache_key,
|
||||||
|
IntentCacheValue(result=fuzzy_result, stage=IntentMatchingStage.FUZZY),
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Did fuzzy match in %s second(s)", time.monotonic() - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
if fuzzy_result is not None:
|
||||||
|
return fuzzy_result
|
||||||
|
|
||||||
# Try again with all entities (including unexposed)
|
# Try again with all entities (including unexposed)
|
||||||
skip_unexposed_entities_match = False
|
skip_unexposed_entities_match = False
|
||||||
if cache_value is not None:
|
if cache_value is not None:
|
||||||
@@ -601,102 +654,160 @@ class DefaultAgent(ConversationEntity):
|
|||||||
# This should fail the intent handling phase (async_match_targets).
|
# This should fail the intent handling phase (async_match_targets).
|
||||||
return strict_result
|
return strict_result
|
||||||
|
|
||||||
# Try again with missing entities enabled
|
# Check unknown names
|
||||||
skip_fuzzy_match = False
|
skip_unknown_names = False
|
||||||
if cache_value is not None:
|
if cache_value is not None:
|
||||||
if (cache_value.result is not None) and (
|
if (cache_value.result is not None) and (
|
||||||
cache_value.stage == IntentMatchingStage.FUZZY
|
cache_value.stage == IntentMatchingStage.UNKNOWN_NAMES
|
||||||
):
|
):
|
||||||
_LOGGER.debug("Got cached result for fuzzy match")
|
_LOGGER.debug("Got cached result for unknown names")
|
||||||
return cache_value.result
|
return cache_value.result
|
||||||
|
|
||||||
# We know we won't succeed for fuzzy matching.
|
skip_unknown_names = True
|
||||||
skip_fuzzy_match = True
|
|
||||||
|
|
||||||
maybe_result: RecognizeResult | None = None
|
maybe_result: RecognizeResult | None = None
|
||||||
if not skip_fuzzy_match:
|
if not skip_unknown_names:
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
best_num_matched_entities = 0
|
maybe_result = self._recognize_unknown_names(
|
||||||
best_num_unmatched_entities = 0
|
lang_intents, user_input, slot_lists, intent_context
|
||||||
best_num_unmatched_ranges = 0
|
)
|
||||||
for result in recognize_all(
|
|
||||||
user_input.text,
|
|
||||||
lang_intents.intents,
|
|
||||||
slot_lists=slot_lists,
|
|
||||||
intent_context=intent_context,
|
|
||||||
allow_unmatched_entities=True,
|
|
||||||
):
|
|
||||||
if result.text_chunks_matched < 1:
|
|
||||||
# Skip results that don't match any literal text
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Don't count missing entities that couldn't be filled from context
|
|
||||||
num_matched_entities = 0
|
|
||||||
for matched_entity in result.entities_list:
|
|
||||||
if matched_entity.name not in result.unmatched_entities:
|
|
||||||
num_matched_entities += 1
|
|
||||||
|
|
||||||
num_unmatched_entities = 0
|
|
||||||
num_unmatched_ranges = 0
|
|
||||||
for unmatched_entity in result.unmatched_entities_list:
|
|
||||||
if isinstance(unmatched_entity, UnmatchedTextEntity):
|
|
||||||
if unmatched_entity.text != MISSING_ENTITY:
|
|
||||||
num_unmatched_entities += 1
|
|
||||||
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
|
|
||||||
num_unmatched_ranges += 1
|
|
||||||
num_unmatched_entities += 1
|
|
||||||
else:
|
|
||||||
num_unmatched_entities += 1
|
|
||||||
|
|
||||||
if (
|
|
||||||
(maybe_result is None) # first result
|
|
||||||
or (
|
|
||||||
# More literal text matched
|
|
||||||
result.text_chunks_matched > maybe_result.text_chunks_matched
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
# More entities matched
|
|
||||||
num_matched_entities > best_num_matched_entities
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
# Fewer unmatched entities
|
|
||||||
(num_matched_entities == best_num_matched_entities)
|
|
||||||
and (num_unmatched_entities < best_num_unmatched_entities)
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
# Prefer unmatched ranges
|
|
||||||
(num_matched_entities == best_num_matched_entities)
|
|
||||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
|
||||||
and (num_unmatched_ranges > best_num_unmatched_ranges)
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
# Prefer match failures with entities
|
|
||||||
(result.text_chunks_matched == maybe_result.text_chunks_matched)
|
|
||||||
and (num_unmatched_entities == best_num_unmatched_entities)
|
|
||||||
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
|
||||||
and (
|
|
||||||
("name" in result.entities)
|
|
||||||
or ("name" in result.unmatched_entities)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
maybe_result = result
|
|
||||||
best_num_matched_entities = num_matched_entities
|
|
||||||
best_num_unmatched_entities = num_unmatched_entities
|
|
||||||
best_num_unmatched_ranges = num_unmatched_ranges
|
|
||||||
|
|
||||||
# Update cache
|
# Update cache
|
||||||
self._intent_cache.put(
|
self._intent_cache.put(
|
||||||
cache_key,
|
cache_key,
|
||||||
IntentCacheValue(result=maybe_result, stage=IntentMatchingStage.FUZZY),
|
IntentCacheValue(
|
||||||
|
result=maybe_result, stage=IntentMatchingStage.UNKNOWN_NAMES
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Did fuzzy match in %s second(s)", time.monotonic() - start_time
|
"Did unknown names match in %s second(s)", time.monotonic() - start_time
|
||||||
)
|
)
|
||||||
|
|
||||||
return maybe_result
|
return maybe_result
|
||||||
|
|
||||||
|
def _recognize_fuzzy(
|
||||||
|
self, lang_intents: LanguageIntents, user_input: ConversationInput
|
||||||
|
) -> RecognizeResult | None:
|
||||||
|
"""Return fuzzy recognition from hassil."""
|
||||||
|
if lang_intents.fuzzy_matcher is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fuzzy_result = lang_intents.fuzzy_matcher.match(user_input.text)
|
||||||
|
if fuzzy_result is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
response = "default"
|
||||||
|
if lang_intents.fuzzy_responses:
|
||||||
|
domain = "" # no domain
|
||||||
|
if "name" in fuzzy_result.slots:
|
||||||
|
domain = fuzzy_result.name_domain
|
||||||
|
elif "domain" in fuzzy_result.slots:
|
||||||
|
domain = fuzzy_result.slots["domain"].value
|
||||||
|
|
||||||
|
slot_combo = tuple(sorted(fuzzy_result.slots))
|
||||||
|
if (
|
||||||
|
intent_responses := lang_intents.fuzzy_responses.get(
|
||||||
|
fuzzy_result.intent_name
|
||||||
|
)
|
||||||
|
) and (combo_responses := intent_responses.get(slot_combo)):
|
||||||
|
response = combo_responses.get(domain, response)
|
||||||
|
|
||||||
|
entities = [
|
||||||
|
MatchEntity(name=slot_name, value=slot_value.value, text=slot_value.text)
|
||||||
|
for slot_name, slot_value in fuzzy_result.slots.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
return RecognizeResult(
|
||||||
|
intent=Intent(name=fuzzy_result.intent_name),
|
||||||
|
intent_data=IntentData(sentence_texts=[]),
|
||||||
|
intent_metadata={METADATA_FUZZY_MATCH: True},
|
||||||
|
entities={entity.name: entity for entity in entities},
|
||||||
|
entities_list=entities,
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _recognize_unknown_names(
|
||||||
|
self,
|
||||||
|
lang_intents: LanguageIntents,
|
||||||
|
user_input: ConversationInput,
|
||||||
|
slot_lists: dict[str, SlotList],
|
||||||
|
intent_context: dict[str, Any] | None,
|
||||||
|
) -> RecognizeResult | None:
|
||||||
|
"""Return result with unknown names for an error message."""
|
||||||
|
maybe_result: RecognizeResult | None = None
|
||||||
|
|
||||||
|
best_num_matched_entities = 0
|
||||||
|
best_num_unmatched_entities = 0
|
||||||
|
best_num_unmatched_ranges = 0
|
||||||
|
for result in recognize_all(
|
||||||
|
user_input.text,
|
||||||
|
lang_intents.intents,
|
||||||
|
slot_lists=slot_lists,
|
||||||
|
intent_context=intent_context,
|
||||||
|
allow_unmatched_entities=True,
|
||||||
|
):
|
||||||
|
if result.text_chunks_matched < 1:
|
||||||
|
# Skip results that don't match any literal text
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Don't count missing entities that couldn't be filled from context
|
||||||
|
num_matched_entities = 0
|
||||||
|
for matched_entity in result.entities_list:
|
||||||
|
if matched_entity.name not in result.unmatched_entities:
|
||||||
|
num_matched_entities += 1
|
||||||
|
|
||||||
|
num_unmatched_entities = 0
|
||||||
|
num_unmatched_ranges = 0
|
||||||
|
for unmatched_entity in result.unmatched_entities_list:
|
||||||
|
if isinstance(unmatched_entity, UnmatchedTextEntity):
|
||||||
|
if unmatched_entity.text != MISSING_ENTITY:
|
||||||
|
num_unmatched_entities += 1
|
||||||
|
elif isinstance(unmatched_entity, UnmatchedRangeEntity):
|
||||||
|
num_unmatched_ranges += 1
|
||||||
|
num_unmatched_entities += 1
|
||||||
|
else:
|
||||||
|
num_unmatched_entities += 1
|
||||||
|
|
||||||
|
if (
|
||||||
|
(maybe_result is None) # first result
|
||||||
|
or (
|
||||||
|
# More literal text matched
|
||||||
|
result.text_chunks_matched > maybe_result.text_chunks_matched
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
# More entities matched
|
||||||
|
num_matched_entities > best_num_matched_entities
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
# Fewer unmatched entities
|
||||||
|
(num_matched_entities == best_num_matched_entities)
|
||||||
|
and (num_unmatched_entities < best_num_unmatched_entities)
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
# Prefer unmatched ranges
|
||||||
|
(num_matched_entities == best_num_matched_entities)
|
||||||
|
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||||
|
and (num_unmatched_ranges > best_num_unmatched_ranges)
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
# Prefer match failures with entities
|
||||||
|
(result.text_chunks_matched == maybe_result.text_chunks_matched)
|
||||||
|
and (num_unmatched_entities == best_num_unmatched_entities)
|
||||||
|
and (num_unmatched_ranges == best_num_unmatched_ranges)
|
||||||
|
and (
|
||||||
|
("name" in result.entities)
|
||||||
|
or ("name" in result.unmatched_entities)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
maybe_result = result
|
||||||
|
best_num_matched_entities = num_matched_entities
|
||||||
|
best_num_unmatched_entities = num_unmatched_entities
|
||||||
|
best_num_unmatched_ranges = num_unmatched_ranges
|
||||||
|
|
||||||
|
return maybe_result
|
||||||
|
|
||||||
def _get_unexposed_entity_names(self, text: str) -> TextSlotList:
|
def _get_unexposed_entity_names(self, text: str) -> TextSlotList:
|
||||||
"""Get filtered slot list with unexposed entity names in Home Assistant."""
|
"""Get filtered slot list with unexposed entity names in Home Assistant."""
|
||||||
if self._unexposed_names_trie is None:
|
if self._unexposed_names_trie is None:
|
||||||
@@ -851,7 +962,7 @@ class DefaultAgent(ConversationEntity):
|
|||||||
if lang_intents is None:
|
if lang_intents is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._make_slot_lists()
|
await self._make_slot_lists()
|
||||||
|
|
||||||
async def async_get_or_load_intents(self, language: str) -> LanguageIntents | None:
|
async def async_get_or_load_intents(self, language: str) -> LanguageIntents | None:
|
||||||
"""Load all intents of a language with lock."""
|
"""Load all intents of a language with lock."""
|
||||||
@@ -1002,12 +1113,85 @@ class DefaultAgent(ConversationEntity):
|
|||||||
intent_responses = responses_dict.get("intents", {})
|
intent_responses = responses_dict.get("intents", {})
|
||||||
error_responses = responses_dict.get("errors", {})
|
error_responses = responses_dict.get("errors", {})
|
||||||
|
|
||||||
|
if not self.fuzzy_matching:
|
||||||
|
_LOGGER.debug("Fuzzy matching is disabled")
|
||||||
|
return LanguageIntents(
|
||||||
|
intents,
|
||||||
|
intents_dict,
|
||||||
|
intent_responses,
|
||||||
|
error_responses,
|
||||||
|
language_variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load fuzzy
|
||||||
|
fuzzy_info = get_fuzzy_language(language_variant, json_load=json_load)
|
||||||
|
if fuzzy_info is None:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Fuzzy matching not available for language: %s", language_variant
|
||||||
|
)
|
||||||
|
return LanguageIntents(
|
||||||
|
intents,
|
||||||
|
intents_dict,
|
||||||
|
intent_responses,
|
||||||
|
error_responses,
|
||||||
|
language_variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._fuzzy_config is None:
|
||||||
|
# Load shared config
|
||||||
|
self._fuzzy_config = get_fuzzy_config(json_load=json_load)
|
||||||
|
_LOGGER.debug("Loaded shared fuzzy matching config")
|
||||||
|
|
||||||
|
assert self._fuzzy_config is not None
|
||||||
|
|
||||||
|
fuzzy_matcher: FuzzyNgramMatcher | None = None
|
||||||
|
fuzzy_responses: FuzzyLanguageResponses | None = None
|
||||||
|
|
||||||
|
start_time = time.monotonic()
|
||||||
|
fuzzy_responses = fuzzy_info.responses
|
||||||
|
fuzzy_matcher = FuzzyNgramMatcher(
|
||||||
|
intents=intents,
|
||||||
|
intent_models={
|
||||||
|
intent_name: Sqlite3NgramModel(
|
||||||
|
order=fuzzy_model.order,
|
||||||
|
words={
|
||||||
|
word: str(word_id)
|
||||||
|
for word, word_id in fuzzy_model.words.items()
|
||||||
|
},
|
||||||
|
database_path=fuzzy_model.database_path,
|
||||||
|
)
|
||||||
|
for intent_name, fuzzy_model in fuzzy_info.ngram_models.items()
|
||||||
|
},
|
||||||
|
intent_slot_list_names=self._fuzzy_config.slot_list_names,
|
||||||
|
slot_combinations={
|
||||||
|
intent_name: {
|
||||||
|
combo_key: [
|
||||||
|
SlotCombinationInfo(
|
||||||
|
name_domains=(set(name_domains) if name_domains else None)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
for combo_key, name_domains in intent_combos.items()
|
||||||
|
}
|
||||||
|
for intent_name, intent_combos in self._fuzzy_config.slot_combinations.items()
|
||||||
|
},
|
||||||
|
domain_keywords=fuzzy_info.domain_keywords,
|
||||||
|
stop_words=fuzzy_info.stop_words,
|
||||||
|
)
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Loaded fuzzy matcher in %s second(s): language=%s, intents=%s",
|
||||||
|
time.monotonic() - start_time,
|
||||||
|
language_variant,
|
||||||
|
sorted(fuzzy_matcher.intent_models.keys()),
|
||||||
|
)
|
||||||
|
|
||||||
return LanguageIntents(
|
return LanguageIntents(
|
||||||
intents,
|
intents,
|
||||||
intents_dict,
|
intents_dict,
|
||||||
intent_responses,
|
intent_responses,
|
||||||
error_responses,
|
error_responses,
|
||||||
language_variant,
|
language_variant,
|
||||||
|
fuzzy_matcher=fuzzy_matcher,
|
||||||
|
fuzzy_responses=fuzzy_responses,
|
||||||
)
|
)
|
||||||
|
|
||||||
@core.callback
|
@core.callback
|
||||||
@@ -1027,8 +1211,7 @@ class DefaultAgent(ConversationEntity):
|
|||||||
# Slot lists have changed, so we must clear the cache
|
# Slot lists have changed, so we must clear the cache
|
||||||
self._intent_cache.clear()
|
self._intent_cache.clear()
|
||||||
|
|
||||||
@core.callback
|
async def _make_slot_lists(self) -> dict[str, SlotList]:
|
||||||
def _make_slot_lists(self) -> dict[str, SlotList]:
|
|
||||||
"""Create slot lists with areas and entity names/aliases."""
|
"""Create slot lists with areas and entity names/aliases."""
|
||||||
if self._slot_lists is not None:
|
if self._slot_lists is not None:
|
||||||
return self._slot_lists
|
return self._slot_lists
|
||||||
@@ -1089,6 +1272,10 @@ class DefaultAgent(ConversationEntity):
|
|||||||
"floor": TextSlotList.from_tuples(floor_names, allow_template=False),
|
"floor": TextSlotList.from_tuples(floor_names, allow_template=False),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Reload fuzzy matchers with new slot lists
|
||||||
|
if self.fuzzy_matching:
|
||||||
|
await self.hass.async_add_executor_job(self._load_fuzzy_matchers)
|
||||||
|
|
||||||
self._listen_clear_slot_list()
|
self._listen_clear_slot_list()
|
||||||
|
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
@@ -1098,6 +1285,25 @@ class DefaultAgent(ConversationEntity):
|
|||||||
|
|
||||||
return self._slot_lists
|
return self._slot_lists
|
||||||
|
|
||||||
|
def _load_fuzzy_matchers(self) -> None:
|
||||||
|
"""Reload fuzzy matchers for all loaded languages."""
|
||||||
|
for lang_intents in self._lang_intents.values():
|
||||||
|
if (not isinstance(lang_intents, LanguageIntents)) or (
|
||||||
|
lang_intents.fuzzy_matcher is None
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
lang_matcher = lang_intents.fuzzy_matcher
|
||||||
|
lang_intents.fuzzy_matcher = FuzzyNgramMatcher(
|
||||||
|
intents=lang_matcher.intents,
|
||||||
|
intent_models=lang_matcher.intent_models,
|
||||||
|
intent_slot_list_names=lang_matcher.intent_slot_list_names,
|
||||||
|
slot_combinations=lang_matcher.slot_combinations,
|
||||||
|
domain_keywords=lang_matcher.domain_keywords,
|
||||||
|
stop_words=lang_matcher.stop_words,
|
||||||
|
slot_lists=self._slot_lists,
|
||||||
|
)
|
||||||
|
|
||||||
def _make_intent_context(
|
def _make_intent_context(
|
||||||
self, user_input: ConversationInput
|
self, user_input: ConversationInput
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
@@ -1521,10 +1727,8 @@ def _get_match_error_response(
|
|||||||
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
|
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
|
||||||
"""Collect list reference names recursively."""
|
"""Collect list reference names recursively."""
|
||||||
if isinstance(expression, Group):
|
if isinstance(expression, Group):
|
||||||
grp: Group = expression
|
for item in expression.items:
|
||||||
for item in grp.items:
|
|
||||||
_collect_list_references(item, list_names)
|
_collect_list_references(item, list_names)
|
||||||
elif isinstance(expression, ListReference):
|
elif isinstance(expression, ListReference):
|
||||||
# {list}
|
# {list}
|
||||||
list_ref: ListReference = expression
|
list_names.add(expression.slot_name)
|
||||||
list_names.add(list_ref.slot_name)
|
|
||||||
|
@@ -26,7 +26,11 @@ from .agent_manager import (
|
|||||||
get_agent_manager,
|
get_agent_manager,
|
||||||
)
|
)
|
||||||
from .const import DATA_COMPONENT, DATA_DEFAULT_ENTITY
|
from .const import DATA_COMPONENT, DATA_DEFAULT_ENTITY
|
||||||
from .default_agent import METADATA_CUSTOM_FILE, METADATA_CUSTOM_SENTENCE
|
from .default_agent import (
|
||||||
|
METADATA_CUSTOM_FILE,
|
||||||
|
METADATA_CUSTOM_SENTENCE,
|
||||||
|
METADATA_FUZZY_MATCH,
|
||||||
|
)
|
||||||
from .entity import ConversationEntity
|
from .entity import ConversationEntity
|
||||||
from .models import ConversationInput
|
from .models import ConversationInput
|
||||||
|
|
||||||
@@ -240,6 +244,8 @@ async def websocket_hass_agent_debug(
|
|||||||
"sentence_template": "",
|
"sentence_template": "",
|
||||||
# When match is incomplete, this will contain the best slot guesses
|
# When match is incomplete, this will contain the best slot guesses
|
||||||
"unmatched_slots": _get_unmatched_slots(intent_result),
|
"unmatched_slots": _get_unmatched_slots(intent_result),
|
||||||
|
# True if match was not exact
|
||||||
|
"fuzzy_match": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if successful_match:
|
if successful_match:
|
||||||
@@ -251,16 +257,19 @@ async def websocket_hass_agent_debug(
|
|||||||
if intent_result.intent_sentence is not None:
|
if intent_result.intent_sentence is not None:
|
||||||
result_dict["sentence_template"] = intent_result.intent_sentence.text
|
result_dict["sentence_template"] = intent_result.intent_sentence.text
|
||||||
|
|
||||||
# Inspect metadata to determine if this matched a custom sentence
|
if intent_result.intent_metadata:
|
||||||
if intent_result.intent_metadata and intent_result.intent_metadata.get(
|
# Inspect metadata to determine if this matched a custom sentence
|
||||||
METADATA_CUSTOM_SENTENCE
|
if intent_result.intent_metadata.get(METADATA_CUSTOM_SENTENCE):
|
||||||
):
|
result_dict["source"] = "custom"
|
||||||
result_dict["source"] = "custom"
|
result_dict["file"] = intent_result.intent_metadata.get(
|
||||||
result_dict["file"] = intent_result.intent_metadata.get(
|
METADATA_CUSTOM_FILE
|
||||||
METADATA_CUSTOM_FILE
|
)
|
||||||
|
else:
|
||||||
|
result_dict["source"] = "builtin"
|
||||||
|
|
||||||
|
result_dict["fuzzy_match"] = intent_result.intent_metadata.get(
|
||||||
|
METADATA_FUZZY_MATCH, False
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
result_dict["source"] = "builtin"
|
|
||||||
|
|
||||||
result_dicts.append(result_dict)
|
result_dicts.append(result_dict)
|
||||||
|
|
||||||
|
@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import stt, tts, wake_word
|
from homeassistant.components import conversation, stt, tts, wake_word
|
||||||
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
|
from homeassistant.components.assist_pipeline import DOMAIN, select as assist_select
|
||||||
from homeassistant.components.assist_pipeline.const import (
|
from homeassistant.components.assist_pipeline.const import (
|
||||||
BYTES_PER_CHUNK,
|
BYTES_PER_CHUNK,
|
||||||
@@ -295,6 +295,11 @@ async def init_supporting_components(
|
|||||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||||
assert await async_setup_component(hass, "media_source", {})
|
assert await async_setup_component(hass, "media_source", {})
|
||||||
|
assert await async_setup_component(hass, "conversation", {"conversation": {}})
|
||||||
|
|
||||||
|
# Disable fuzzy matching by default for tests
|
||||||
|
agent = hass.data[conversation.DATA_DEFAULT_ENTITY]
|
||||||
|
agent.fuzzy_matching = False
|
||||||
|
|
||||||
config_entry = MockConfigEntry(domain="test")
|
config_entry = MockConfigEntry(domain="test")
|
||||||
config_entry.add_to_hass(hass)
|
config_entry.add_to_hass(hass)
|
||||||
|
@@ -73,4 +73,8 @@ async def sl_setup(hass: HomeAssistant):
|
|||||||
async def init_components(hass: HomeAssistant):
|
async def init_components(hass: HomeAssistant):
|
||||||
"""Initialize relevant components with empty configs."""
|
"""Initialize relevant components with empty configs."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {conversation.DOMAIN: {}})
|
||||||
|
|
||||||
|
# Disable fuzzy matching by default for tests
|
||||||
|
agent = hass.data[conversation.DATA_DEFAULT_ENTITY]
|
||||||
|
agent.fuzzy_matching = False
|
||||||
|
@@ -464,6 +464,7 @@
|
|||||||
'value': 'my cool light',
|
'value': 'my cool light',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
'fuzzy_match': False,
|
||||||
'intent': dict({
|
'intent': dict({
|
||||||
'name': 'HassTurnOn',
|
'name': 'HassTurnOn',
|
||||||
}),
|
}),
|
||||||
@@ -472,7 +473,6 @@
|
|||||||
'slots': dict({
|
'slots': dict({
|
||||||
'name': 'my cool light',
|
'name': 'my cool light',
|
||||||
}),
|
}),
|
||||||
'source': 'builtin',
|
|
||||||
'targets': dict({
|
'targets': dict({
|
||||||
'light.kitchen': dict({
|
'light.kitchen': dict({
|
||||||
'matched': True,
|
'matched': True,
|
||||||
@@ -489,6 +489,7 @@
|
|||||||
'value': 'my cool light',
|
'value': 'my cool light',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
'fuzzy_match': False,
|
||||||
'intent': dict({
|
'intent': dict({
|
||||||
'name': 'HassTurnOff',
|
'name': 'HassTurnOff',
|
||||||
}),
|
}),
|
||||||
@@ -497,7 +498,6 @@
|
|||||||
'slots': dict({
|
'slots': dict({
|
||||||
'name': 'my cool light',
|
'name': 'my cool light',
|
||||||
}),
|
}),
|
||||||
'source': 'builtin',
|
|
||||||
'targets': dict({
|
'targets': dict({
|
||||||
'light.kitchen': dict({
|
'light.kitchen': dict({
|
||||||
'matched': True,
|
'matched': True,
|
||||||
@@ -519,6 +519,7 @@
|
|||||||
'value': 'light',
|
'value': 'light',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
'fuzzy_match': False,
|
||||||
'intent': dict({
|
'intent': dict({
|
||||||
'name': 'HassTurnOn',
|
'name': 'HassTurnOn',
|
||||||
}),
|
}),
|
||||||
@@ -528,7 +529,6 @@
|
|||||||
'area': 'kitchen',
|
'area': 'kitchen',
|
||||||
'domain': 'light',
|
'domain': 'light',
|
||||||
}),
|
}),
|
||||||
'source': 'builtin',
|
|
||||||
'targets': dict({
|
'targets': dict({
|
||||||
'light.kitchen': dict({
|
'light.kitchen': dict({
|
||||||
'matched': True,
|
'matched': True,
|
||||||
@@ -555,6 +555,7 @@
|
|||||||
'value': 'on',
|
'value': 'on',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
'fuzzy_match': False,
|
||||||
'intent': dict({
|
'intent': dict({
|
||||||
'name': 'HassGetState',
|
'name': 'HassGetState',
|
||||||
}),
|
}),
|
||||||
@@ -565,7 +566,6 @@
|
|||||||
'domain': 'lights',
|
'domain': 'lights',
|
||||||
'state': 'on',
|
'state': 'on',
|
||||||
}),
|
}),
|
||||||
'source': 'builtin',
|
|
||||||
'targets': dict({
|
'targets': dict({
|
||||||
'light.kitchen': dict({
|
'light.kitchen': dict({
|
||||||
'matched': False,
|
'matched': False,
|
||||||
@@ -590,6 +590,7 @@
|
|||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
'file': 'en/beer.yaml',
|
'file': 'en/beer.yaml',
|
||||||
|
'fuzzy_match': False,
|
||||||
'intent': dict({
|
'intent': dict({
|
||||||
'name': 'OrderBeer',
|
'name': 'OrderBeer',
|
||||||
}),
|
}),
|
||||||
@@ -630,6 +631,7 @@
|
|||||||
'value': 'test light',
|
'value': 'test light',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
'fuzzy_match': False,
|
||||||
'intent': dict({
|
'intent': dict({
|
||||||
'name': 'HassLightSet',
|
'name': 'HassLightSet',
|
||||||
}),
|
}),
|
||||||
@@ -639,7 +641,6 @@
|
|||||||
'brightness': '100',
|
'brightness': '100',
|
||||||
'name': 'test light',
|
'name': 'test light',
|
||||||
}),
|
}),
|
||||||
'source': 'builtin',
|
|
||||||
'targets': dict({
|
'targets': dict({
|
||||||
'light.demo_1234': dict({
|
'light.demo_1234': dict({
|
||||||
'matched': True,
|
'matched': True,
|
||||||
@@ -662,6 +663,7 @@
|
|||||||
'value': 'test light',
|
'value': 'test light',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
'fuzzy_match': False,
|
||||||
'intent': dict({
|
'intent': dict({
|
||||||
'name': 'HassLightSet',
|
'name': 'HassLightSet',
|
||||||
}),
|
}),
|
||||||
@@ -670,7 +672,6 @@
|
|||||||
'slots': dict({
|
'slots': dict({
|
||||||
'name': 'test light',
|
'name': 'test light',
|
||||||
}),
|
}),
|
||||||
'source': 'builtin',
|
|
||||||
'targets': dict({
|
'targets': dict({
|
||||||
}),
|
}),
|
||||||
'unmatched_slots': dict({
|
'unmatched_slots': dict({
|
||||||
|
@@ -25,7 +25,12 @@ from homeassistant.components.intent import (
|
|||||||
TimerInfo,
|
TimerInfo,
|
||||||
async_register_timer_handler,
|
async_register_timer_handler,
|
||||||
)
|
)
|
||||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
from homeassistant.components.light import (
|
||||||
|
ATTR_SUPPORTED_COLOR_MODES,
|
||||||
|
DOMAIN as LIGHT_DOMAIN,
|
||||||
|
ColorMode,
|
||||||
|
intent as light_intent,
|
||||||
|
)
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DEVICE_CLASS,
|
ATTR_DEVICE_CLASS,
|
||||||
ATTR_FRIENDLY_NAME,
|
ATTR_FRIENDLY_NAME,
|
||||||
@@ -81,6 +86,10 @@ async def init_components(hass: HomeAssistant) -> None:
|
|||||||
assert await async_setup_component(hass, "conversation", {})
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
assert await async_setup_component(hass, "intent", {})
|
assert await async_setup_component(hass, "intent", {})
|
||||||
|
|
||||||
|
# Disable fuzzy matching by default for tests
|
||||||
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
|
agent.fuzzy_matching = False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"er_kwargs",
|
"er_kwargs",
|
||||||
@@ -3287,3 +3296,97 @@ async def test_language_with_alternative_code(
|
|||||||
assert call.domain == LIGHT_DOMAIN
|
assert call.domain == LIGHT_DOMAIN
|
||||||
assert call.service == "turn_on"
|
assert call.service == "turn_on"
|
||||||
assert call.data == {"entity_id": [entity_id]}
|
assert call.data == {"entity_id": [entity_id]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("fuzzy_matching", [True, False])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("sentence", "intent_type", "slots"),
|
||||||
|
[
|
||||||
|
("time", "HassGetCurrentTime", {}),
|
||||||
|
("how about my timers", "HassTimerStatus", {}),
|
||||||
|
(
|
||||||
|
"the office needs more blue",
|
||||||
|
"HassLightSet",
|
||||||
|
{"area": "office", "color": "blue"},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"50% office light",
|
||||||
|
"HassLightSet",
|
||||||
|
{"name": "office light", "brightness": "50%"},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_fuzzy_matching(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
area_registry: ar.AreaRegistry,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
fuzzy_matching: bool,
|
||||||
|
sentence: str,
|
||||||
|
intent_type: str,
|
||||||
|
slots: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Test fuzzy vs. non-fuzzy matching on some English sentences."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
assert await async_setup_component(hass, "intent", {})
|
||||||
|
await light_intent.async_setup_intents(hass)
|
||||||
|
|
||||||
|
agent = hass.data[DATA_DEFAULT_ENTITY]
|
||||||
|
agent.fuzzy_matching = fuzzy_matching
|
||||||
|
|
||||||
|
area_office = area_registry.async_get_or_create("office_id")
|
||||||
|
area_office = area_registry.async_update(area_office.id, name="office")
|
||||||
|
|
||||||
|
entry = MockConfigEntry()
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
office_satellite = device_registry.async_get_or_create(
|
||||||
|
config_entry_id=entry.entry_id,
|
||||||
|
connections=set(),
|
||||||
|
identifiers={("demo", "id-1234")},
|
||||||
|
)
|
||||||
|
device_registry.async_update_device(office_satellite.id, area_id=area_office.id)
|
||||||
|
|
||||||
|
office_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
||||||
|
office_light = entity_registry.async_update_entity(
|
||||||
|
office_light.entity_id, area_id=area_office.id
|
||||||
|
)
|
||||||
|
hass.states.async_set(
|
||||||
|
office_light.entity_id,
|
||||||
|
"on",
|
||||||
|
attributes={
|
||||||
|
ATTR_FRIENDLY_NAME: "office light",
|
||||||
|
ATTR_SUPPORTED_COLOR_MODES: [ColorMode.BRIGHTNESS, ColorMode.RGB],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
_on_calls = async_mock_service(hass, LIGHT_DOMAIN, "turn_on")
|
||||||
|
|
||||||
|
result = await conversation.async_converse(
|
||||||
|
hass,
|
||||||
|
sentence,
|
||||||
|
None,
|
||||||
|
Context(),
|
||||||
|
language="en",
|
||||||
|
device_id=office_satellite.id,
|
||||||
|
)
|
||||||
|
response = result.response
|
||||||
|
|
||||||
|
if not fuzzy_matching:
|
||||||
|
# Should not match
|
||||||
|
assert response.response_type == intent.IntentResponseType.ERROR
|
||||||
|
return
|
||||||
|
|
||||||
|
assert response.response_type in (
|
||||||
|
intent.IntentResponseType.ACTION_DONE,
|
||||||
|
intent.IntentResponseType.QUERY_ANSWER,
|
||||||
|
)
|
||||||
|
assert response.intent is not None
|
||||||
|
assert response.intent.intent_type == intent_type
|
||||||
|
|
||||||
|
# Verify slot texts match
|
||||||
|
actual_slots = {
|
||||||
|
slot_name: slot_value["text"]
|
||||||
|
for slot_name, slot_value in response.intent.slots.items()
|
||||||
|
if slot_name != "preferred_area_id" # context area
|
||||||
|
}
|
||||||
|
assert actual_slots == slots
|
||||||
|
Reference in New Issue
Block a user