Skip to content

Commit

Permalink
Merge pull request #1117 from RasaHQ/improvement/ATO-2609-add-grpc-tr…
Browse files Browse the repository at this point in the history
…acing

[ATO-2609] Add tracing for custom actions invoked over gRPC
  • Loading branch information
radovanZRasa authored Jul 2, 2024
2 parents d647767 + 03dd122 commit 02840a3
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog/1117.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Tracing is supported for actions called over gRPC protocol.
42 changes: 38 additions & 4 deletions rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import zlib
import json
from functools import partial
from typing import List, Text, Union, Optional
from typing import List, Text, Union, Optional, Any
from ssl import SSLContext

from multidict import MultiDict
from sanic import Sanic, response
from sanic.compat import Header
from sanic.response import HTTPResponse
from sanic.worker.loader import AppLoader

Expand Down Expand Up @@ -127,8 +130,21 @@ async def health(_) -> HTTPResponse:
@app.post("/webhook")
async def webhook(request: Request) -> HTTPResponse:
"""Webhook to retrieve action calls."""
tracer, context, span_name = get_tracer_and_context(
request.app.ctx.tracer_provider, request
span_name = "create_app.webhook"

def header_to_multi_dict(headers: Header) -> MultiDict:
return MultiDict(
[
(key, value)
for key, value in headers.items()
if key.lower() not in ("content-length", "content-encoding")
]
)

tracer, context = get_tracer_and_context(
span_name=span_name,
tracer_provider=request.app.ctx.tracer_provider,
tracing_carrier=header_to_multi_dict(request.headers),
)

with tracer.start_as_current_span(span_name, context=context) as span:
Expand Down Expand Up @@ -162,7 +178,12 @@ async def webhook(request: Request) -> HTTPResponse:
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=449)

set_span_attributes(span, action_call)
set_http_span_attributes(
span,
action_call,
http_method="POST",
route="/webhook",
)

return response.json(result, status=200)

Expand Down Expand Up @@ -238,6 +259,19 @@ def run(
)


def set_http_span_attributes(
span: Any,
action_call: dict,
http_method: str,
route: str,
) -> None:
"""Sets http span attributes."""
set_span_attributes(span, action_call)
if span.is_recording():
span.set_attribute("http.method", http_method)
span.set_attribute("http.route", route)


if __name__ == "__main__":
import rasa_sdk.__main__

Expand Down
40 changes: 31 additions & 9 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import grpc
import logging
import types
from typing import Union, Optional
from typing import Union, Optional, Any, Dict
from concurrent import futures
from grpc import aio
from google.protobuf.json_format import MessageToDict, ParseDict
from opentelemetry import trace
from grpc.aio import Metadata
from multidict import MultiDict

from rasa_sdk.constants import (
DEFAULT_SERVER_PORT,
Expand Down Expand Up @@ -43,6 +44,8 @@
from rasa_sdk.tracing.utils import (
get_tracer_provider,
TracerProvider,
get_tracer_and_context,
set_span_attributes,
)
from rasa_sdk.utils import (
check_version_compatibility,
Expand Down Expand Up @@ -135,13 +138,23 @@ async def Webhook(
gRPC response.
"""
span_name = "GRPCActionServerWebhook.Webhook"
tracer = (
self.tracer_provider.get_tracer(span_name)
if self.tracer_provider
else trace.get_tracer(span_name)
invocation_metadata = context.invocation_metadata()

def convert_metadata_to_multidict(
metadata: Optional[Metadata],
) -> Optional[MultiDict]:
"""Convert list of tuples to multidict."""
if not metadata:
return None
return MultiDict(metadata)

tracer, tracing_context = get_tracer_and_context(
span_name=span_name,
tracer_provider=self.tracer_provider,
tracing_carrier=convert_metadata_to_multidict(invocation_metadata),
)

with tracer.start_as_current_span(span_name):
with tracer.start_as_current_span(span_name, context=tracing_context) as span:
check_version_compatibility(request.version)
if self.auto_reload:
self.executor.reload()
Expand Down Expand Up @@ -179,12 +192,21 @@ async def Webhook(
return action_webhook_pb2.WebhookResponse()
if not result:
return action_webhook_pb2.WebhookResponse()
# set_span_attributes(span, request)
response = action_webhook_pb2.WebhookResponse()

set_grpc_span_attributes(span, action_call, method_name="Webhook")
response = action_webhook_pb2.WebhookResponse()
return ParseDict(result, response)


def set_grpc_span_attributes(
span: Any, action_call: Dict[str, Any], method_name: str
) -> None:
"""Sets grpc span attributes."""
set_span_attributes(span, action_call)
if span.is_recording():
span.set_attribute("grpc.method", method_name)


def get_signal_name(signal_number: int) -> str:
"""Return the signal name for the given signal number."""
return signal.Signals(signal_number).name
Expand Down
27 changes: 15 additions & 12 deletions rasa_sdk/tracing/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from multidict import MultiDict

from rasa_sdk.tracing import config
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from opentelemetry.sdk.trace import TracerProvider
from sanic.request import Request

from typing import Optional, Tuple, Any, Text, Union
from typing import Optional, Tuple, Any


def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]:
Expand All @@ -17,34 +18,36 @@ def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]:


def get_tracer_and_context(
tracer_provider: Optional[TracerProvider], request: Union[Request]
) -> Tuple[Any, Any, Text]:
span_name: str,
tracer_provider: Optional[TracerProvider],
tracing_carrier: Optional[MultiDict],
) -> Tuple[Any, Any]:
"""Gets tracer and context."""
span_name = "create_app.webhook"

if tracer_provider is None:
tracer = trace.get_tracer(span_name)
context = None
else:
tracer = tracer_provider.get_tracer(span_name)
context = TraceContextTextMapPropagator().extract(request.headers)
return (tracer, context, span_name)
context = (
TraceContextTextMapPropagator().extract(tracing_carrier)
if tracing_carrier
else None
)
return tracer, context


def set_span_attributes(span: Any, action_call: dict) -> None:
"""Sets span attributes."""
tracker = action_call.get("tracker", {})
set_span_attributes = {
"http.method": "POST",
"http.route": "/webhook",
span_attributes = {
"next_action": action_call.get("next_action"),
"version": action_call.get("version"),
"sender_id": tracker.get("sender_id"),
"message_id": tracker.get("latest_message", {}).get("message_id"),
}

if span.is_recording():
for key, value in set_span_attributes.items():
for key, value in span_attributes.items():
span.set_attribute(key, value)

return None
12 changes: 8 additions & 4 deletions tests/tracing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def test_get_tracer_provider_returns_none_if_tracing_is_not_configured() -> None

def test_get_tracer_provider_returns_provider() -> None:
"""Tests that get_tracer_provider returns a TracerProvider
if tracing is configured."""
if tracing is configured.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--endpoints", type=str, default=None)

Expand All @@ -58,7 +59,7 @@ def test_get_tracer_provider_returns_provider() -> None:


def test_get_tracer_and_context() -> None:
"""Tests that get_tracer_and_context returns a ProxyTracer and span name"""
"""Tests that get_tracer_and_context returns a ProxyTracer and span name."""
data = {
"next_action": "custom_action",
"version": "1.0.0",
Expand All @@ -70,8 +71,11 @@ def test_get_tracer_and_context() -> None:
}
app = ep.create_app(None)
request, _ = app.test_client.post("/webhook", data=json.dumps(data))
tracer, context, span_name = get_tracer_and_context(None, request)
tracer, context = get_tracer_and_context(
span_name="create_app.webhook",
tracer_provider=None,
tracing_carrier=request.headers,
)

assert isinstance(tracer, ProxyTracer)
assert span_name == "create_app.webhook"
assert context is None

0 comments on commit 02840a3

Please sign in to comment.