Skip to content

Commit

Permalink
Add check loop def
Browse files Browse the repository at this point in the history
Co-authored-by: WashingtonBispo <[email protected]>
  • Loading branch information
emysdias and WashingtonBispo committed Mar 8, 2022
1 parent f104371 commit e3d62e6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog/10925.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add verify loop caused by checkpoints in "rasa data validate"
56 changes: 55 additions & 1 deletion rasa/validator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import itertools
import logging
import queue
from collections import defaultdict
from platform import node
from re import A
from typing import Set, Text, Optional, Dict, Any, List

from numpy import False_

import rasa.core.training.story_conflict
import rasa.shared.nlu.constants
from rasa.shared.constants import (
Expand All @@ -24,6 +29,10 @@
from rasa.shared.nlu.training_data.training_data import TrainingData
import rasa.shared.utils.io

from rasa.shared.core.training_data.structures import (
STORY_START,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -91,6 +100,47 @@ def verify_intents(self, ignore_warnings: bool = True) -> bool:
everything_is_alright = False

return everything_is_alright

def verify_loop_in_intents(
self, ignore_warnings: bool = True
) -> bool:
row = queue.Queue()
nodes = dict()
visited = dict()
loops_cp = []

everything_is_alright = True

for story in self.story_graph.story_steps:
start_cp = story.start_checkpoints[0].name
if start_cp not in nodes:
nodes[start_cp] = []
visited[start_cp] = False
if len(story.end_checkpoints) > 0:
end_cp = story.end_checkpoints[0].name
if(end_cp not in nodes):
nodes[end_cp] = []
visited[end_cp] = False
nodes[start_cp].append(end_cp)

if STORY_START in nodes:
row.put(STORY_START)
while not row.empty():
x = row.get()
visited[x] = True
for node in nodes[x]:
if visited[node]:
loops_cp.append(f"{x} => {node}")
everything_is_alright = ignore_warnings and everything_is_alright
else:
row.put(node)

if(len(loops_cp) > 0):
rasa.shared.utils.io.raise_warning(
f"These checkpoints '{loops_cp}' is causing loop"
)

return everything_is_alright

def verify_example_repetition_in_intents(
self, ignore_warnings: bool = True
Expand Down Expand Up @@ -327,10 +377,14 @@ def verify_nlu(self, ignore_warnings: bool = True) -> bool:
there_is_no_duplication = self.verify_example_repetition_in_intents(
ignore_warnings
)

logger.info("Validating loop of checkpoints...")
loop_in_checkpoint = self.verify_loop_in_intents(ignore_warnings)

logger.info("Validating utterances...")
stories_are_valid = self.verify_utterances_in_stories(ignore_warnings)
return intents_are_valid and stories_are_valid and there_is_no_duplication
return (intents_are_valid and stories_are_valid and there_is_no_duplication
and loop_in_checkpoint)

def verify_form_slots(self) -> bool:
"""Verifies that form slots match the slot mappings in domain."""
Expand Down

0 comments on commit e3d62e6

Please sign in to comment.