Skip to content

Commit

Permalink
Merge pull request #1078 from RasaHQ/ATO-2188-allow-access-to-dialogu…
Browse files Browse the repository at this point in the history
…e-stack-from-custom-actions

[ATO-2188] Allow access to Dialogue stack from custom actions
  • Loading branch information
Tawakalt authored Mar 6, 2024
2 parents 93104bf + 1d85152 commit 5744372
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog/1078.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a `stack` property to the `Tracker` class which corresponds to the dialogue stack.
5 changes: 5 additions & 0 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def from_dict(cls, state: "TrackerState") -> "Tracker":
state.get("followup_action"),
state.get("active_loop", state.get("active_form", {})),
state.get("latest_action_name"),
state.get("stack", []),
)

def __init__(
Expand All @@ -45,6 +46,7 @@ def __init__(
followup_action: Optional[Text],
active_loop: Dict[Text, Any],
latest_action_name: Optional[Text],
stack: List[Dict[Text, Any]] = [],
) -> None:
"""Initialize the tracker."""

Expand All @@ -66,6 +68,7 @@ def __init__(
self.latest_message = latest_message if latest_message else {}
self.active_loop = active_loop
self.latest_action_name = latest_action_name
self.stack = stack

@property
def active_form(self) -> Dict[Text, Any]:
Expand Down Expand Up @@ -93,6 +96,7 @@ def current_state(self) -> Dict[Text, Any]:
"latest_input_channel": self.get_latest_input_channel(),
"active_loop": self.active_loop,
"latest_action_name": self.latest_action_name,
"stack": self.stack,
}

def current_slot_values(self) -> Dict[Text, Any]:
Expand Down Expand Up @@ -196,6 +200,7 @@ def copy(self) -> "Tracker":
self.followup_action,
self.active_loop,
self.latest_action_name,
self.stack,
)

def last_executed_action_has(self, name: Text, skip: int = 0) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions rasa_sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class TrackerState(TypedDict):
active_form: Dict[Text, Any]
# the name of the previously executed action or text of e2e action
latest_action_name: Optional[Text]
# the current dialogue stack
stack: List[Dict[Text, Any]]


class DomainDict(TypedDict):
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
from sanic import Sanic

Sanic.test_mode = True


def get_stack():
dialogue_stack = [
{
"frame_id": "CP6JP9GQ",
"flow_id": "check_balance",
"step_id": "0_check_balance",
"frame_type": "regular",
"type": "flow",
}
]
return dialogue_stack
13 changes: 13 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,16 @@ def run(
domain: DomainDict,
) -> List[Dict[Text, Any]]:
raise Exception("test exception")


class CustomActionWithDialogueStack(Action):
def name(cls) -> Text:
return "custom_action_with_dialogue_stack"

def run(
self,
dispatcher: CollectingDispatcher,
tracker: Tracker,
domain: DomainDict,
) -> List[Dict[Text, Any]]:
return [SlotSet("stack", tracker.stack)]
26 changes: 25 additions & 1 deletion tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Dict, List, Text
import json
import logging
import zlib
Expand All @@ -6,6 +7,7 @@

import rasa_sdk.endpoint as ep
from rasa_sdk.events import SlotSet
from tests.conftest import get_stack

# noinspection PyTypeChecker
app = ep.create_app(None)
Expand All @@ -23,14 +25,15 @@ def test_server_health_returns_200():
def test_server_list_actions_returns_200():
request, response = app.test_client.get("/actions")
assert response.status == 200
assert len(response.json) == 5
assert len(response.json) == 6

# ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS
expected = [
# defined in tests/test_actions.py
{"name": "custom_async_action"},
{"name": "custom_action"},
{"name": "custom_action_exception"},
{"name": "custom_action_with_dialogue_stack"},
# defined in tests/tracing/instrumentation/conftest.py
{"name": "mock_validation_action"},
{"name": "mock_form_validation_action"},
Expand Down Expand Up @@ -119,6 +122,27 @@ def test_server_webhook_custom_action_encoded_data_returns_200():
assert response.status == 200


@pytest.mark.parametrize(
"stack_state, dialogue_stack",
[
({}, []),
({"stack": get_stack()}, get_stack()),
],
)
def test_server_webhook_custom_action_with_dialogue_stack_returns_200(
stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]]
):
data = {
"next_action": "custom_action_with_dialogue_stack",
"tracker": {"sender_id": "1", "conversation_id": "default", **stack_state},
}
_, response = app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")

assert events == [SlotSet("stack", dialogue_stack)]
assert response.status == 200


# ENSURE THIS IS ALWAYS THE LAST TEST FOR OTHER TESTS TO RUN
# because the call to sys.exit() terminates pytest process
def test_endpoint_exit_for_unknown_actions_package():
Expand Down
22 changes: 21 additions & 1 deletion tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict
from typing import Any, Dict, List, Text

import pytest
from rasa_sdk.events import SlotSet

from rasa_sdk.interfaces import Tracker
from tests.conftest import get_stack


@pytest.mark.parametrize(
Expand Down Expand Up @@ -61,3 +62,22 @@ def test_tracker_with_slots():

assert tracker.slots["my slot"] == 5
assert tracker.slots["my slot 2"] is None


@pytest.mark.parametrize(
"stack_state, dialogue_stack",
[
({}, []),
({"stack": get_stack()}, get_stack()),
],
)
def test_stack_in_tracker_state(
stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]]
):

state = {"events": [], "sender_id": "old", "active_loop": {}, **stack_state}
tracker = Tracker.from_dict(state)

assert tracker.stack == dialogue_stack
assert tracker.copy().stack == dialogue_stack
assert tracker.current_state()["stack"] == dialogue_stack
1 change: 0 additions & 1 deletion tests/tracing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_get_tracer_and_context() -> None:
app = ep.create_app(None)
request, _ = app.test_client.post("/webhook", data=json.dumps(data))
tracer, context, span_name = get_tracer_and_context(None, request)
print(type(tracer))

assert isinstance(tracer, ProxyTracer)
assert span_name == "create_app.webhook"
Expand Down

0 comments on commit 5744372

Please sign in to comment.