Skip to content

Commit

Permalink
Domain Payload Optimization to Action server
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksandarmijat committed Jun 5, 2024
1 parent 11aa42e commit 4ed415e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 9 deletions.
10 changes: 9 additions & 1 deletion rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
DEFAULT_SERVER_PORT,
)
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException
from rasa_sdk.interfaces import (
ActionExecutionRejection,
ActionNotFoundException,
ActionMissingDomainException,
)
from rasa_sdk.plugin import plugin_manager
from rasa_sdk.tracing.utils import (
get_tracer_and_context,
Expand Down Expand Up @@ -153,6 +157,10 @@ async def webhook(request: Request) -> HTTPResponse:
logger.error(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=404)
except ActionMissingDomainException as e:
logger.error(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=449)

set_span_attributes(span, action_call)

Expand Down
60 changes: 52 additions & 8 deletions rasa_sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@
import inspect
import logging
import pkgutil
import typing
import warnings
from typing import Text, List, Dict, Any, Type, Union, Callable, Optional, Set, cast
from collections import namedtuple
import types
import sys
import os

from rasa_sdk.interfaces import Tracker, ActionNotFoundException, Action
from rasa_sdk.interfaces import (
Tracker,
ActionNotFoundException,
Action,
ActionMissingDomainException,
)

from rasa_sdk import utils

if typing.TYPE_CHECKING: # pragma: no cover
from rasa_sdk.types import ActionCall

logger = logging.getLogger(__name__)


class CollectingDispatcher:
"""Send messages back to user"""

def __init__(self) -> None:

self.messages: List[Dict[Text, Any]] = []

def utter_message(
Expand Down Expand Up @@ -162,6 +162,8 @@ def __init__(self) -> None:
self.actions: Dict[Text, Callable] = {}
self._modules: Dict[Text, TimestampModule] = {}
self._loaded: Set[Type[Action]] = set()
self.domain: Optional[Dict[Text, Any]] = None
self.domain_digest: Optional[Text] = None

def register_action(self, action: Union[Type[Action], Action]) -> None:
if inspect.isclass(action):
Expand Down Expand Up @@ -380,7 +382,49 @@ def validate_events(events: List[Dict[Text, Any]], action_name: Text):
# we won't append this to validated events -> will be ignored
return validated

async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]:
def is_domain_digest_valid(self, domain_digest: Optional[Text]) -> bool:
"""Check if the domain_digest is valid
If the domain_digest is empty or different from the one provided, it is invalid.
Args:
domain_digest: latest value provided to compare the current value with.
Returns:
True if the domain_digest is valid, False otherwise.
"""
return bool(self.domain_digest) and self.domain_digest == domain_digest

def update_and_return_domain(
self, payload: Dict[Text, Any], action_name: Text
) -> Optional[Dict[Text, Any]]:
"""Validate the digest, store the domain if available, and return the domain.
This method validates the domain digest from the payload.
If the digest is invalid and no domain is provided, an exception is raised.
If domain data is available, it stores the domain and digest.
Finally, it returns the domain.
Args:
payload: Request payload containing the domain data.
action_name: Name of the action that should be executed.
Returns:
The domain dictionary.
Raises:
ActionMissingDomainException: Invalid digest and no domain data available.
"""
payload_domain = payload.get("domain")
payload_domain_digest = payload.get("domain_digest")

# If digest is invalid and no domain is available - raise the error
if (
not self.is_domain_digest_valid(payload_domain_digest)
and payload_domain is None
):
raise ActionMissingDomainException(action_name)

if payload_domain:
self.domain = payload_domain
self.domain_digest = payload_domain_digest

return self.domain

async def run(self, action_call: Dict[Text, Any]) -> Optional[Dict[Text, Any]]:
from rasa_sdk.interfaces import Tracker

action_name = action_call.get("next_action")
Expand All @@ -391,7 +435,7 @@ async def run(self, action_call: "ActionCall") -> Optional[Dict[Text, Any]]:
raise ActionNotFoundException(action_name)

tracker_json = action_call["tracker"]
domain = action_call.get("domain", {})
domain = self.update_and_return_domain(action_call, action_name)
tracker = Tracker.from_dict(tracker_json)
dispatcher = CollectingDispatcher()

Expand Down
11 changes: 11 additions & 0 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,14 @@ def __init__(self, action_name: Text, message: Optional[Text] = None) -> None:

def __str__(self) -> Text:
return self.message


class ActionMissingDomainException(Exception):
"""Raising this exception when the domain is missing."""

def __init__(self, action_name: Text, message: Optional[Text] = None) -> None:
self.action_name = action_name
self.message = message or "Domain context is missing."

def __str__(self) -> Text:
return self.message
4 changes: 4 additions & 0 deletions tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_server_webhook_handles_action_exception(sanic_app: Sanic):
data = {
"next_action": "custom_action_exception",
"tracker": {"sender_id": "1", "conversation_id": "default"},
"domain": {},
}
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
assert response.status == 500
Expand All @@ -76,6 +77,7 @@ def test_server_webhook_custom_action_returns_200(sanic_app: Sanic):
data = {
"next_action": "custom_action",
"tracker": {"sender_id": "1", "conversation_id": "default"},
"domain": {},
}
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")
Expand All @@ -88,6 +90,7 @@ def test_server_webhook_custom_async_action_returns_200(sanic_app: Sanic):
data = {
"next_action": "custom_async_action",
"tracker": {"sender_id": "1", "conversation_id": "default"},
"domain": {},
}
request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")
Expand Down Expand Up @@ -148,6 +151,7 @@ def test_server_webhook_custom_action_with_dialogue_stack_returns_200(
data = {
"next_action": "custom_action_with_dialogue_stack",
"tracker": {"sender_id": "1", "conversation_id": "default", **stack_state},
"domain": {},
}
_, response = sanic_app.test_client.post("/webhook", data=json.dumps(data))
events = response.json.get("events")
Expand Down
1 change: 1 addition & 0 deletions tests/tracing/instrumentation/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_server_webhook_custom_action_is_instrumented(
"rasa_sdk.endpoint.get_tracer_provider", lambda _: tracer_provider
)
data["next_action"] = action_name
data["domain"] = {}
app = ep.create_app(action_package)

app.register_listener(
Expand Down

0 comments on commit 4ed415e

Please sign in to comment.