From 64b0337574ec5500fadfb0923095e8fa170a9254 Mon Sep 17 00:00:00 2001 From: kozistr Date: Tue, 15 Oct 2024 15:14:20 +0900 Subject: [PATCH] feature: get trace id from req headers --- router/src/logging.rs | 55 +++++++++++++++++++++++++++++++++++++++++++ router/src/server.rs | 35 +++++++++++++++++++++++---- router/src/vertex.rs | 6 +++++ 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/router/src/logging.rs b/router/src/logging.rs index 5a98ef57b93..e5721f26d36 100644 --- a/router/src/logging.rs +++ b/router/src/logging.rs @@ -1,13 +1,68 @@ +use axum::{extract::Request, middleware::Next, response::Response}; use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::trace; use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::Resource; +use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId}; +use opentelemetry::Context; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; +struct TraceParent { + #[allow(dead_code)] + version: u8, + trace_id: TraceId, + parent_id: SpanId, + trace_flags: TraceFlags, +} + +fn parse_traceparent(header_value: &str) -> Option { + let parts: Vec<&str> = header_value.split('-').collect(); + if parts.len() != 4 { + return None; + } + + let version = u8::from_str_radix(parts[0], 16).ok()?; + if version == 0xff { + return None; + } + + let trace_id = TraceId::from_hex(parts[1]).ok()?; + let parent_id = SpanId::from_hex(parts[2]).ok()?; + let trace_flags = u8::from_str_radix(parts[3], 16).ok()?; + + Some(TraceParent { + version, + trace_id, + parent_id, + trace_flags: TraceFlags::new(trace_flags), + }) +} + +pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response { + let context = request + .headers() + .get("traceparent") + .and_then(|v| v.to_str().ok()) + .and_then(parse_traceparent) + .map(|traceparent| { + Context::new().with_remote_span_context(SpanContext::new( + traceparent.trace_id, + traceparent.parent_id, + traceparent.trace_flags, + true, + Default::default(), + )) + }); + + request.extensions_mut().insert(context); + + next.run(request).await +} + /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// - otlp_endpoint is an optional URL to an Open Telemetry collector /// - otlp_service_name service name to appear in APM diff --git a/router/src/server.rs b/router/src/server.rs index 5e6e696037e..3a45851b2f1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,6 +7,7 @@ use crate::kserve::{ kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kserve_model_metadata, kserve_model_metadata_ready, }; +use crate::logging::trace_context_middleware; use crate::validation::ValidationError; use crate::vertex::vertex_compatibility; use crate::ChatTokenizeResponse; @@ -57,6 +58,7 @@ use tokio::sync::oneshot; use tokio::time::Instant; use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing::{info_span, instrument, Instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -87,6 +89,7 @@ async fn compat_generate( Extension(default_return_full_text): Extension, infer: Extension, compute_type: Extension, + context: Extension>, Json(mut req): Json, ) -> Result)> { // default return_full_text given the pipeline_tag @@ -96,11 +99,14 @@ async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer, compute_type, Json(req.into())) - .await - .into_response()) + Ok( + generate_stream(infer, compute_type, context, Json(req.into())) + .await + .into_response(), + ) } else { - let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?; + let (headers, Json(generation)) = + generate(infer, compute_type, context, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation])).into_response()) } @@ -251,9 +257,14 @@ seed, async fn generate( infer: Extension, Extension(ComputeType(compute_type)): Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + generate_internal(infer, ComputeType(compute_type), Json(req), span).await } @@ -447,12 +458,17 @@ seed, async fn generate_stream( Extension(infer): Extension, Extension(compute_type): Extension, + Extension(context): Extension>, Json(req): Json, ) -> ( HeaderMap, Sse>>, ) { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + let (headers, response_stream) = generate_stream_internal(infer, compute_type, Json(req), span).await; @@ -682,9 +698,14 @@ async fn completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { @@ -1206,9 +1227,14 @@ async fn chat_completions( Extension(infer): Extension, Extension(compute_type): Extension, Extension(info): Extension, + Extension(context): Extension>, Json(chat): Json, ) -> Result)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + metrics::counter!("tgi_request_count").increment(1); let ChatRequest { stream, @@ -2348,6 +2374,7 @@ async fn start( .layer(Extension(compute_type)) .layer(Extension(prom_handle.clone())) .layer(OtelAxumLayer::default()) + .layer(axum::middleware::from_fn(trace_context_middleware)) .layer(cors_layer); tracing::info!("Connected"); diff --git a/router/src/vertex.rs b/router/src/vertex.rs index 0c1467fe373..03aed63c0ea 100644 --- a/router/src/vertex.rs +++ b/router/src/vertex.rs @@ -10,6 +10,7 @@ use axum::response::{IntoResponse, Response}; use axum::Json; use serde::{Deserialize, Serialize}; use tracing::instrument; +use tracing_opentelemetry::OpenTelemetrySpanExt; use utoipa::ToSchema; #[derive(Clone, Deserialize, ToSchema)] @@ -223,9 +224,14 @@ example = json ! ({"error": "Incomplete generation"})), pub(crate) async fn vertex_compatibility( Extension(infer): Extension, Extension(compute_type): Extension, + Extension(context): Extension>, Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); + if let Some(context) = context { + span.set_parent(context); + } + metrics::counter!("tgi_request_count").increment(1); // check that theres at least one instance