diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 9feedb7a..ba2fb690 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,42 +1,19 @@ use crate::configuration::{ExtensionType, FailureMode, FilterConfig}; use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; -use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate}; use crate::policy::Policy; use crate::service::rate_limit::RateLimitService; -use crate::service::GrpcServiceHandler; +use crate::service::{GrpcServiceHandler, HeaderResolver}; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; -use proxy_wasm::types::{Action, Bytes}; +use proxy_wasm::types::Action; use std::rc::Rc; -// tracing headers -#[derive(Clone)] -pub enum TracingHeader { - Traceparent, - Tracestate, - Baggage, -} - -impl TracingHeader { - pub fn all() -> [Self; 3] { - [Traceparent, Tracestate, Baggage] - } - - pub fn as_str(&self) -> &'static str { - match self { - Traceparent => "traceparent", - Tracestate => "tracestate", - Baggage => "baggage", - } - } -} - pub struct Filter { pub context_id: u32, pub config: Rc, pub response_headers_to_add: Vec<(String, String)>, - pub tracing_headers: Vec<(TracingHeader, Bytes)>, + pub header_resolver: Rc, } impl Filter { @@ -66,7 +43,7 @@ impl Filter { let rls = GrpcServiceHandler::new( ExtensionType::RateLimit, rlp.service.clone(), - self.tracing_headers.clone(), + Rc::clone(&self.header_resolver), ); let message = RateLimitService::message(rlp.domain.clone(), descriptors); @@ -102,12 +79,6 @@ impl HttpContext for Filter { fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { debug!("#{} on_http_request_headers", self.context_id); - for header in TracingHeader::all() { - if let Some(value) = self.get_http_request_header_bytes(header.as_str()) { - self.tracing_headers.push((header, value)) - } - } - match self .config .index diff --git a/src/filter/root_context.rs b/src/filter/root_context.rs index ab28c72c..90774e1c 100644 --- a/src/filter/root_context.rs +++ b/src/filter/root_context.rs @@ -1,5 +1,6 @@ use crate::configuration::{FilterConfig, PluginConfiguration}; use crate::filter::http_context::Filter; +use crate::service::HeaderResolver; use const_format::formatcp; use log::{debug, error, info}; use proxy_wasm::traits::{Context, HttpContext, RootContext}; @@ -40,7 +41,7 @@ impl RootContext for FilterRoot { context_id, config: Rc::clone(&self.config), response_headers_to_add: Vec::default(), - tracing_headers: Vec::default(), + header_resolver: Rc::new(HeaderResolver::new()), })) } diff --git a/src/service.rs b/src/service.rs index 3f358550..0695145f 100644 --- a/src/service.rs +++ b/src/service.rs @@ -2,21 +2,22 @@ pub(crate) mod auth; pub(crate) mod rate_limit; use crate::configuration::ExtensionType; -use crate::filter::http_context::TracingHeader; use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; +use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; use protobuf::Message; use proxy_wasm::hostcalls; use proxy_wasm::hostcalls::dispatch_grpc_call; use proxy_wasm::types::{Bytes, MapType, Status}; use std::cell::OnceCell; +use std::rc::Rc; use std::time::Duration; pub struct GrpcServiceHandler { endpoint: String, service_name: String, method_name: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, + header_resolver: Rc, } impl GrpcServiceHandler { @@ -24,33 +25,33 @@ impl GrpcServiceHandler { endpoint: String, service_name: &str, method_name: &str, - tracing_headers: Vec<(TracingHeader, Bytes)>, + header_resolver: Rc, ) -> Self { Self { endpoint: endpoint.to_owned(), service_name: service_name.to_owned(), method_name: method_name.to_owned(), - tracing_headers, + header_resolver, } } pub fn new( extension_type: ExtensionType, endpoint: String, - tracing_headers: Vec<(TracingHeader, Bytes)>, + header_resolver: Rc, ) -> Self { match extension_type { ExtensionType::Auth => Self::new_base( endpoint, AUTH_SERVICE_NAME, AUTH_METHOD_NAME, - tracing_headers, + header_resolver, ), ExtensionType::RateLimit => Self::new_base( endpoint, RATELIMIT_SERVICE_NAME, RATELIMIT_METHOD_NAME, - tracing_headers, + header_resolver, ), } } @@ -58,9 +59,10 @@ impl GrpcServiceHandler { pub fn send(&self, message: M) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self - .tracing_headers + .header_resolver + .get() .iter() - .map(|(header, value)| (header.as_str(), value.as_slice())) + .map(|(header, value)| (*header, value.as_slice())) .collect(); dispatch_grpc_call( @@ -74,23 +76,49 @@ impl GrpcServiceHandler { } } -pub struct TracingHeaderResolver { - tracing_headers: OnceCell>, +pub struct HeaderResolver { + headers: OnceCell>, } -impl TracingHeaderResolver { - pub fn get(&self) -> &Vec<(TracingHeader, Bytes)> { - self.tracing_headers.get_or_init(|| { +impl HeaderResolver { + pub fn new() -> Self { + Self { + headers: OnceCell::new(), + } + } + + pub fn get(&self) -> &Vec<(&'static str, Bytes)> { + self.headers.get_or_init(|| { let mut headers = Vec::new(); for header in TracingHeader::all() { - if let Some(value) = - hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, header.as_str()) - .unwrap() + if let Ok(Some(value)) = + hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str()) { - headers.push((header, value)); + headers.push(((*header).as_str(), value)); } } headers }) } } + +// tracing headers +pub enum TracingHeader { + Traceparent, + Tracestate, + Baggage, +} + +impl TracingHeader { + fn all() -> &'static [Self; 3] { + &[Traceparent, Tracestate, Baggage] + } + + pub fn as_str(&self) -> &'static str { + match self { + Traceparent => "traceparent", + Tracestate => "tracestate", + Baggage => "baggage", + } + } +} diff --git a/src/service/auth.rs b/src/service/auth.rs index 36220f59..1e7c7344 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -5,7 +5,6 @@ use crate::envoy::{ }; use chrono::{DateTime, FixedOffset, Timelike}; use protobuf::well_known_types::Timestamp; -use protobuf::Message; use proxy_wasm::hostcalls; use proxy_wasm::types::MapType; use std::collections::HashMap;