diff --git a/apollo-router/src/test_harness/http_snapshot.rs b/apollo-router/src/test_harness/http_snapshot.rs index 66f67ac95e..526527066e 100644 --- a/apollo-router/src/test_harness/http_snapshot.rs +++ b/apollo-router/src/test_harness/http_snapshot.rs @@ -49,10 +49,17 @@ use std::str::FromStr; use std::sync::Arc; use axum::extract::Path as AxumPath; +use axum::extract::RawQuery; use axum::extract::State; use axum::routing::any; use axum::Router; use base64::Engine; +use http::header::CONNECTION; +use http::header::CONTENT_LENGTH; +use http::header::HOST; +use http::header::TRAILER; +use http::header::TRANSFER_ENCODING; +use http::header::UPGRADE; use http::HeaderMap; use http::HeaderName; use http::HeaderValue; @@ -78,6 +85,16 @@ use crate::services::http::HttpRequest; use crate::services::router; use crate::services::router::body::RouterBody; +/// Headers that will not be passed on to the upstream API +static FILTERED_HEADERS: [HeaderName; 6] = [ + CONNECTION, + TRAILER, + TRANSFER_ENCODING, + UPGRADE, + HOST, + HeaderName::from_static("keep-alive"), +]; + /// An error from the snapshot server #[derive(Debug, thiserror::Error)] enum SnapshotError { @@ -112,26 +129,46 @@ async fn root_handler( State(state): State, req: http::Request, ) -> Result, StatusCode> { - handle(State(state), req, "/".to_string()).await + handle(State(state), req, "/".to_string(), None).await } async fn handler( State(state): State, AxumPath(path): AxumPath, + RawQuery(query): RawQuery, req: http::Request, ) -> Result, StatusCode> { - handle(State(state), req, path).await + handle(State(state), req, path, query).await } async fn handle( State(state): State, req: http::Request, path: String, + query: Option, ) -> Result, StatusCode> { + let path = if let Some(query) = query { + format!("{path}?{query}") + } else { + path + }; let uri = [state.base_url.to_string(), path.clone()].concat(); let method = req.method().clone(); let version = req.version(); - let request_headers = req.headers().clone(); + let request_headers: HeaderMap = req + .headers() + .clone() + .drain() + .filter_map(|(name, value)| { + name.and_then(|name| { + if !FILTERED_HEADERS.contains(&name) { + Some((name, value)) + } else { + None + } + }) + }) + .collect(); let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX) .await .unwrap(); @@ -163,7 +200,7 @@ async fn handle( .uri(uri.clone()) .body(router::body::from_bytes(body_bytes)) .unwrap(); - *request.headers_mut() = request_headers.clone(); + *request.headers_mut() = request_headers; let response = state .client .oneshot(HttpRequest { @@ -206,7 +243,7 @@ async fn handle( ); } } - if let Ok(response) = snapshot.into_body() { + if let Ok(response) = snapshot.into_response() { Ok(response) } else { fail(uri, method, "Unable to convert snapshot into response body") @@ -239,7 +276,7 @@ fn response_from_snapshot( ); snapshot .clone() - .into_body() + .into_response() .map_err(|e| error!("Unable to convert snapshot into HTTP response: {:?}", e)) .ok() }) @@ -442,8 +479,9 @@ struct Snapshot { } impl Snapshot { - fn into_body(self) -> Result, ()> { + fn into_response(self) -> Result, ()> { let mut response = http::Response::builder().status(self.response.status); + let body_string = self.response.body.to_string(); if let Some(headers) = response.headers_mut() { for (name, values) in self.response.headers.into_iter() { if let Ok(name) = HeaderName::from_str(&name.clone()) { @@ -456,8 +494,12 @@ impl Snapshot { warn!("Invalid header name `{}` in snapshot", name); } } + + // Rewrite the content length header to the actual body length. Serializing and + // deserializing the snapshot may result in a different length due to formatting + // differences. + headers.insert(CONTENT_LENGTH, HeaderValue::from(body_string.len())); } - let body_string = self.response.body.to_string(); if let Ok(response) = response.body(router::body::from_bytes(body_string)) { return Ok(response); }