From 9216c026546aeed8cdbd884109aa061765aebf7e Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Fri, 9 Feb 2024 13:56:40 +0100 Subject: [PATCH 1/6] instrument ValidationAction.run --- rasa_sdk/tracing/config.py | 3 +- .../instrumentation/attribute_extractors.py | 30 +++++++++++++++++-- .../instrumentation/instrumentation.py | 24 +++++++++++++-- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/rasa_sdk/tracing/config.py b/rasa_sdk/tracing/config.py index 46c64a408..a0079a0b1 100644 --- a/rasa_sdk/tracing/config.py +++ b/rasa_sdk/tracing/config.py @@ -14,7 +14,7 @@ from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config from rasa_sdk.tracing.instrumentation import instrumentation from rasa_sdk.executor import ActionExecutor - +from rasa_sdk.forms import ValidationAction TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk") @@ -38,6 +38,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None: instrumentation.instrument( tracer_provider=tracer_provider, action_executor_class=ActionExecutor, + validation_action_class=ValidationAction, ) diff --git a/rasa_sdk/tracing/instrumentation/attribute_extractors.py b/rasa_sdk/tracing/instrumentation/attribute_extractors.py index 8dbcc7954..74804d600 100644 --- a/rasa_sdk/tracing/instrumentation/attribute_extractors.py +++ b/rasa_sdk/tracing/instrumentation/attribute_extractors.py @@ -1,6 +1,8 @@ from typing import Any, Dict, Text -from rasa_sdk.executor import ActionExecutor -from rasa_sdk.types import ActionCall +from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from rasa_sdk.forms import ValidationAction +from rasa_sdk.types import ActionCall, DomainDict +from rasa_sdk import Tracker # This file contains all attribute extractors for tracing instrumentation. @@ -28,3 +30,27 @@ def extract_attrs_for_action_executor( attributes["action_name"] = action_name return attributes + + +def extract_attrs_for_validation_action( + self: ValidationAction, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", +) -> Dict[Text, Any]: + """Extract the attributes for `ValidationAction.run`. + + :param self: The `ValidationAction` on which `run` is called. + :param dispatcher: The `CollectingDispatcher` argument. + :param tracker: The `Tracker` argument. + :param domain: The `DomainDict` argument. + :return: A dictionary containing the attributes. + """ + slots_to_validate = tracker.slots_to_validate().keys() + + return { + "class_name": self.__class__.__name__, + "sender_id": tracker.sender_id, + "slots_to_validate": str(slots_to_validate), + "action_name": self.name(), + } diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index dcc92b214..9e84a4c5d 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -15,6 +15,7 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.trace import Tracer from rasa_sdk.executor import ActionExecutor +from rasa_sdk.forms import ValidationAction from rasa_sdk.tracing.instrumentation import attribute_extractors # The `TypeVar` representing the return type for a function to be wrapped. @@ -70,9 +71,11 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S: if attr_extractor and should_extract_args else {} ) - with tracer.start_as_current_span( - f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs - ): + if issubclass(self.__class__, ValidationAction): + span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}" + else: + span_name = f"{self.__class__.__name__}.{fn.__name__}" + with tracer.start_as_current_span(span_name, attributes=attrs): return await fn(self, *args, **kwargs) return async_wrapper @@ -109,11 +112,13 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S: ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor) +ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction) def instrument( tracer_provider: TracerProvider, action_executor_class: Optional[Type[ActionExecutorType]] = None, + validation_action_class: Optional[Type[ValidationActionType]] = None, ) -> None: """Substitute methods to be traced by their traced counterparts. @@ -121,6 +126,8 @@ def instrument( on the substituted methods. :param action_executor_class: The `ActionExecutor` to be instrumented. If `None` is given, no `ActionExecutor` will be instrumented. + :param validation_action_class: The `ValidationAction` to be instrumented. If `None` + is given, no `ValidationAction` will be instrumented. """ if action_executor_class is not None and not class_is_instrumented( action_executor_class @@ -133,6 +140,17 @@ def instrument( ) mark_class_as_instrumented(action_executor_class) + if validation_action_class is not None and not class_is_instrumented( + validation_action_class + ): + _instrument_method( + tracer_provider.get_tracer(validation_action_class.__module__), + validation_action_class, + "run", + attribute_extractors.extract_attrs_for_validation_action, + ) + mark_class_as_instrumented(validation_action_class) + def _instrument_method( tracer: Tracer, From 3645b862621302b7b05f7227adb61d949a3286fe Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Fri, 9 Feb 2024 13:56:52 +0100 Subject: [PATCH 2/6] add test --- tests/tracing/instrumentation/conftest.py | 33 +++++++++- .../instrumentation/test_validation_action.py | 66 +++++++++++++++++++ 2 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 tests/tracing/instrumentation/test_validation_action.py diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index 3a7cecd80..6435c1859 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -5,8 +5,10 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from rasa_sdk.executor import ActionExecutor -from rasa_sdk.types import ActionCall +from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from rasa_sdk.forms import ValidationAction +from rasa_sdk.types import ActionCall, DomainDict +from rasa_sdk import Tracker @pytest.fixture(scope="session") @@ -44,3 +46,30 @@ def fail_if_undefined(self, method_name: Text) -> None: async def run(self, action_call: ActionCall) -> None: pass + + +class MockValidationAction(ValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: Text) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def run( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + pass + + def name(self) -> Text: + return "mock_validation_action" diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py new file mode 100644 index 000000000..f8ffebdeb --- /dev/null +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -0,0 +1,66 @@ +from typing import Sequence, Optional + +import pytest +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from rasa_sdk.tracing.instrumentation import instrumentation +from tests.tracing.instrumentation.conftest import ( + MockValidationAction, + # MockValidationActionSubClass, +) +from rasa_sdk import Tracker +from rasa_sdk.executor import CollectingDispatcher +from rasa_sdk.events import SlotSet + + +@pytest.mark.parametrize( + "events, expected_slots_to_validate", + [ + ([], "dict_keys([])"), + ( + [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], + "dict_keys(['name', 'address'])", + ), + ], +) +@pytest.mark.asyncio +async def test_tracing_action_executor_run( + tracer_provider: TracerProvider, + span_exporter: InMemorySpanExporter, + previous_num_captured_spans: int, + events: Optional[str], + expected_slots_to_validate: Optional[str], +) -> None: + component_class = MockValidationAction + + instrumentation.instrument( + tracer_provider, + validation_action_class=component_class, + ) + + mock_validation_action = component_class() + dispatcher = CollectingDispatcher() + tracker = Tracker.from_dict({"sender_id": "test", "events": events}) + + await mock_validation_action.run(dispatcher, tracker, {}) + + captured_spans: Sequence[ + ReadableSpan + ] = span_exporter.get_finished_spans() # type: ignore + + num_captured_spans = len(captured_spans) - previous_num_captured_spans + assert num_captured_spans == 1 + + captured_span = captured_spans[-1] + + assert captured_span.name == "ValidationAction.MockValidationAction.run" + + expected_attributes = { + "class_name": component_class.__name__, + "sender_id": "test", + "slots_to_validate": expected_slots_to_validate, + "action_name": "mock_validation_action", + } + + assert captured_span.attributes == expected_attributes From cd7a9123bdeb2167b8c4afc56c95fc31c0b4c08d Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Fri, 9 Feb 2024 14:09:38 +0100 Subject: [PATCH 3/6] fix failing test --- tests/test_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 712abdea1..6e7d0d465 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -23,7 +23,7 @@ def test_server_health_returns_200(): def test_server_list_actions_returns_200(): request, response = app.test_client.get("/actions") assert response.status == 200 - assert len(response.json) == 3 + assert len(response.json) == 4 def test_server_webhook_unknown_action_returns_404(): From 58a319996c2b34e6215f48d283833c50c8ecc6a6 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Fri, 9 Feb 2024 14:16:40 +0100 Subject: [PATCH 4/6] add changelog entry --- changelog/1074.improvement.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/1074.improvement.md diff --git a/changelog/1074.improvement.md b/changelog/1074.improvement.md new file mode 100644 index 000000000..278f6601d --- /dev/null +++ b/changelog/1074.improvement.md @@ -0,0 +1 @@ +Instrument `ValidationAction.run` method and extract attributes `class_name`, `sender_id`, `action_name` and `slots_to_validate`. \ No newline at end of file From 5a7c1b88b399f1929be45821259ed32e396f0f4c Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Fri, 9 Feb 2024 14:57:31 +0100 Subject: [PATCH 5/6] use json.dumps() instead of str() --- rasa_sdk/tracing/instrumentation/attribute_extractors.py | 4 +++- tests/tracing/instrumentation/test_validation_action.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/rasa_sdk/tracing/instrumentation/attribute_extractors.py b/rasa_sdk/tracing/instrumentation/attribute_extractors.py index 74804d600..80183ac3a 100644 --- a/rasa_sdk/tracing/instrumentation/attribute_extractors.py +++ b/rasa_sdk/tracing/instrumentation/attribute_extractors.py @@ -1,3 +1,5 @@ +import json + from typing import Any, Dict, Text from rasa_sdk.executor import ActionExecutor, CollectingDispatcher from rasa_sdk.forms import ValidationAction @@ -51,6 +53,6 @@ def extract_attrs_for_validation_action( return { "class_name": self.__class__.__name__, "sender_id": tracker.sender_id, - "slots_to_validate": str(slots_to_validate), + "slots_to_validate": json.dumps(list(slots_to_validate)), "action_name": self.name(), } diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index f8ffebdeb..5f745dd06 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -17,10 +17,10 @@ @pytest.mark.parametrize( "events, expected_slots_to_validate", [ - ([], "dict_keys([])"), + ([], "[]"), ( [SlotSet("name", "Tom"), SlotSet("address", "Berlin")], - "dict_keys(['name', 'address'])", + '["name", "address"]', ), ], ) From 66edd17496d58b8bd292fb737a69c38683236ac0 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Fri, 9 Feb 2024 16:20:19 +0100 Subject: [PATCH 6/6] implement PR feedback --- tests/test_endpoint.py | 11 +++++++++++ .../instrumentation/test_validation_action.py | 13 +++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 6e7d0d465..29678d3a9 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -25,6 +25,17 @@ def test_server_list_actions_returns_200(): assert response.status == 200 assert len(response.json) == 4 + # ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS + expected = [ + # defined in tests/test_actions.py + {"name": "custom_async_action"}, + {"name": "custom_action"}, + {"name": "custom_action_exception"}, + # defined in tests/tracing/instrumentation/conftest.py + {"name": "mock_validation_action"}, + ] + assert response.json == expected + def test_server_webhook_unknown_action_returns_404(): data = { diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index 5f745dd06..ce407d8db 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -1,17 +1,14 @@ -from typing import Sequence, Optional +from typing import List, Sequence import pytest from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.tracing.instrumentation.conftest import ( - MockValidationAction, - # MockValidationActionSubClass, -) +from tests.tracing.instrumentation.conftest import MockValidationAction from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.events import SlotSet +from rasa_sdk.events import SlotSet, EventType @pytest.mark.parametrize( @@ -29,8 +26,8 @@ async def test_tracing_action_executor_run( tracer_provider: TracerProvider, span_exporter: InMemorySpanExporter, previous_num_captured_spans: int, - events: Optional[str], - expected_slots_to_validate: Optional[str], + events: List[EventType], + expected_slots_to_validate: str, ) -> None: component_class = MockValidationAction