Skip to content

Commit

Permalink
Merge pull request #273 from RasaHQ/slot-mapping-roles-groups
Browse files Browse the repository at this point in the history
Only fill other slots if slot mapping contains a group/role restriction
  • Loading branch information
tabergma authored Sep 23, 2020
2 parents a1616b3 + 362d367 commit bded575
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 94 deletions.
2 changes: 2 additions & 0 deletions changelog/237.enhancement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Only fill other slots if slot mapping contains a role or group restriction and
the entity type matches.
3 changes: 1 addition & 2 deletions rasa_sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ def register_package(self, package: Union[Text, types.ModuleType]) -> None:
self._register_all_actions()

def _register_all_actions(self) -> None:
"""Scan for all user subclasses of `Action`, and register them.
"""
"""Scan for all user subclasses of `Action`, and register them."""
import inspect

actions = utils.all_subclasses(Action)
Expand Down
92 changes: 64 additions & 28 deletions rasa_sdk/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,33 +190,42 @@ def intent_is_desired(
return intent_not_blacklisted or intent in mapping_intents

def entity_is_desired(
self, requested_slot_mapping: Dict[Text, Any], slot: Text, tracker: "Tracker"
self,
other_slot_mapping: Dict[Text, Any],
other_slot: Text,
entity_type_of_slot_to_fill: Optional[Text],
tracker: "Tracker",
) -> bool:
"""Check whether slot should be filled by an entity in the input or not.
"""Check whether the other slot should be filled by an entity in the input or
not.
Args:
requested_slot_mapping: Slot mapping.
slot: The slot to be filled.
other_slot_mapping: Slot mapping.
other_slot: The other slot to be filled.
entity_type_of_slot_to_fill: Entity type of slot to fill.
tracker: The tracker.
Returns:
True, if slot should be filled, false otherwise.
True, if other slot should be filled, false otherwise.
"""

# slot name is equal to the entity type
slot_equals_entity = slot == requested_slot_mapping.get("entity")
other_slot_equals_entity = other_slot == other_slot_mapping.get("entity")

# use the custom slot mapping 'from_entity' defined by the user to check
# whether we can fill a slot with an entity
matching_values = self.get_entity_value(
requested_slot_mapping.get("entity"),
tracker,
requested_slot_mapping.get("role"),
requested_slot_mapping.get("group"),
)
slot_fulfils_entity_mapping = matching_values is not None
other_slot_fulfils_entity_mapping = False
if (
other_slot_mapping.get("role") is not None
or other_slot_mapping.get("group") is not None
) and entity_type_of_slot_to_fill == other_slot_mapping.get("entity"):
matching_values = self.get_entity_value(
other_slot_mapping.get("entity"),
tracker,
other_slot_mapping.get("role"),
other_slot_mapping.get("group"),
)
other_slot_fulfils_entity_mapping = matching_values is not None

return slot_equals_entity or slot_fulfils_entity_mapping
return other_slot_equals_entity or other_slot_fulfils_entity_mapping

@staticmethod
def get_entity_value(
Expand Down Expand Up @@ -254,11 +263,15 @@ def extract_other_slots(
domain: Dict[Text, Any],
) -> Dict[Text, Any]:
"""Extract the values of the other slots
if they are set by corresponding entities from the user input
else return None
if they are set by corresponding entities from the user input
else return None
"""
slot_to_fill = tracker.get_slot(REQUESTED_SLOT)

entity_type_of_slot_to_fill = self._get_entity_type_of_slot_to_fill(
slot_to_fill
)

slot_values = {}
for slot in self.required_slots(tracker):
# look for other slots
Expand All @@ -271,7 +284,12 @@ def extract_other_slots(
should_fill_entity_slot = (
other_slot_mapping["type"] == "from_entity"
and self.intent_is_desired(other_slot_mapping, tracker)
and self.entity_is_desired(other_slot_mapping, slot, tracker)
and self.entity_is_desired(
other_slot_mapping,
slot,
entity_type_of_slot_to_fill,
tracker,
)
)
# check whether the slot should be
# filled from trigger intent mapping
Expand Down Expand Up @@ -308,7 +326,7 @@ def extract_requested_slot(
domain: Dict[Text, Any],
) -> Dict[Text, Any]:
"""Extract the value of requested slot from a user input
else return None
else return None
"""
slot_to_fill = tracker.get_slot(REQUESTED_SLOT)
logger.debug(f"Trying to extract requested slot '{slot_to_fill}' ...")
Expand Down Expand Up @@ -421,7 +439,7 @@ def request_next_slot(
domain: Dict[Text, Any],
) -> Optional[List[EventType]]:
"""Request the next slot and utter template if needed,
else return None"""
else return None"""

for slot in self.required_slots(tracker):
if self._should_request_slot(tracker, slot):
Expand All @@ -434,7 +452,7 @@ def request_next_slot(

def deactivate(self) -> List[EventType]:
"""Return `Form` event with `None` as name to deactivate the form
and reset the requested slot"""
and reset the requested slot"""

logger.debug(f"Deactivating the form '{self.name()}'")
return [Form(None), SlotSet(REQUESTED_SLOT, None)]
Expand All @@ -446,15 +464,15 @@ async def submit(
domain: Dict[Text, Any],
) -> List[EventType]:
"""Define what the form has to do
after all required slots are filled"""
after all required slots are filled"""

raise NotImplementedError("A form must implement a submit method")

# helpers
@staticmethod
def _to_list(x: Optional[Any]) -> List[Any]:
"""Convert object to a list if it is not a list,
None converted to empty list
None converted to empty list
"""
if x is None:
x = []
Expand Down Expand Up @@ -538,10 +556,10 @@ async def _validate_if_required(
domain: Dict[Text, Any],
) -> List[EventType]:
"""Return a list of events from `self.validate(...)`
if validation is required:
- the form is active
- the form is called after `action_listen`
- form validation was not cancelled
if validation is required:
- the form is active
- the form is called after `action_listen`
- form validation was not cancelled
"""
if tracker.latest_action_name == "action_listen" and tracker.active_form.get(
"validate", True
Expand Down Expand Up @@ -614,3 +632,21 @@ async def run(

def __str__(self) -> Text:
return f"FormAction('{self.name()}')"

def _get_entity_type_of_slot_to_fill(self, slot_to_fill: Text,) -> Optional[Text]:
if not slot_to_fill:
return None

mappings = self.get_mappings_for_slot(slot_to_fill)
mappings = [m for m in mappings if m.get("type") == "from_entity"]

if not mappings:
return None

entity_type = mappings[0].get("entity")

for i in range(1, len(mappings)):
if entity_type != mappings[i].get("entity"):
return None

return entity_type
4 changes: 2 additions & 2 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def applied_events(self) -> List[Dict[Text, Any]]:

def undo_till_previous(event_type: Text, done_events: List[Dict[Text, Any]]):
"""Removes events from `done_events` until the first
occurrence `event_type` is found which is also removed."""
occurrence `event_type` is found which is also removed."""
# list gets modified - hence we need to copy events!
for e in reversed(done_events[:]):
del done_events[-1]
Expand Down Expand Up @@ -259,7 +259,7 @@ def __str__(self) -> Text:

class ActionExecutionRejection(Exception):
"""Raising this exception will allow other policies
to predict another action"""
to predict another action"""

def __init__(self, action_name: Text, message: Optional[Text] = None) -> None:
self.action_name = action_name
Expand Down
Loading

0 comments on commit bded575

Please sign in to comment.