Skip to content

Commit

Permalink
Merge pull request #253 from RasaHQ/extract-slots-helper
Browse files Browse the repository at this point in the history
add helper to extract form slot candidates
  • Loading branch information
wochinge authored Aug 31, 2020
2 parents 4e5fb5f + 80657d7 commit 4fb112a
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 1 deletion.
32 changes: 32 additions & 0 deletions changelog/238.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
Added the method ``form_slots_to_validate`` to ``Tracker``. This method is helpful
when using a custom action to validate slots which were extracted by a Form as shown
by the following example.

.. code-block:: python
class ValidateSlots(Action):
def name(self) -> Text:
return "validate_your_form"
def run(
self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: Dict
) -> List[EventType]:
extracted_slots: Dict[Text, Any] = tracker.form_slots_to_validate()
validation_events = []
for slot_name, slot_value in extracted_slots:
# Check if slot is valid.
if is_valid(slot_value):
validation_events.append(SlotSet(slot_name, slot_value))
else:
# Return a `SlotSet` event with value `None` to indicate that this
# slot still needs to be filled.
validation_events.append(SlotSet(slot_name, None))
return validation_events
def is_valid(slot_value: Any) -> bool:
# Implementation of the validate function.
Please note that ``tracker.form_slots_to_validate`` only works with Rasa Open Source 2.
28 changes: 28 additions & 0 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,34 @@ def undo_till_previous(event_type: Text, done_events: List[Dict[Text, Any]]):
applied_events.append(event)
return applied_events

def form_slots_to_validate(self) -> Dict[Text, Any]:
"""Get form slots which need validation.
You can use a custom action to validate slots which were extracted during the
latest form execution. This method provides you all extracted candidates for
form slots.
Returns:
A mapping of extracted slot candidates and their values.
"""

slots_to_validate = {}

if not self.active_loop:
return slots_to_validate

for event in reversed(self.events):
# The `FormAction` in Rasa Open Source will append all slot candidates
# at the end of the tracker events.
if event["event"] == "slot":
slots_to_validate[event["name"]] = event["value"]
else:
# Stop as soon as there is another event type as this means that we
# checked all potential slot candidates.
break

return slots_to_validate


class Action:
"""Next action to be taken in response to a dialogue state."""
Expand Down
43 changes: 42 additions & 1 deletion tests/test_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

from rasa_sdk import Tracker
from rasa_sdk.events import ActionExecuted, UserUttered, ActionReverted
from rasa_sdk.events import ActionExecuted, UserUttered, ActionReverted, SlotSet
from rasa_sdk.interfaces import ACTION_LISTEN_NAME
from typing import List, Dict, Text, Any

Expand Down Expand Up @@ -70,3 +72,42 @@ def test_last_executed_has_not_name():
tracker = get_tracker(events)

assert tracker.last_executed_action_has("another") is False


@pytest.mark.parametrize(
"events, expected_extracted_slots",
[
([], {}),
([ActionExecuted("my_form")], {}),
(
[ActionExecuted("my_form"), SlotSet("my_slot", "some_value")],
{"my_slot": "some_value"},
),
(
[
ActionExecuted("my_form"),
SlotSet("my_slot", "some_value"),
SlotSet("some_other", "some_value2"),
],
{"my_slot": "some_value", "some_other": "some_value2"},
),
([SlotSet("my_slot", "some_value"), ActionExecuted("my_form")], {},),
],
)
def test_get_extracted_slots(
events: List[Dict[Text, Any]], expected_extracted_slots: Dict[Text, Any]
):
tracker = get_tracker(events)
tracker.active_loop = {"name": "my form"}
assert tracker.form_slots_to_validate() == expected_extracted_slots


def test_get_extracted_slots_with_no_active_loop():
events = [
ActionExecuted("my_form"),
SlotSet("my_slot", "some_value"),
SlotSet("some_other", "some_value2"),
]
tracker = get_tracker(events)

assert tracker.form_slots_to_validate() == {}

0 comments on commit 4fb112a

Please sign in to comment.