diff --git a/router/src/logging.rs b/router/src/logging.rs index 5a98ef57b93..e5721f26d36 100644 --- a/router/src/logging.rs +++ b/router/src/logging.rs @@ -1,13 +1,68 @@ +use axum::{extract::Request, middleware::Next, response::Response}; use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::trace; use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::Resource; +use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId}; +use opentelemetry::Context; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; +struct TraceParent { + #[allow(dead_code)] + version: u8, + trace_id: TraceId, + parent_id: SpanId, + trace_flags: TraceFlags, +} + +fn parse_traceparent(header_value: &str) -> Option { + let parts: Vec<&str> = header_value.split('-').collect(); + if parts.len() != 4 { + return None; + } + + let version = u8::from_str_radix(parts[0], 16).ok()?; + if version == 0xff { + return None; + } + + let trace_id = TraceId::from_hex(parts[1]).ok()?; + let parent_id = SpanId::from_hex(parts[2]).ok()?; + let trace_flags = u8::from_str_radix(parts[3], 16).ok()?; + + Some(TraceParent { + version, + trace_id, + parent_id, + trace_flags: TraceFlags::new(trace_flags), + }) +} + +pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response { + let context = request + .headers() + .get("traceparent") + .and_then(|v| v.to_str().ok()) + .and_then(parse_traceparent) + .map(|traceparent| { + Context::new().with_remote_span_context(SpanContext::new( + traceparent.trace_id, + traceparent.parent_id, + traceparent.trace_flags, + true, + Default::default(), + )) + }); + + request.extensions_mut().insert(context); + + next.run(request).await +} + /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// - otlp_endpoint is an optional URL to an Open Telemetry collector /// - otlp_service_name service name to appear in APM diff --git a/router/src/server.rs b/router/src/server.rs index 6001e2dd09c..5737631e453 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -6,6 +6,7 @@ use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; +use crate::logging::trace_context_middleware; use crate::sagemaker::{ sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse, __path_sagemaker_compatibility, @@ -61,6 +62,7 @@ use tokio::sync::oneshot; use tokio::time::Instant; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -126,6 +128,7 @@ pub(crate) async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, compute_type: Extension, + context: Extension>, Json(mut req): Json, ) -> Result)> { // default return_full_text given the pipeline_tag @@ -135,11 +138,14 @@ pub(crate) async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer, compute_type, Json(req.into())) - .await - .into_response()) + Ok( + generate_stream(infer, compute_type, context, Json(req.into())) + .await + .into_response(), + ) } else { - let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?; + let (headers, Json(generation)) = + generate(infer, compute_type, context, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation])).into_response()) } @@ -268,9 +274,14 @@ seed, async fn generate( infer: Extension, Extension(ComputeType(compute_type)): Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + generate_internal(infer, ComputeType(compute_type), Json(req), span).await } @@ -464,12 +475,17 @@ seed, async fn generate_stream( Extension(infer): Extension, Extension(compute_type): Extension, + Extension(context): Extension>, Json(req): Json, ) -> ( HeaderMap, Sse>>, ) { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + let (headers, response_stream) = generate_stream_internal(infer, compute_type, Json(req), span).await; @@ -699,9 +715,14 @@ pub(crate) async fn completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { @@ -1223,9 +1244,14 @@ pub(crate) async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, + Extension(context): Extension>, Json(chat): Json, ) -> Result)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + metrics::counter!("tgi_request_count").increment(1); let ChatRequest { stream, @@ -2388,6 +2414,7 @@ async fn start( .layer(Extension(prom_handle.clone())) .layer(OtelAxumLayer::default()) .layer(DefaultBodyLimit::max(payload_limit)) + .layer(axum::middleware::from_fn(trace_context_middleware)) .layer(cors_layer); tracing::info!("Connected"); diff --git a/router/src/vertex.rs b/router/src/vertex.rs index a532c9eca8d..4cf5f5c73e4 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -7,6 +7,7 @@ use axum::response::{IntoResponse, Response}; use axum::Json; use serde::{Deserialize, Serialize}; use tracing::instrument; +use tracing_opentelemetry::OpenTelemetrySpanExt; use utoipa::ToSchema; #[derive(Clone, Deserialize, ToSchema)] @@ -70,9 +71,14 @@ example = json ! ({"error": "Incomplete generation"})), pub(crate) async fn vertex_compatibility( Extension(infer): Extension, Extension(compute_type): Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + metrics::counter!("tgi_request_count").increment(1); // check that theres at least one instance