Skip to content

Commit

Permalink
Extra rule while file upload (#1597)
Browse files Browse the repository at this point in the history
* added code related to handling the extra creation of nlu fallback data while uploading the files.

* added code related to handling the extra creation of nlu fallback data while uploading the files.

* Added test cases.
  • Loading branch information
maheshsattala authored Nov 14, 2024
1 parent 8632333 commit 38d1819
Show file tree
Hide file tree
Showing 11 changed files with 346 additions and 17 deletions.
3 changes: 2 additions & 1 deletion kairon/importer/data_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ def import_data(self):
self.validator.bot_content,
self.validator.chat_client_config.get('config'),
self.validator.other_collections,
self.overwrite, self.files_to_save)
self.overwrite, self.files_to_save,
default_fallback_data=True)
41 changes: 36 additions & 5 deletions kairon/shared/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,36 @@ def save_training_data(
other_collections: dict = None,
overwrite: bool = False,
what: set = REQUIREMENTS.copy(),
default_fallback_data: bool = False
):
"""
Save various components of the bot's training data to the database. Components can include the bot's domain,
stories, NLU data, actions, configuration, multiflow stories, bot content, and chat client configuration.
Args:
bot (Text): The unique identifier of the bot.
user (Text): The identifier of the user making the request.
config (dict, optional): Configuration settings for the bot.
domain (Domain, optional): The domain data for the bot, defining intents, entities, slots, and actions.
story_graph (StoryGraph, optional): Graph representation of the bot's stories and rules.
nlu (TrainingData, optional): NLU training data for the bot, including intents, entities, and examples.
actions (dict, optional): Action data for the bot, containing details of custom actions.
multiflow_stories (dict, optional): Multi-step story flows used in complex conversations.
bot_content (list, optional): Additional content for the bot, such as FAQs or responses.
chat_client_config (dict, optional): Configuration settings specific to the chat client.
other_collections (dict, optional): Other related data collections for extended functionalities.
overwrite (bool, optional): If True, existing data will be overwritten; otherwise, new data is appended.
what (set, optional): A set of data types to save, e.g., {"domain", "stories", "nlu"}.
default_fallback_data (bool, optional): If True, default fallback data is included.
Behavior:
- Deletes the specified existing data if `overwrite` is True.
- Saves each specified component in `what` to the database, invoking relevant helper functions for each data type.
Raises:
Exception: Raises exceptions if saving any component fails.
"""
if overwrite:
self.delete_bot_data(bot, user, what)

Expand All @@ -585,7 +614,7 @@ def save_training_data(
if "rules" in what:
self.save_rules(story_graph.story_steps, bot, user)
if "config" in what:
self.add_or_overwrite_config(config, bot, user)
self.add_or_overwrite_config(config, bot, user, default_fallback_data)
if "chat_client_config" in what:
self.save_chat_client_config(chat_client_config, bot, user)
if "multiflow_stories" in what:
Expand Down Expand Up @@ -2131,18 +2160,19 @@ def save_config(self, configs: dict, bot: Text, user: Text):
logging.info(e)
raise AppException(e)

def add_or_overwrite_config(self, configs: dict, bot: Text, user: Text):
def add_or_overwrite_config(self, configs: dict, bot: Text, user: Text, default_fallback_data: bool = False):
"""
saves bot configuration
:param configs: configuration
:param bot: bot id
:param user: user id
:param default_fallback_data: If True, default fallback data is included
:return: config unique id
"""
for custom_component in Utility.environment["model"]["pipeline"]["custom"]:
self.__insert_bot_id(configs, bot, custom_component)
self.add_default_fallback_config(configs, bot, user)
self.add_default_fallback_config(configs, bot, user, default_fallback_data)
try:
config_obj = Configs.objects().get(bot=bot)
except DoesNotExist:
Expand Down Expand Up @@ -5402,7 +5432,7 @@ def prepare_training_data_for_validation(
rules = self.get_rules_for_training(bot)
YAMLStoryWriter().dump(rules_path, rules.story_steps)

def add_default_fallback_config(self, config_obj: dict, bot: Text, user: Text):
def add_default_fallback_config(self, config_obj: dict, bot: Text, user: Text, default_fallback_data: bool = False):
idx = next(
(
idx
Expand Down Expand Up @@ -5449,7 +5479,8 @@ def add_default_fallback_config(self, config_obj: dict, bot: Text, user: Text):
fallback = {"name": "FallbackClassifier", "threshold": 0.7}
config_obj["pipeline"].insert(property_idx + 1, fallback)

self.add_default_fallback_data(bot, user, True, True)
if not default_fallback_data:
self.add_default_fallback_data(bot, user, True, True)

def add_default_fallback_data(
self,
Expand Down
25 changes: 25 additions & 0 deletions tests/testing_data/validator/valid_data/actions.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
http_action:
- action_name: action_say_hello
content_type: json
headers: []
http_url: https://jsonplaceholder.typicode.com/posts/1
params_list: []
request_method: GET
response:
dispatch: true
dispatch_type: text
evaluation_type: expression
value: ${data}
set_slots: []
- action_name: action_say_goodbye
content_type: json
headers: []
http_url: https://jsonplaceholder.typicode.com/posts/1
params_list: []
request_method: GET
response:
dispatch: true
dispatch_type: text
evaluation_type: expression
value: ${data}
set_slots: []
44 changes: 44 additions & 0 deletions tests/testing_data/validator/valid_data/chat_client_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
config:
api_server_host_url: http://testserver
botClassName: ''
buttonType: button
chatContainerClassName: ''
chat_server_base_url: null
container: '#root'
containerClassName: ''
formClassName: ''
headerClassName: ''
name: kairon_testing
openButtonClassName: ''
styles:
botStyle:
backgroundColor: '#e0e0e0'
color: '#000000'
fontFamily: '''Roboto'', sans-serif'
fontSize: 14px
iconSrc: ''
showIcon: 'false'
buttonStyle:
backgroundColor: '#2b3595'
color: '#ffffff'
containerStyles:
background: '#ffffff'
height: 500px
width: 350px
headerStyle:
backgroundColor: '#2b3595'
color: '#ffffff'
height: 60px
userStyle:
backgroundColor: '#2b3595'
color: '#ffffff'
fontFamily: '''Roboto'', sans-serif'
fontSize: 14px
iconSrc: ''
showIcon: 'false'
userClassName: ''
userStorage: ls
userType: custom
welcomeMessage: Hello! How are you? This is Testing Welcome Message.
whitelist:
- '*'
26 changes: 26 additions & 0 deletions tests/testing_data/validator/valid_data/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
recipe: default.v1
language: en
pipeline:
- name: WhitespaceTokenizer
- name: RegexFeaturizer
- name: LexicalSyntacticFeaturizer
- name: CountVectorsFeaturizer
- analyzer: char_wb
max_ngram: 4
min_ngram: 1
name: CountVectorsFeaturizer
- name: FallbackClassifier
threshold: 0.75
- epochs: 5
name: DIETClassifier
- name: EntitySynonymMapper
- epochs: 5
name: ResponseSelector
policies:
- name: MemoizationPolicy
- epochs: 5
max_history: 5
name: TEDPolicy
- name: RulePolicy
core_fallback_threshold: 0.3
core_fallback_action_name: action_small_talk
13 changes: 13 additions & 0 deletions tests/testing_data/validator/valid_data/data/nlu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
version: "3.1"
nlu:
- intent: greet
examples: |
- hey
- hello
- hi
- intent: deny
examples: |
- no
- never
- I don't think so
- don't like that
25 changes: 25 additions & 0 deletions tests/testing_data/validator/valid_data/data/rules.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
rules:
- rule: ask the user to rephrase whenever they send a message with low nlu confidence
steps:
- intent: nlu_fallback
- action: utter_please_rephrase

- rule: Only say `hello` if the user provided a location
condition:
- slot_was_set:
- location: true
steps:
- intent: greet
- action: utter_greet

- rule: Say `hello` when the user starts a conversation with intent `greet`
conversation_start: true
steps:
- intent: greet
- action: utter_greet

- rule: Rule which will not wait for user message once it was applied
steps:
- intent: greet
- action: utter_greet
wait_for_user_input: false
14 changes: 14 additions & 0 deletions tests/testing_data/validator/valid_data/data/stories.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
version: "3.1"
stories:
- story: greet
steps:
- intent: greet
- action: utter_greet
- action: action_say_hello
- action: action_restart
- story: say goodbye
steps:
- intent: deny
- action: utter_goodbye
- action: action_say_goodbye
- action: action_restart
26 changes: 26 additions & 0 deletions tests/testing_data/validator/valid_data/domain.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
version: "3.1"
config:
store_entities_as_slots: true
session_config:
session_expiration_time: 60
carry_over_slots_to_new_session: true
intents:
- greet:
use_entities: true
- deny:
use_entities: true
responses:
utter_goodbye:
- text: Bye
utter_greet:
- text: Hey! How are you?
utter_default:
- text: Can you rephrase!
utter_please_rephrase:
- text: I'm sorry, I didn't quite understand that. Could you rephrase?

actions:
- action_say_hello
- action_say_goodbye
- utter_greet
- utter_goodbye
94 changes: 87 additions & 7 deletions tests/unit_test/events/events_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ def _path(*args, **kwargs):
assert 'deny' in processor.fetch_intents(bot)
assert len(processor.fetch_stories(bot)) == 2
assert len(list(processor.fetch_training_examples(bot))) == 7
assert len(list(processor.fetch_responses(bot))) == 4
assert len(list(processor.fetch_responses(bot))) == 3
assert len(processor.fetch_actions(bot)) == 2
assert len(processor.fetch_rule_block_names(bot)) == 4
assert len(processor.fetch_rule_block_names(bot)) == 3

def test_trigger_data_importer_validate_and_save_append(self, monkeypatch):
bot = 'test_events'
Expand Down Expand Up @@ -302,7 +302,7 @@ def _path(*args, **kwargs):
assert len(list(processor.fetch_training_examples(bot))) == 13
assert len(list(processor.fetch_responses(bot))) == 6
assert len(processor.fetch_actions(bot)) == 2
assert len(processor.fetch_rule_block_names(bot)) == 4
assert len(processor.fetch_rule_block_names(bot)) == 3

def test_trigger_data_importer_validate_and_save_overwrite_same_user(self, monkeypatch):
bot = 'test_events'
Expand Down Expand Up @@ -338,9 +338,9 @@ def _path(*args, **kwargs):
assert 'deny' in processor.fetch_intents(bot)
assert len(processor.fetch_stories(bot)) == 2
assert len(list(processor.fetch_training_examples(bot))) == 7
assert len(list(processor.fetch_responses(bot))) == 4
assert len(list(processor.fetch_responses(bot))) == 3
assert len(processor.fetch_actions(bot)) == 2
assert len(processor.fetch_rule_block_names(bot)) == 4
assert len(processor.fetch_rule_block_names(bot)) == 3

@responses.activate
def test_trigger_data_importer_validate_event(self, monkeypatch):
Expand Down Expand Up @@ -836,9 +836,89 @@ def _path(*args, **kwargs):
mongo_processor = MongoProcessor()
assert len(mongo_processor.fetch_stories(bot)) == 3
assert len(list(mongo_processor.fetch_training_examples(bot))) == 21
assert len(list(mongo_processor.fetch_responses(bot))) == 14
print(list(mongo_processor.fetch_responses(bot)))
assert len(list(mongo_processor.fetch_responses(bot))) == 12
assert len(mongo_processor.fetch_actions(bot)) == 0
assert len(mongo_processor.fetch_rule_block_names(bot)) == 1
print(mongo_processor.fetch_rule_block_names(bot))
assert len(mongo_processor.fetch_rule_block_names(bot)) == 0

def test_trigger_data_importer_with_valid_data(self, monkeypatch):
bot = 'test_events_with_valid_data'
user = 'test'
test_data_path = os.path.join(pytest.tmp_dir, str(uuid.uuid4()))
shutil.copytree('tests/testing_data/validator/valid_data', test_data_path)

def _path(*args, **kwargs):
return test_data_path

monkeypatch.setattr(Utility, "get_latest_file", _path)

DataImporterLogProcessor.add_log(bot, user,
files_received=REQUIREMENTS - {"http_actions", "chat_client_config"})
TrainingDataImporterEvent(bot, user, import_data=True, overwrite=False).execute()
logs = list(DataImporterLogProcessor.get_logs(bot))
assert len(logs) == 1
assert not logs[0].get('intents').get('data')
assert not logs[0].get('stories').get('data')
assert not logs[0].get('utterances').get('data')
assert [action.get('data') for action in logs[0].get('actions') if action.get('type') == 'http_actions']
assert not logs[0].get('training_examples').get('data')
assert not logs[0].get('domain').get('data')
assert not logs[0].get('config').get('data')
assert not logs[0].get('exception')
assert logs[0]['start_timestamp']
assert logs[0]['end_timestamp']
assert logs[0]['status'] == 'Success'
assert logs[0]['event_status'] == EVENT_STATUS.COMPLETED.value

processor = MongoProcessor()
assert 'greet' in processor.fetch_intents(bot)
assert 'deny' in processor.fetch_intents(bot)
assert len(processor.fetch_stories(bot)) == 2
assert len(list(processor.fetch_training_examples(bot))) == 7
assert len(list(processor.fetch_responses(bot))) == 4
assert len(processor.fetch_actions(bot)) == 2
assert len(processor.fetch_rule_block_names(bot)) == 4

def test_trigger_data_importer_with_actions(self, monkeypatch):
bot = 'test_events_with_valid_data'
user = 'test'
actions = 'tests/testing_data/valid_yml/actions.yml'
test_data_path = os.path.join(pytest.tmp_dir, str(uuid.uuid4()))
shutil.copytree('tests/testing_data/validator/valid_data', test_data_path)
shutil.copy2(actions, test_data_path)

def _path(*args, **kwargs):
return test_data_path

monkeypatch.setattr(Utility, "get_latest_file", _path)

DataImporterLogProcessor.add_log(bot, user,
files_received=REQUIREMENTS - {"http_actions", "chat_client_config"})
TrainingDataImporterEvent(bot, user, import_data=True, overwrite=False).execute()
logs = list(DataImporterLogProcessor.get_logs(bot))
assert len(logs) == 2
assert not logs[0].get('intents').get('data')
assert not logs[0].get('stories').get('data')
assert not logs[0].get('utterances').get('data')
assert [action.get('data') for action in logs[0].get('actions') if action.get('type') == 'http_actions']
assert not logs[0].get('training_examples').get('data')
assert not logs[0].get('domain').get('data')
assert not logs[0].get('config').get('data')
assert not logs[0].get('exception')
assert logs[0]['start_timestamp']
assert logs[0]['end_timestamp']
assert logs[0]['status'] == 'Success'
assert logs[0]['event_status'] == EVENT_STATUS.COMPLETED.value

processor = MongoProcessor()
assert 'greet' in processor.fetch_intents(bot)
assert 'deny' in processor.fetch_intents(bot)
assert len(processor.fetch_stories(bot)) == 2
assert len(list(processor.fetch_training_examples(bot))) == 7
assert len(list(processor.fetch_responses(bot))) == 4
assert len(processor.fetch_actions(bot)) == 16
assert len(processor.fetch_rule_block_names(bot)) == 4

def test_trigger_faq_importer_validate_only(self, monkeypatch):
def _mock_execution(*args, **kwargs):
Expand Down
Loading

0 comments on commit 38d1819

Please sign in to comment.