From b41e9ad733ad3130a7538816614de7c6ab040c7d Mon Sep 17 00:00:00 2001 From: Radovan Zivkovic Date: Tue, 2 Jul 2024 11:49:26 +0200 Subject: [PATCH] Add gRPC tracing --- rasa_sdk/endpoint.py | 29 +++++++++++++++++++++++++---- rasa_sdk/grpc_server.py | 35 ++++++++++++++++++++++++++--------- rasa_sdk/tracing/utils.py | 23 +++++++++++------------ tests/tracing/test_utils.py | 12 ++++++++---- 4 files changed, 70 insertions(+), 29 deletions(-) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 0a0326042..adce34367 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -6,7 +6,7 @@ import zlib import json from functools import partial -from typing import List, Text, Union, Optional +from typing import List, Text, Union, Optional, Any from ssl import SSLContext from sanic import Sanic, response from sanic.response import HTTPResponse @@ -127,8 +127,11 @@ async def health(_) -> HTTPResponse: @app.post("/webhook") async def webhook(request: Request) -> HTTPResponse: """Webhook to retrieve action calls.""" - tracer, context, span_name = get_tracer_and_context( - request.app.ctx.tracer_provider, request + span_name = "create_app.webhook" + tracer, context = get_tracer_and_context( + span_name=span_name, + tracer_provider=request.app.ctx.tracer_provider, + tracing_carrier=request.headers, ) with tracer.start_as_current_span(span_name, context=context) as span: @@ -162,7 +165,12 @@ async def webhook(request: Request) -> HTTPResponse: body = {"error": e.message, "action_name": e.action_name} return response.json(body, status=449) - set_span_attributes(span, action_call) + set_http_span_attributes( + span, + action_call, + http_method="POST", + route="/webhook", + ) return response.json(result, status=200) @@ -238,6 +246,19 @@ def run( ) +def set_http_span_attributes( + span: Any, + action_call: dict, + http_method: str, + route: str, +) -> None: + """Sets http span attributes.""" + set_span_attributes(span, action_call) + if span.is_recording(): + span.set_attribute("http.method", http_method) + span.set_attribute("http.route", route) + + if __name__ == "__main__": import rasa_sdk.__main__ diff --git a/rasa_sdk/grpc_server.py b/rasa_sdk/grpc_server.py index 1b71060ee..f0396c59c 100644 --- a/rasa_sdk/grpc_server.py +++ b/rasa_sdk/grpc_server.py @@ -7,11 +7,11 @@ import grpc import logging import types -from typing import Union, Optional +from typing import Union, Optional, Any, Dict from concurrent import futures from grpc import aio from google.protobuf.json_format import MessageToDict, ParseDict -from opentelemetry import trace +from multidict import MultiDict from rasa_sdk.constants import ( DEFAULT_SERVER_PORT, @@ -43,6 +43,8 @@ from rasa_sdk.tracing.utils import ( get_tracer_provider, TracerProvider, + get_tracer_and_context, + set_span_attributes, ) from rasa_sdk.utils import ( check_version_compatibility, @@ -135,13 +137,19 @@ async def Webhook( gRPC response. """ span_name = "GRPCActionServerWebhook.Webhook" - tracer = ( - self.tracer_provider.get_tracer(span_name) - if self.tracer_provider - else trace.get_tracer(span_name) + invocation_metadata = context.invocation_metadata() + + def convert_list_tuple_to_multidict() -> MultiDict: + """Convert list of tuples to multidict.""" + return MultiDict(invocation_metadata) + + tracer, tracing_context = get_tracer_and_context( + span_name=span_name, + tracer_provider=self.tracer_provider, + tracing_carrier=convert_list_tuple_to_multidict(), ) - with tracer.start_as_current_span(span_name): + with tracer.start_as_current_span(span_name, context=tracing_context) as span: check_version_compatibility(request.version) if self.auto_reload: self.executor.reload() @@ -179,12 +187,21 @@ async def Webhook( return action_webhook_pb2.WebhookResponse() if not result: return action_webhook_pb2.WebhookResponse() - # set_span_attributes(span, request) - response = action_webhook_pb2.WebhookResponse() + set_grpc_span_attributes(span, action_call, method_name="Webhook") + response = action_webhook_pb2.WebhookResponse() return ParseDict(result, response) +def set_grpc_span_attributes( + span: Any, action_call: Dict[str, Any], method_name: str +) -> None: + """Sets grpc span attributes.""" + set_span_attributes(span, action_call) + if span.is_recording(): + span.set_attribute("grpc.method", method_name) + + def get_signal_name(signal_number: int) -> str: """Return the signal name for the given signal number.""" return signal.Signals(signal_number).name diff --git a/rasa_sdk/tracing/utils.py b/rasa_sdk/tracing/utils.py index 18c848439..6b10f82bc 100644 --- a/rasa_sdk/tracing/utils.py +++ b/rasa_sdk/tracing/utils.py @@ -1,11 +1,12 @@ +from multidict import MultiDict + from rasa_sdk.tracing import config from opentelemetry import trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.sdk.trace import TracerProvider -from sanic.request import Request -from typing import Optional, Tuple, Any, Text, Union +from typing import Optional, Tuple, Any def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]: @@ -17,26 +18,24 @@ def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]: def get_tracer_and_context( - tracer_provider: Optional[TracerProvider], request: Union[Request] -) -> Tuple[Any, Any, Text]: + span_name: str, + tracer_provider: Optional[TracerProvider], + tracing_carrier: MultiDict, +) -> Tuple[Any, Any]: """Gets tracer and context.""" - span_name = "create_app.webhook" - if tracer_provider is None: tracer = trace.get_tracer(span_name) context = None else: tracer = tracer_provider.get_tracer(span_name) - context = TraceContextTextMapPropagator().extract(request.headers) - return (tracer, context, span_name) + context = TraceContextTextMapPropagator().extract(tracing_carrier) + return tracer, context def set_span_attributes(span: Any, action_call: dict) -> None: """Sets span attributes.""" tracker = action_call.get("tracker", {}) - set_span_attributes = { - "http.method": "POST", - "http.route": "/webhook", + span_attributes = { "next_action": action_call.get("next_action"), "version": action_call.get("version"), "sender_id": tracker.get("sender_id"), @@ -44,7 +43,7 @@ def set_span_attributes(span: Any, action_call: dict) -> None: } if span.is_recording(): - for key, value in set_span_attributes.items(): + for key, value in span_attributes.items(): span.set_attribute(key, value) return None diff --git a/tests/tracing/test_utils.py b/tests/tracing/test_utils.py index 00683c040..ef7b82ee1 100644 --- a/tests/tracing/test_utils.py +++ b/tests/tracing/test_utils.py @@ -44,7 +44,8 @@ def test_get_tracer_provider_returns_none_if_tracing_is_not_configured() -> None def test_get_tracer_provider_returns_provider() -> None: """Tests that get_tracer_provider returns a TracerProvider - if tracing is configured.""" + if tracing is configured. + """ parser = argparse.ArgumentParser() parser.add_argument("--endpoints", type=str, default=None) @@ -58,7 +59,7 @@ def test_get_tracer_provider_returns_provider() -> None: def test_get_tracer_and_context() -> None: - """Tests that get_tracer_and_context returns a ProxyTracer and span name""" + """Tests that get_tracer_and_context returns a ProxyTracer and span name.""" data = { "next_action": "custom_action", "version": "1.0.0", @@ -70,8 +71,11 @@ def test_get_tracer_and_context() -> None: } app = ep.create_app(None) request, _ = app.test_client.post("/webhook", data=json.dumps(data)) - tracer, context, span_name = get_tracer_and_context(None, request) + tracer, context = get_tracer_and_context( + span_name="create_app.webhook", + tracer_provider=None, + tracing_carrier=request.headers, + ) assert isinstance(tracer, ProxyTracer) - assert span_name == "create_app.webhook" assert context is None