From 6317c25fcb36209ad1563133c6de047f85388f2b Mon Sep 17 00:00:00 2001 From: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:03:55 -0700 Subject: [PATCH] Add distributed tracing middleware --- python/langsmith/middleware.py | 40 +++++++++++++++++ python/langsmith/run_trees.py | 43 ++++++++++++++----- python/tests/integration_tests/fake_server.py | 4 +- 3 files changed, 74 insertions(+), 13 deletions(-) create mode 100644 python/langsmith/middleware.py diff --git a/python/langsmith/middleware.py b/python/langsmith/middleware.py new file mode 100644 index 000000000..a7dc5a93b --- /dev/null +++ b/python/langsmith/middleware.py @@ -0,0 +1,40 @@ +"""Middleware for making it easier to do distributed tracing.""" + + +class TracingMiddleware: + """Middleware for propagating distributed tracing context using LangSmith. + + This middleware checks for the 'langsmith-trace' header and propagates the + tracing context if present. It does not start new traces by default. + It is designed to work with ASGI applications. + + Attributes: + app: The ASGI application being wrapped. + """ + + def __init__(self, app): + """Initialize the middleware.""" + from langsmith.run_helpers import tracing_context # type: ignore + + self._with_headers = tracing_context + self.app = app + + async def __call__(self, scope: dict, receive, send): + """Handle incoming requests and propagate tracing context if applicable. + + Args: + scope: A dict containing ASGI connection scope. + receive: An awaitable callable for receiving ASGI events. + send: An awaitable callable for sending ASGI events. + + If the request is HTTP and contains the 'langsmith-trace' header, + it propagates the tracing context before calling the wrapped application. + Otherwise, it calls the application directly without modifying the context. + """ + if scope["type"] == "http" and "headers" in scope: + headers = dict(scope["headers"]) + if b"langsmith-trace" in headers: + with self._with_headers(parent=headers): + await self.app(scope, receive, send) + return + await self.app(scope, receive, send) diff --git a/python/langsmith/run_trees.py b/python/langsmith/run_trees.py index 85bb9bcd3..4bfae0e83 100644 --- a/python/langsmith/run_trees.py +++ b/python/langsmith/run_trees.py @@ -4,8 +4,9 @@ import json import logging +import sys from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast from uuid import UUID, uuid4 try: @@ -26,7 +27,11 @@ logger = logging.getLogger(__name__) LANGSMITH_PREFIX = "langsmith-" -LANGSMITH_DOTTED_ORDER = f"{LANGSMITH_PREFIX}trace" +LANGSMITH_DOTTED_ORDER = sys.intern(f"{LANGSMITH_PREFIX}trace") +LANGSMITH_DOTTED_ORDER_BYTES = LANGSMITH_DOTTED_ORDER.encode("utf-8") +LANGSMITH_METADATA = sys.intern(f"{LANGSMITH_PREFIX}metadata") +LANGSMITH_TAGS = sys.intern(f"{LANGSMITH_PREFIX}tags") +LANGSMITH_PROJECT = sys.intern(f"{LANGSMITH_PREFIX}project") _CLIENT: Optional[Client] = None _LOCK = threading.Lock() # Keeping around for a while for backwards compat @@ -332,9 +337,9 @@ def from_dotted_order( RunTree: The new span. """ headers = { - f"{LANGSMITH_DOTTED_ORDER}": dotted_order, + LANGSMITH_DOTTED_ORDER: dotted_order, } - return cast(RunTree, cls.from_headers(headers, **kwargs)) + return cast(RunTree, cls.from_headers(headers, **kwargs)) # type: ignore[arg-type] @classmethod def from_runnable_config( @@ -402,7 +407,9 @@ def from_runnable_config( return None @classmethod - def from_headers(cls, headers: Dict[str, str], **kwargs: Any) -> Optional[RunTree]: + def from_headers( + cls, headers: Mapping[Union[str, bytes], Union[str, bytes]], **kwargs: Any + ) -> Optional[RunTree]: """Create a new 'parent' span from the provided headers. Extracts parent span information from the headers and creates a new span. @@ -415,9 +422,14 @@ def from_headers(cls, headers: Dict[str, str], **kwargs: Any) -> Optional[RunTre """ init_args = kwargs.copy() - langsmith_trace = headers.get(f"{LANGSMITH_DOTTED_ORDER}") + langsmith_trace = cast(Optional[str], headers.get(LANGSMITH_DOTTED_ORDER)) if not langsmith_trace: - return # type: ignore[return-value] + langsmith_trace_bytes = cast( + Optional[bytes], headers.get(LANGSMITH_DOTTED_ORDER_BYTES) + ) + if not langsmith_trace_bytes: + return # type: ignore[return-value] + langsmith_trace = langsmith_trace_bytes.decode("utf-8") parent_dotted_order = langsmith_trace.strip() parsed_dotted_order = _parse_dotted_order(parent_dotted_order) @@ -436,7 +448,7 @@ def from_headers(cls, headers: Dict[str, str], **kwargs: Any) -> Optional[RunTre init_args["run_type"] = init_args.get("run_type") or "chain" init_args["name"] = init_args.get("name") or "parent" - baggage = _Baggage.from_header(headers.get("baggage")) + baggage = _Baggage.from_headers(headers) if baggage.metadata or baggage.tags: init_args["extra"] = init_args.setdefault("extra", {}) init_args["extra"]["metadata"] = init_args["extra"].setdefault( @@ -490,17 +502,26 @@ def from_header(cls, header_value: Optional[str]) -> _Baggage: try: for item in header_value.split(","): key, value = item.split("=", 1) - if key == f"{LANGSMITH_PREFIX}metadata": + if key == LANGSMITH_METADATA: metadata = json.loads(urllib.parse.unquote(value)) - elif key == f"{LANGSMITH_PREFIX}tags": + elif key == LANGSMITH_TAGS: tags = urllib.parse.unquote(value).split(",") - elif key == f"{LANGSMITH_PREFIX}project": + elif key == LANGSMITH_PROJECT: project_name = urllib.parse.unquote(value) except Exception as e: logger.warning(f"Error parsing baggage header: {e}") return cls(metadata=metadata, tags=tags, project_name=project_name) + @classmethod + def from_headers(cls, headers: Mapping[Union[str, bytes], Any]) -> _Baggage: + if "baggage" in headers: + return cls.from_header(headers["baggage"]) + elif b"baggage" in headers: + return cls.from_header(cast(bytes, headers[b"baggage"]).decode("utf-8")) + else: + return cls.from_header(None) + def to_header(self) -> str: """Return the Baggage object as a header value.""" items = [] diff --git a/python/tests/integration_tests/fake_server.py b/python/tests/integration_tests/fake_server.py index f42f328f2..c028103bb 100644 --- a/python/tests/integration_tests/fake_server.py +++ b/python/tests/integration_tests/fake_server.py @@ -1,9 +1,11 @@ from fastapi import FastAPI, Request from langsmith import traceable +from langsmith.middleware import TracingMiddleware from langsmith.run_helpers import get_current_run_tree, trace, tracing_context fake_app = FastAPI() +fake_app.add_middleware(TracingMiddleware) @traceable @@ -47,13 +49,11 @@ async def fake_route(request: Request): with trace( "Trace", project_name="Definitely-not-your-grandpas-project", - parent=request.headers, ): fake_function() fake_function_two( "foo", langsmith_extra={ - "parent": request.headers, "project_name": "Definitely-not-your-grandpas-project", }, )