Skip to content

Commit

Permalink
fix/adapt_config (#515)
Browse files Browse the repository at this point in the history
* refactor/adjust_adapt_conf_levels

"what's the weather" was failing with 0.7 confidence, while adapt_high was capped at 0.8

lower the adapt thresholds by default

* fix/read_adapt_from_conf

* Update adapt_service.py

* better/padatious defaults

- remove artificial training delay
- default to multithreading

* more logs to help in pipeline matching
  • Loading branch information
JarbasAl authored Jun 22, 2024
1 parent 0be30d8 commit 5b6cb2f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 35 deletions.
24 changes: 13 additions & 11 deletions ovos_core/intent_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@
# limitations under the License.
#
from collections import namedtuple

from typing import Tuple, Callable
from ovos_bus_client.message import Message
from ovos_bus_client.session import SessionManager
from ovos_bus_client.util import get_message_lang
from ovos_config.config import Configuration
from ovos_config.locale import setup_locale, get_valid_languages, get_full_lang_code

from ovos_core.intent_services.adapt_service import AdaptService
from ovos_core.intent_services.commonqa_service import CommonQAService
from ovos_core.intent_services.converse_service import ConverseService
from ovos_core.intent_services.fallback_service import FallbackService
from ovos_core.intent_services.ocp_service import OCPPipelineMatcher
from ovos_core.intent_services.padacioso_service import PadaciosoService
from ovos_core.intent_services.stop_service import StopService
from ovos_core.intent_services.ocp_service import OCPPipelineMatcher
from ovos_core.transformers import MetadataTransformersService, UtteranceTransformersService
from ovos_utils.log import LOG, deprecated, log_deprecation
from ovos_utils.metrics import Stopwatch
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(self, bus, config=None):
self.skill_names = {}

# TODO - replace with plugins
self.adapt_service = AdaptService()
self.adapt_service = AdaptService(config=self.config.get("adapt", {}))
self.padatious_service = None
try:
if self.config["padatious"].get("disabled"):
Expand Down Expand Up @@ -198,7 +199,7 @@ def disambiguate_lang(message):

return default_lang

def get_pipeline(self, skips=None, session=None):
def get_pipeline(self, skips=None, session=None) -> Tuple[str, Callable]:
"""return a list of matcher functions ordered by priority
utterances will be sent to each matcher in order until one can handle the utterance
the list can be configured in mycroft.conf under intents.pipeline,
Expand Down Expand Up @@ -248,7 +249,7 @@ def get_pipeline(self, skips=None, session=None):
f"filtered {[k for k in pipeline if k not in matchers]}")
pipeline = [k for k in pipeline if k in matchers]
LOG.debug(f"Session pipeline: {pipeline}")
return [matchers[k] for k in pipeline]
return [(k, matchers[k]) for k in pipeline]

def _validate_session(self, message, lang):
# get session
Expand Down Expand Up @@ -373,9 +374,10 @@ def handle_utterance(self, message: Message):
match = None
with stopwatch:
# Loop through the matching functions until a match is found.
for match_func in self.get_pipeline(session=sess):
for pipeline, match_func in self.get_pipeline(session=sess):
match = match_func(utterances, lang, message)
if match:
LOG.info(f"{pipeline} match: {match}")
if match.skill_id and match.skill_id in sess.blacklisted_skills:
LOG.debug(f"ignoring match, skill_id '{match.skill_id}' blacklisted by Session '{sess.session_id}'")
continue
Expand Down Expand Up @@ -509,11 +511,11 @@ def handle_get_intent(self, message):
sess = SessionManager.get(message)

# Loop through the matching functions until a match is found.
for match_func in self.get_pipeline(skips=["converse",
"fallback_high",
"fallback_medium",
"fallback_low"],
session=sess):
for pipeline, match_func in self.get_pipeline(skips=["converse",
"fallback_high",
"fallback_medium",
"fallback_low"],
session=sess):
match = match_func([utterance], lang, message)
if match:
if match.intent_type:
Expand Down
15 changes: 7 additions & 8 deletions ovos_core/intent_services/adapt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(self, config=None):
self.max_words = 50 # if an utterance contains more words than this, don't attempt to match

# TODO sanitize config option
self.conf_high = self.config.get("conf_high") or 0.8
self.conf_med = self.config.get("conf_med") or 0.5
self.conf_low = self.config.get("conf_low") or 0.3
self.conf_high = self.config.get("conf_high") or 0.65
self.conf_med = self.config.get("conf_med") or 0.45
self.conf_low = self.config.get("conf_low") or 0.25

@property
def context_keywords(self):
Expand Down Expand Up @@ -145,7 +145,7 @@ def match_high(self, utterances: List[str],
with optional normalized version.
"""
match = self.match_intent(tuple(utterances), lang, message.serialize())
if match and match.intent_data.get("confidence", 0.0) > self.conf_high:
if match and match.intent_data.get("confidence", 0.0) >= self.conf_high:
return match
return None

Expand All @@ -159,7 +159,7 @@ def match_medium(self, utterances: List[str],
with optional normalized version.
"""
match = self.match_intent(tuple(utterances), lang, message.serialize())
if match and match.intent_data.get("confidence", 0.0) > self.conf_med:
if match and match.intent_data.get("confidence", 0.0) >= self.conf_med:
return match
return None

Expand All @@ -173,7 +173,7 @@ def match_low(self, utterances: List[str],
with optional normalized version.
"""
match = self.match_intent(tuple(utterances), lang, message.serialize())
if match and match.intent_data.get("confidence", 0.0) > self.conf_low:
if match and match.intent_data.get("confidence", 0.0) >= self.conf_low:
return match
return None

Expand Down Expand Up @@ -227,7 +227,6 @@ def take_best(intent, utt):
# TODO - Shouldn't Adapt do this?
best_intent['utterance'] = utt

sess = SessionManager.get(message)
for utt in utterances:
try:
intents = [i for i in self.engines[lang].determine_intent(
Expand Down Expand Up @@ -358,4 +357,4 @@ def detach_intent(self, intent_name):
def shutdown(self):
for lang in self.engines:
parsers = self.engines[lang].intent_parsers
self.engines[lang].drop_intent_parser(parsers)
self.engines[lang].drop_intent_parser(parsers)
6 changes: 3 additions & 3 deletions ovos_core/intent_services/ocp_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def handle_m(m):

return self.ocp_sessions[sess.session_id]

def normalize_results(self, results: list) -> List[Union[MediaEntry, Playlist]]:
def normalize_results(self, results: list) -> List[Union[MediaEntry, Playlist, PluginStream]]:
# support Playlist and MediaEntry objects in tracks
for idx, track in enumerate(results):
if isinstance(track, dict):
Expand Down Expand Up @@ -944,7 +944,7 @@ def _execute_query(self, phrase: str,
LOG.debug(f'Returning {len(results)} search results')
return results

def select_best(self, results: list, message: Message) -> MediaEntry:
def select_best(self, results: list, message: Message) -> Union[MediaEntry, Playlist, PluginStream]:

sess = SessionManager.get(message)

Expand Down Expand Up @@ -980,7 +980,7 @@ def select_best(self, results: list, message: Message) -> MediaEntry:

##################
# Legacy Audio subsystem API
def legacy_play(self, results: List[MediaEntry], phrase="",
def legacy_play(self, results: List[Union[MediaEntry, Playlist, PluginStream]], phrase="",
message: Optional[Message] = None):
res = []
for r in results:
Expand Down
16 changes: 3 additions & 13 deletions ovos_core/intent_services/padatious_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ def __init__(self, bus, config):
self.finished_training_event = Event()
self.finished_initial_train = False

self.train_delay = self.padatious_config.get('train_delay', 4)
self.train_time = get_time() + self.train_delay

self.registered_intents = []
self.registered_entities = []
self.max_words = 50 # if an utterance contains more words than this, don't attempt to match
Expand All @@ -133,7 +130,7 @@ def train(self, message=None):
message (Message): optional triggering message
"""
self.finished_training_event.clear()
padatious_single_thread = self.padatious_config.get('single_thread', True)
padatious_single_thread = self.padatious_config.get('single_thread', False)
if message is None:
single_thread = padatious_single_thread
else:
Expand All @@ -142,7 +139,7 @@ def train(self, message=None):
for lang in self.containers:
self.containers[lang].train(single_thread=single_thread)

LOG.info('Training complete.')
LOG.debug('Training complete.')
self.finished_training_event.set()
if not self.finished_initial_train:
self.bus.emit(Message('mycroft.skills.trained'))
Expand All @@ -152,13 +149,7 @@ def wait_and_train(self):
"""Wait for minimum time between training and start training."""
if not self.finished_initial_train:
return
sleep(self.train_delay)
if self.train_time < 0.0:
return

if self.train_time <= get_time() + 0.01:
self.train_time = -1.0
self.train()
self.train()

def __detach_intent(self, intent_name):
""" Remove an intent if it has been registered.
Expand Down Expand Up @@ -214,7 +205,6 @@ def _register_object(self, message, object_name, register_func):

register_func(name, samples)

self.train_time = get_time() + self.train_delay
self.wait_and_train()

def register_intent(self, message):
Expand Down

0 comments on commit 5b6cb2f

Please sign in to comment.