Skip to content

Commit

Permalink
Merge pull request #1074 from RasaHQ/ATO-2102-instrument-ValidationAc…
Browse files Browse the repository at this point in the history
…tion.run

[ATO-2102] Instrument `ValidationAction.run`
  • Loading branch information
Tawakalt authored Feb 9, 2024
2 parents 552f409 + 66edd17 commit 71489ca
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 9 deletions.
1 change: 1 addition & 0 deletions changelog/1074.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Instrument `ValidationAction.run` method and extract attributes `class_name`, `sender_id`, `action_name` and `slots_to_validate`.
3 changes: 2 additions & 1 deletion rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config
from rasa_sdk.tracing.instrumentation import instrumentation
from rasa_sdk.executor import ActionExecutor

from rasa_sdk.forms import ValidationAction

TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk")

Expand All @@ -38,6 +38,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None:
instrumentation.instrument(
tracer_provider=tracer_provider,
action_executor_class=ActionExecutor,
validation_action_class=ValidationAction,
)


Expand Down
32 changes: 30 additions & 2 deletions rasa_sdk/tracing/instrumentation/attribute_extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json

from typing import Any, Dict, Text
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.forms import ValidationAction
from rasa_sdk.types import ActionCall, DomainDict
from rasa_sdk import Tracker


# This file contains all attribute extractors for tracing instrumentation.
Expand Down Expand Up @@ -28,3 +32,27 @@ def extract_attrs_for_action_executor(
attributes["action_name"] = action_name

return attributes


def extract_attrs_for_validation_action(
self: ValidationAction,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> Dict[Text, Any]:
"""Extract the attributes for `ValidationAction.run`.
:param self: The `ValidationAction` on which `run` is called.
:param dispatcher: The `CollectingDispatcher` argument.
:param tracker: The `Tracker` argument.
:param domain: The `DomainDict` argument.
:return: A dictionary containing the attributes.
"""
slots_to_validate = tracker.slots_to_validate().keys()

return {
"class_name": self.__class__.__name__,
"sender_id": tracker.sender_id,
"slots_to_validate": json.dumps(list(slots_to_validate)),
"action_name": self.name(),
}
24 changes: 21 additions & 3 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk.tracing.instrumentation import attribute_extractors

# The `TypeVar` representing the return type for a function to be wrapped.
Expand Down Expand Up @@ -70,9 +71,11 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S:
if attr_extractor and should_extract_args
else {}
)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs
):
if issubclass(self.__class__, ValidationAction):
span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}"
else:
span_name = f"{self.__class__.__name__}.{fn.__name__}"
with tracer.start_as_current_span(span_name, attributes=attrs):
return await fn(self, *args, **kwargs)

return async_wrapper
Expand Down Expand Up @@ -109,18 +112,22 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S:


ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor)
ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction)


def instrument(
tracer_provider: TracerProvider,
action_executor_class: Optional[Type[ActionExecutorType]] = None,
validation_action_class: Optional[Type[ValidationActionType]] = None,
) -> None:
"""Substitute methods to be traced by their traced counterparts.
:param tracer_provider: The `TracerProvider` to be used for configuring tracing
on the substituted methods.
:param action_executor_class: The `ActionExecutor` to be instrumented. If `None`
is given, no `ActionExecutor` will be instrumented.
:param validation_action_class: The `ValidationAction` to be instrumented. If `None`
is given, no `ValidationAction` will be instrumented.
"""
if action_executor_class is not None and not class_is_instrumented(
action_executor_class
Expand All @@ -133,6 +140,17 @@ def instrument(
)
mark_class_as_instrumented(action_executor_class)

if validation_action_class is not None and not class_is_instrumented(
validation_action_class
):
_instrument_method(
tracer_provider.get_tracer(validation_action_class.__module__),
validation_action_class,
"run",
attribute_extractors.extract_attrs_for_validation_action,
)
mark_class_as_instrumented(validation_action_class)


def _instrument_method(
tracer: Tracer,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@ def test_server_health_returns_200():
def test_server_list_actions_returns_200():
request, response = app.test_client.get("/actions")
assert response.status == 200
assert len(response.json) == 3
assert len(response.json) == 4

# ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS
expected = [
# defined in tests/test_actions.py
{"name": "custom_async_action"},
{"name": "custom_action"},
{"name": "custom_action_exception"},
# defined in tests/tracing/instrumentation/conftest.py
{"name": "mock_validation_action"},
]
assert response.json == expected


def test_server_webhook_unknown_action_returns_404():
Expand Down
33 changes: 31 additions & 2 deletions tests/tracing/instrumentation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.executor import ActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.forms import ValidationAction
from rasa_sdk.types import ActionCall, DomainDict
from rasa_sdk import Tracker


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -44,3 +46,30 @@ def fail_if_undefined(self, method_name: Text) -> None:

async def run(self, action_call: ActionCall) -> None:
pass


class MockValidationAction(ValidationAction):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: Text) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def run(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
pass

def name(self) -> Text:
return "mock_validation_action"
63 changes: 63 additions & 0 deletions tests/tracing/instrumentation/test_validation_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List, Sequence

import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import MockValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.events import SlotSet, EventType


@pytest.mark.parametrize(
"events, expected_slots_to_validate",
[
([], "[]"),
(
[SlotSet("name", "Tom"), SlotSet("address", "Berlin")],
'["name", "address"]',
),
],
)
@pytest.mark.asyncio
async def test_tracing_action_executor_run(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
events: List[EventType],
expected_slots_to_validate: str,
) -> None:
component_class = MockValidationAction

instrumentation.instrument(
tracer_provider,
validation_action_class=component_class,
)

mock_validation_action = component_class()
dispatcher = CollectingDispatcher()
tracker = Tracker.from_dict({"sender_id": "test", "events": events})

await mock_validation_action.run(dispatcher, tracker, {})

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
assert num_captured_spans == 1

captured_span = captured_spans[-1]

assert captured_span.name == "ValidationAction.MockValidationAction.run"

expected_attributes = {
"class_name": component_class.__name__,
"sender_id": "test",
"slots_to_validate": expected_slots_to_validate,
"action_name": "mock_validation_action",
}

assert captured_span.attributes == expected_attributes

0 comments on commit 71489ca

Please sign in to comment.