diff --git a/changelog/1074.improvement.md b/changelog/1074.improvement.md new file mode 100644 index 000000000..278f6601d --- /dev/null +++ b/changelog/1074.improvement.md @@ -0,0 +1 @@ +Instrument `ValidationAction.run` method and extract attributes `class_name`, `sender_id`, `action_name` and `slots_to_validate`. \ No newline at end of file diff --git a/rasa_sdk/tracing/config.py b/rasa_sdk/tracing/config.py index 46c64a408..a0079a0b1 100644 --- a/rasa_sdk/tracing/config.py +++ b/rasa_sdk/tracing/config.py @@ -14,7 +14,7 @@ from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config from rasa_sdk.tracing.instrumentation import instrumentation from rasa_sdk.executor import ActionExecutor - +from rasa_sdk.forms import ValidationAction TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk") @@ -38,6 +38,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None: instrumentation.instrument( tracer_provider=tracer_provider, action_executor_class=ActionExecutor, + validation_action_class=ValidationAction, ) diff --git a/rasa_sdk/tracing/instrumentation/attribute_extractors.py b/rasa_sdk/tracing/instrumentation/attribute_extractors.py index 8dbcc7954..80183ac3a 100644 --- a/rasa_sdk/tracing/instrumentation/attribute_extractors.py +++ b/rasa_sdk/tracing/instrumentation/attribute_extractors.py @@ -1,6 +1,10 @@ +import json + from typing import Any, Dict, Text -from rasa_sdk.executor import ActionExecutor -from rasa_sdk.types import ActionCall +from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from rasa_sdk.forms import ValidationAction +from rasa_sdk.types import ActionCall, DomainDict +from rasa_sdk import Tracker # This file contains all attribute extractors for tracing instrumentation. @@ -28,3 +32,27 @@ def extract_attrs_for_action_executor( attributes["action_name"] = action_name return attributes + + +def extract_attrs_for_validation_action( + self: ValidationAction, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", +) -> Dict[Text, Any]: + """Extract the attributes for `ValidationAction.run`. + + :param self: The `ValidationAction` on which `run` is called. + :param dispatcher: The `CollectingDispatcher` argument. + :param tracker: The `Tracker` argument. + :param domain: The `DomainDict` argument. + :return: A dictionary containing the attributes. + """ + slots_to_validate = tracker.slots_to_validate().keys() + + return { + "class_name": self.__class__.__name__, + "sender_id": tracker.sender_id, + "slots_to_validate": json.dumps(list(slots_to_validate)), + "action_name": self.name(), + } diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index dcc92b214..9e84a4c5d 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -15,6 +15,7 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.trace import Tracer from rasa_sdk.executor import ActionExecutor +from rasa_sdk.forms import ValidationAction from rasa_sdk.tracing.instrumentation import attribute_extractors # The `TypeVar` representing the return type for a function to be wrapped. @@ -70,9 +71,11 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S: if attr_extractor and should_extract_args else {} ) - with tracer.start_as_current_span( - f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs - ): + if issubclass(self.__class__, ValidationAction): + span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}" + else: + span_name = f"{self.__class__.__name__}.{fn.__name__}" + with tracer.start_as_current_span(span_name, attributes=attrs): return await fn(self, *args, **kwargs) return async_wrapper @@ -109,11 +112,13 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S: ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor) +ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction) def instrument( tracer_provider: TracerProvider, action_executor_class: Optional[Type[ActionExecutorType]] = None, + validation_action_class: Optional[Type[ValidationActionType]] = None, ) -> None: """Substitute methods to be traced by their traced counterparts. @@ -121,6 +126,8 @@ def instrument( on the substituted methods. :param action_executor_class: The `ActionExecutor` to be instrumented. If `None` is given, no `ActionExecutor` will be instrumented. + :param validation_action_class: The `ValidationAction` to be instrumented. If `None` + is given, no `ValidationAction` will be instrumented. """ if action_executor_class is not None and not class_is_instrumented( action_executor_class @@ -133,6 +140,17 @@ def instrument( ) mark_class_as_instrumented(action_executor_class) + if validation_action_class is not None and not class_is_instrumented( + validation_action_class + ): + _instrument_method( + tracer_provider.get_tracer(validation_action_class.__module__), + validation_action_class, + "run", + attribute_extractors.extract_attrs_for_validation_action, + ) + mark_class_as_instrumented(validation_action_class) + def _instrument_method( tracer: Tracer, diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 712abdea1..29678d3a9 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -23,7 +23,18 @@ def test_server_health_returns_200(): def test_server_list_actions_returns_200(): request, response = app.test_client.get("/actions") assert response.status == 200 - assert len(response.json) == 3 + assert len(response.json) == 4 + + # ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS + expected = [ + # defined in tests/test_actions.py + {"name": "custom_async_action"}, + {"name": "custom_action"}, + {"name": "custom_action_exception"}, + # defined in tests/tracing/instrumentation/conftest.py + {"name": "mock_validation_action"}, + ] + assert response.json == expected def test_server_webhook_unknown_action_returns_404(): diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index 3a7cecd80..6435c1859 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -5,8 +5,10 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from rasa_sdk.executor import ActionExecutor -from rasa_sdk.types import ActionCall +from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from rasa_sdk.forms import ValidationAction +from rasa_sdk.types import ActionCall, DomainDict +from rasa_sdk import Tracker @pytest.fixture(scope="session") @@ -44,3 +46,30 @@ def fail_if_undefined(self, method_name: Text) -> None: async def run(self, action_call: ActionCall) -> None: pass + + +class MockValidationAction(ValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: Text) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def run( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + pass + + def name(self) -> Text: + return "mock_validation_action" diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py new file mode 100644 index 000000000..ce407d8db --- /dev/null +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -0,0 +1,63 @@ +from typing import List, Sequence + +import pytest +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from rasa_sdk.tracing.instrumentation import instrumentation +from tests.tracing.instrumentation.conftest import MockValidationAction +from rasa_sdk import Tracker +from rasa_sdk.executor import CollectingDispatcher +from rasa_sdk.events import SlotSet, EventType + + +@pytest.mark.parametrize( + "events, expected_slots_to_validate", + [ + ([], "[]"), + ( + [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], + '["name", "address"]', + ), + ], +) +@pytest.mark.asyncio +async def test_tracing_action_executor_run( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: List[EventType], + expected_slots_to_validate: str, +) -> None: + component_class = MockValidationAction + + instrumentation.instrument( + tracer_provider, + validation_action_class=component_class, + ) + + mock_validation_action = component_class() + dispatcher = CollectingDispatcher() + tracker = Tracker.from_dict({"sender_id": "test", "events": events}) + + await mock_validation_action.run(dispatcher, tracker, {}) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + assert num_captured_spans == 1 + + captured_span = captured_spans[-1] + + assert captured_span.name == "ValidationAction.MockValidationAction.run" + + expected_attributes = { + "class_name": component_class.__name__, + "sender_id": "test", + "slots_to_validate": expected_slots_to_validate, + "action_name": "mock_validation_action", + } + + assert captured_span.attributes == expected_attributes