diff --git a/Cargo.lock b/Cargo.lock index 1ff0afd9..4e882d88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1324,6 +1324,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + [[package]] name = "ryu" version = "1.0.11" @@ -1460,6 +1466,25 @@ dependencies = [ "syn 1.0.99", ] +[[package]] +name = "strum" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" + +[[package]] +name = "strum_macros" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6cf59daf282c0a494ba14fd21610a0325f9f90ec9d1231dea26bcb1d696c946" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.18", +] + [[package]] name = "syn" version = "1.0.99" @@ -1684,6 +1709,8 @@ dependencies = [ "serde", "serde_json", "serial_test", + "strum", + "strum_macros", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index a06794c1..ed23b588 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ protobuf = { version = "2.27", features = ["with-serde"] } thiserror = "1.0" regex = "1" radix_trie = "0.2.1" +strum = "0.26.2" +strum_macros = "0.26.2" [dev-dependencies] proxy-wasm-test-framework = { git = "https://github.com/Kuadrant/wasm-test-framework.git", branch = "kuadrant" } diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 4502a24d..0f884dc5 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -10,23 +10,38 @@ use crate::utils::tokenize_with_escaping; use log::{debug, info, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; -use proxy_wasm::types::Action; +use proxy_wasm::types::{Action, Bytes}; use std::rc::Rc; use std::time::Duration; +use strum::IntoEnumIterator; +use strum_macros::EnumIter; const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; // tracing headers -const TRACEPARENT_HEADER: &str = "traceparent"; -const TRACESTATE_HEADER: &str = "tracestate"; -const BAGGAGE_HEADER: &str = "baggage"; +#[derive(EnumIter)] +pub enum TracingHeader { + Traceparent, + Tracestate, + Baggage, +} + +impl TracingHeader { + fn as_str(&self) -> &'static str { + match self { + TracingHeader::Traceparent => "traceparent", + TracingHeader::Tracestate => "tracestate", + TracingHeader::Baggage => "baggage", + } + } +} pub struct Filter { pub context_id: u32, pub config: Rc, pub response_headers_to_add: Vec<(String, String)>, - pub tracing_headers: Vec<(String, String)>, + pub tracing_headers: Vec<(TracingHeader, Bytes)>, } impl Filter { @@ -60,7 +75,7 @@ impl Filter { let rl_tracing_headers = self .tracing_headers .iter() - .map(|(header, value)| (header.as_str(), value.as_bytes())) + .map(|(header, value)| (header.as_str(), value.as_slice())) .collect(); match self.dispatch_grpc_call( @@ -219,13 +234,10 @@ impl HttpContext for Filter { fn on_http_request_headers(&mut self, _: usize, _: bool) -> Action { info!("on_http_request_headers #{}", self.context_id); - let req_headers = self.get_http_request_headers(); - for (header, value) in req_headers.iter() { - match header.to_lowercase().as_str() { - TRACEPARENT_HEADER | TRACESTATE_HEADER | BAGGAGE_HEADER => { - self.tracing_headers.push((header.clone(), value.clone())) - } - _ => (), + for header in TracingHeader::iter() { + match self.get_http_request_header_bytes(header.as_str()) { + Some(value) => self.tracing_headers.push((header, value)), + None => (), } } match self