Skip to content

Commit

Permalink
Merge pull request #67 from RasaHQ/validate-prefilled
Browse files Browse the repository at this point in the history
Validate prefilled slots
  • Loading branch information
erohmensing authored Apr 12, 2019
2 parents cfffaac + bd8518f commit c2c73f9
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ This project adheres to `Semantic Versioning`_ starting with version 0.11.0.
Added
-----
- add formatter 'black'
- Slots filled before the start of a form are now validated upon form start
- In debug mode, the values of required slots for a form are now printed
before submitting

Changed
-------
Expand Down
95 changes: 71 additions & 24 deletions rasa_core_sdk/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,31 @@ def extract_requested_slot(
logger.debug("Failed to extract requested slot '{}'".format(slot_to_fill))
return {}

def validate_slots(self, slot_dict, dispatcher, tracker, domain):
# type: (Dict[Text, Any], CollectingDispatcher, Tracker, Dict[Text, Any]) -> List[Dict]
"""Validate slots using helper validation functions.
Call validate_{slot} function for each slot, value pair to be validated.
If this function is not implemented, set the slot to the value.
"""

for slot, value in list(slot_dict.items()):
validate_func = getattr(
self, "validate_{}".format(slot), lambda *x: {slot: value}
)
validation_output = validate_func(value, dispatcher, tracker, domain)
if not isinstance(validation_output, dict):
logger.warning(
"Returning values in helper validation methods is deprecated. "
+ "Your `validate_{}()` method should return ".format(slot)
+ "a dict of {'slot_name': value} instead."
)
validation_output = {slot: validation_output}
slot_dict.update(validation_output)

# validation succeed, set slots to extracted values
return [SlotSet(slot, value) for slot, value in slot_dict.items()]

def validate(self, dispatcher, tracker, domain):
# type: (CollectingDispatcher, Tracker, Dict[Text, Any]) -> List[Dict]
"""Extract and validate value of requested slot.
Expand Down Expand Up @@ -339,22 +364,8 @@ def validate(self, dispatcher, tracker, domain):
"with action {1}"
"".format(slot_to_fill, self.name()),
)

for slot, value in list(slot_values.items()):
validate_func = getattr(
self, "validate_{}".format(slot), lambda *x: {slot: value}
)
validation_output = validate_func(value, dispatcher, tracker, domain)
if not isinstance(validation_output, dict):
logger.warning(
"Returning values in helper validation methods is deprecated. "
+ "Your method should return a dict of {'slot_name': value} instead."
)
validation_output = {slot: validation_output}
slot_values.update(validation_output)

# validation succeed, set slots to extracted values
return [SlotSet(slot, value) for slot, value in slot_values.items()]
logger.debug("Validating extracted slots: {}".format(slot_values))
return self.validate_slots(slot_values, dispatcher, tracker, domain)

# noinspection PyUnusedLocal
def request_next_slot(
Expand All @@ -378,7 +389,7 @@ def request_next_slot(
)
return [SlotSet(REQUESTED_SLOT, slot)]

logger.debug("No slots left to request")
# no more required slots to fill
return None

def deactivate(self):
Expand Down Expand Up @@ -425,10 +436,27 @@ def _list_intents(

return self._to_list(intent), self._to_list(not_intent)

def _activate_if_required(self, tracker):
# type: (Tracker) -> List[Dict]
"""Return `Form` event with the name of the form
if the form was called for the first time"""
def _log_form_slots(self, tracker):
"""Logs the values of all required slots before submitting the form."""

req_slots = self.required_slots(tracker)
slot_values = "\n".join(
["\t{}: {}".format(slot, tracker.get_slot(slot)) for slot in req_slots]
)
logger.debug(
"No slots left to request, all required slots are filled:\n{}".format(
slot_values
)
)

def _activate_if_required(self, dispatcher, tracker, domain):
# type: (CollectingDispatcher, Tracker, Dict[Text, Any]) -> List[Dict]
"""Activate form if the form is called for the first time.
If activating, validate any required slots that were filled before
form activation and return `Form` event with the name of the form, as well
as any `SlotSet` events from validation of pre-filled slots.
"""

if tracker.active_form.get("name") is not None:
logger.debug("The form '{}' is active".format(tracker.active_form))
Expand All @@ -439,7 +467,25 @@ def _activate_if_required(self, tracker):
return []
else:
logger.debug("Activated the form '{}'".format(self.name()))
return [Form(self.name())]
events = [Form(self.name())]

# collect values of required slots filled before activation
prefilled_slots = {}
for slot_name in self.required_slots(tracker):
if not self._should_request_slot(tracker, slot_name):
prefilled_slots[slot_name] = tracker.get_slot(slot_name)

if prefilled_slots:
logger.debug(
"Validating pre-filled required slots: {}".format(prefilled_slots)
)
events.extend(
self.validate_slots(prefilled_slots, dispatcher, tracker, domain)
)
else:
logger.debug("No pre-filled required slots to validate.")

return events

def _validate_if_required(self, dispatcher, tracker, domain):
# type: (CollectingDispatcher, Tracker, Dict[Text, Any]) -> List[Dict]
Expand Down Expand Up @@ -479,10 +525,9 @@ def run(self, dispatcher, tracker, domain):
"""

# activate the form
events = self._activate_if_required(tracker)
events = self._activate_if_required(dispatcher, tracker, domain)
# validate user input
events.extend(self._validate_if_required(dispatcher, tracker, domain))

# check that the form wasn't deactivated in validation
if Form(None) not in events:

Expand All @@ -499,6 +544,8 @@ def run(self, dispatcher, tracker, domain):
events.extend(next_slot_events)
else:
# there is nothing more to request, so we can submit
self._log_form_slots(tracker)
logger.debug("Submitting the form '{}'".format(self.name()))
events.extend(self.submit(dispatcher, temp_tracker, domain))
# deactivate the form after submission
events.extend(self.deactivate())
Expand Down
69 changes: 67 additions & 2 deletions tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,67 @@ def validate_some_slot(self, value, dispatcher, tracker, domain):
assert events == [SlotSet("some_slot", "validated_value")]


def test_validate_prefilled_slots():
# noinspection PyAbstractClass
class CustomFormAction(FormAction):
def name(self):
return "some_form"

@staticmethod
def required_slots(_tracker):
return ["some_slot", "some_other_slot"]

def validate_some_slot(self, value, dispatcher, tracker, domain):
if value == "some_value":
return {"some_slot": "validated_value"}
else:
return {"some_slot": None}

form = CustomFormAction()

tracker = Tracker(
"default",
{"some_slot": "some_value", "some_other_slot": "some_other_value"},
{
"entities": [{"entity": "some_slot", "value": "some_bad_value"}],
"text": "some text",
},
[],
False,
None,
{},
"action_listen",
)

events = form._activate_if_required(dispatcher=None, tracker=tracker, domain=None)
# check that the form was activated and prefilled slots were validated
assert events == [
Form("some_form"),
SlotSet("some_slot", "validated_value"),
SlotSet("some_other_slot", "some_other_value"),
] or events == [ # this 'or' is only necessary for python 2.7 and 3.5
Form("some_form"),
SlotSet("some_other_slot", "some_other_value"),
SlotSet("some_slot", "validated_value"),
]

events.extend(
form._validate_if_required(dispatcher=None, tracker=tracker, domain=None)
)
# check that entities picked up in input overwrite prefilled slots
assert events == [
Form("some_form"),
SlotSet("some_slot", "validated_value"),
SlotSet("some_other_slot", "some_other_value"),
SlotSet("some_slot", None),
] or events == [ # this 'or' is only necessary for python 2.7 and 3.5
Form("some_form"),
SlotSet("some_other_slot", "some_other_value"),
SlotSet("some_slot", "validated_value"),
SlotSet("some_slot", None),
]


def test_validate_trigger_slots():
"""Test validation results of from_trigger_intent slot mappings
"""
Expand Down Expand Up @@ -827,6 +888,10 @@ class CustomFormAction(FormAction):
def name(self):
return "some_form"

@staticmethod
def required_slots(_tracker):
return ["some_slot", "some_other_slot"]

form = CustomFormAction()

tracker = Tracker(
Expand All @@ -840,7 +905,7 @@ def name(self):
"action_listen",
)

events = form._activate_if_required(tracker)
events = form._activate_if_required(dispatcher=None, tracker=tracker, domain=None)
# check that the form was activated
assert events == [Form("some_form")]

Expand All @@ -855,7 +920,7 @@ def name(self):
"action_listen",
)

events = form._activate_if_required(tracker)
events = form._activate_if_required(dispatcher=None, tracker=tracker, domain=None)
# check that the form was not activated again
assert events == []

Expand Down

0 comments on commit c2c73f9

Please sign in to comment.