diff --git a/changelog/1117.bugfix.md b/changelog/1117.bugfix.md new file mode 100644 index 000000000..f8414f7d3 --- /dev/null +++ b/changelog/1117.bugfix.md @@ -0,0 +1 @@ +Tracing is supported for actions called over gRPC protocol. \ No newline at end of file diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 46b978621..a310ebe07 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -6,9 +6,12 @@ import zlib import json from functools import partial -from typing import List, Text, Union, Optional +from typing import List, Text, Union, Optional, Any from ssl import SSLContext + +from multidict import MultiDict from sanic import Sanic, response +from sanic.compat import Header from sanic.response import HTTPResponse from sanic.worker.loader import AppLoader @@ -127,8 +130,21 @@ async def health(_) -> HTTPResponse: @app.post("/webhook") async def webhook(request: Request) -> HTTPResponse: """Webhook to retrieve action calls.""" - tracer, context, span_name = get_tracer_and_context( - request.app.ctx.tracer_provider, request + span_name = "create_app.webhook" + + def header_to_multi_dict(headers: Header) -> MultiDict: + return MultiDict( + [ + (key, value) + for key, value in headers.items() + if key.lower() not in ("content-length", "content-encoding") + ] + ) + + tracer, context = get_tracer_and_context( + span_name=span_name, + tracer_provider=request.app.ctx.tracer_provider, + tracing_carrier=header_to_multi_dict(request.headers), ) with tracer.start_as_current_span(span_name, context=context) as span: @@ -162,7 +178,12 @@ async def webhook(request: Request) -> HTTPResponse: body = {"error": e.message, "action_name": e.action_name} return response.json(body, status=449) - set_span_attributes(span, action_call) + set_http_span_attributes( + span, + action_call, + http_method="POST", + route="/webhook", + ) return response.json(result, status=200) @@ -238,6 +259,19 @@ def run( ) +def set_http_span_attributes( + span: Any, + action_call: dict, + http_method: str, + route: str, +) -> None: + """Sets http span attributes.""" + set_span_attributes(span, action_call) + if span.is_recording(): + span.set_attribute("http.method", http_method) + span.set_attribute("http.route", route) + + if __name__ == "__main__": import rasa_sdk.__main__ diff --git a/rasa_sdk/grpc_server.py b/rasa_sdk/grpc_server.py index 1b71060ee..71a9fc681 100644 --- a/rasa_sdk/grpc_server.py +++ b/rasa_sdk/grpc_server.py @@ -7,11 +7,12 @@ import grpc import logging import types -from typing import Union, Optional +from typing import Union, Optional, Any, Dict from concurrent import futures from grpc import aio from google.protobuf.json_format import MessageToDict, ParseDict -from opentelemetry import trace +from grpc.aio import Metadata +from multidict import MultiDict from rasa_sdk.constants import ( DEFAULT_SERVER_PORT, @@ -43,6 +44,8 @@ from rasa_sdk.tracing.utils import ( get_tracer_provider, TracerProvider, + get_tracer_and_context, + set_span_attributes, ) from rasa_sdk.utils import ( check_version_compatibility, @@ -135,13 +138,23 @@ async def Webhook( gRPC response. """ span_name = "GRPCActionServerWebhook.Webhook" - tracer = ( - self.tracer_provider.get_tracer(span_name) - if self.tracer_provider - else trace.get_tracer(span_name) + invocation_metadata = context.invocation_metadata() + + def convert_metadata_to_multidict( + metadata: Optional[Metadata], + ) -> Optional[MultiDict]: + """Convert list of tuples to multidict.""" + if not metadata: + return None + return MultiDict(metadata) + + tracer, tracing_context = get_tracer_and_context( + span_name=span_name, + tracer_provider=self.tracer_provider, + tracing_carrier=convert_metadata_to_multidict(invocation_metadata), ) - with tracer.start_as_current_span(span_name): + with tracer.start_as_current_span(span_name, context=tracing_context) as span: check_version_compatibility(request.version) if self.auto_reload: self.executor.reload() @@ -179,12 +192,21 @@ async def Webhook( return action_webhook_pb2.WebhookResponse() if not result: return action_webhook_pb2.WebhookResponse() - # set_span_attributes(span, request) - response = action_webhook_pb2.WebhookResponse() + set_grpc_span_attributes(span, action_call, method_name="Webhook") + response = action_webhook_pb2.WebhookResponse() return ParseDict(result, response) +def set_grpc_span_attributes( + span: Any, action_call: Dict[str, Any], method_name: str +) -> None: + """Sets grpc span attributes.""" + set_span_attributes(span, action_call) + if span.is_recording(): + span.set_attribute("grpc.method", method_name) + + def get_signal_name(signal_number: int) -> str: """Return the signal name for the given signal number.""" return signal.Signals(signal_number).name diff --git a/rasa_sdk/tracing/utils.py b/rasa_sdk/tracing/utils.py index 18c848439..719b9f1ac 100644 --- a/rasa_sdk/tracing/utils.py +++ b/rasa_sdk/tracing/utils.py @@ -1,11 +1,12 @@ +from multidict import MultiDict + from rasa_sdk.tracing import config from opentelemetry import trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.sdk.trace import TracerProvider -from sanic.request import Request -from typing import Optional, Tuple, Any, Text, Union +from typing import Optional, Tuple, Any def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]: @@ -17,26 +18,28 @@ def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]: def get_tracer_and_context( - tracer_provider: Optional[TracerProvider], request: Union[Request] -) -> Tuple[Any, Any, Text]: + span_name: str, + tracer_provider: Optional[TracerProvider], + tracing_carrier: Optional[MultiDict], +) -> Tuple[Any, Any]: """Gets tracer and context.""" - span_name = "create_app.webhook" - if tracer_provider is None: tracer = trace.get_tracer(span_name) context = None else: tracer = tracer_provider.get_tracer(span_name) - context = TraceContextTextMapPropagator().extract(request.headers) - return (tracer, context, span_name) + context = ( + TraceContextTextMapPropagator().extract(tracing_carrier) + if tracing_carrier + else None + ) + return tracer, context def set_span_attributes(span: Any, action_call: dict) -> None: """Sets span attributes.""" tracker = action_call.get("tracker", {}) - set_span_attributes = { - "http.method": "POST", - "http.route": "/webhook", + span_attributes = { "next_action": action_call.get("next_action"), "version": action_call.get("version"), "sender_id": tracker.get("sender_id"), @@ -44,7 +47,7 @@ def set_span_attributes(span: Any, action_call: dict) -> None: } if span.is_recording(): - for key, value in set_span_attributes.items(): + for key, value in span_attributes.items(): span.set_attribute(key, value) return None diff --git a/tests/tracing/test_utils.py b/tests/tracing/test_utils.py index 00683c040..ef7b82ee1 100644 --- a/tests/tracing/test_utils.py +++ b/tests/tracing/test_utils.py @@ -44,7 +44,8 @@ def test_get_tracer_provider_returns_none_if_tracing_is_not_configured() -> None def test_get_tracer_provider_returns_provider() -> None: """Tests that get_tracer_provider returns a TracerProvider - if tracing is configured.""" + if tracing is configured. + """ parser = argparse.ArgumentParser() parser.add_argument("--endpoints", type=str, default=None) @@ -58,7 +59,7 @@ def test_get_tracer_provider_returns_provider() -> None: def test_get_tracer_and_context() -> None: - """Tests that get_tracer_and_context returns a ProxyTracer and span name""" + """Tests that get_tracer_and_context returns a ProxyTracer and span name.""" data = { "next_action": "custom_action", "version": "1.0.0", @@ -70,8 +71,11 @@ def test_get_tracer_and_context() -> None: } app = ep.create_app(None) request, _ = app.test_client.post("/webhook", data=json.dumps(data)) - tracer, context, span_name = get_tracer_and_context(None, request) + tracer, context = get_tracer_and_context( + span_name="create_app.webhook", + tracer_provider=None, + tracing_carrier=request.headers, + ) assert isinstance(tracer, ProxyTracer) - assert span_name == "create_app.webhook" assert context is None