Skip to content

Commit

Permalink
Snapshot server fixes (#6706)
Browse files Browse the repository at this point in the history
  • Loading branch information
pubmodmatt authored Jan 30, 2025
1 parent 077f52b commit 1138a16
Showing 1 changed file with 50 additions and 8 deletions.
58 changes: 50 additions & 8 deletions apollo-router/src/test_harness/http_snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,17 @@ use std::str::FromStr;
use std::sync::Arc;

use axum::extract::Path as AxumPath;
use axum::extract::RawQuery;
use axum::extract::State;
use axum::routing::any;
use axum::Router;
use base64::Engine;
use http::header::CONNECTION;
use http::header::CONTENT_LENGTH;
use http::header::HOST;
use http::header::TRAILER;
use http::header::TRANSFER_ENCODING;
use http::header::UPGRADE;
use http::HeaderMap;
use http::HeaderName;
use http::HeaderValue;
Expand All @@ -78,6 +85,16 @@ use crate::services::http::HttpRequest;
use crate::services::router;
use crate::services::router::body::RouterBody;

/// Headers that will not be passed on to the upstream API
static FILTERED_HEADERS: [HeaderName; 6] = [
CONNECTION,
TRAILER,
TRANSFER_ENCODING,
UPGRADE,
HOST,
HeaderName::from_static("keep-alive"),
];

/// An error from the snapshot server
#[derive(Debug, thiserror::Error)]
enum SnapshotError {
Expand Down Expand Up @@ -112,26 +129,46 @@ async fn root_handler(
State(state): State<SnapshotServerState>,
req: http::Request<axum::body::Body>,
) -> Result<http::Response<RouterBody>, StatusCode> {
handle(State(state), req, "/".to_string()).await
handle(State(state), req, "/".to_string(), None).await
}

async fn handler(
State(state): State<SnapshotServerState>,
AxumPath(path): AxumPath<String>,
RawQuery(query): RawQuery,
req: http::Request<axum::body::Body>,
) -> Result<http::Response<RouterBody>, StatusCode> {
handle(State(state), req, path).await
handle(State(state), req, path, query).await
}

async fn handle(
State(state): State<SnapshotServerState>,
req: http::Request<axum::body::Body>,
path: String,
query: Option<String>,
) -> Result<http::Response<RouterBody>, StatusCode> {
let path = if let Some(query) = query {
format!("{path}?{query}")
} else {
path
};
let uri = [state.base_url.to_string(), path.clone()].concat();
let method = req.method().clone();
let version = req.version();
let request_headers = req.headers().clone();
let request_headers: HeaderMap = req
.headers()
.clone()
.drain()
.filter_map(|(name, value)| {
name.and_then(|name| {
if !FILTERED_HEADERS.contains(&name) {
Some((name, value))
} else {
None
}
})
})
.collect();
let body_bytes = axum::body::to_bytes(req.into_body(), usize::MAX)
.await
.unwrap();
Expand Down Expand Up @@ -163,7 +200,7 @@ async fn handle(
.uri(uri.clone())
.body(router::body::from_bytes(body_bytes))
.unwrap();
*request.headers_mut() = request_headers.clone();
*request.headers_mut() = request_headers;
let response = state
.client
.oneshot(HttpRequest {
Expand Down Expand Up @@ -206,7 +243,7 @@ async fn handle(
);
}
}
if let Ok(response) = snapshot.into_body() {
if let Ok(response) = snapshot.into_response() {
Ok(response)
} else {
fail(uri, method, "Unable to convert snapshot into response body")
Expand Down Expand Up @@ -239,7 +276,7 @@ fn response_from_snapshot(
);
snapshot
.clone()
.into_body()
.into_response()
.map_err(|e| error!("Unable to convert snapshot into HTTP response: {:?}", e))
.ok()
})
Expand Down Expand Up @@ -442,8 +479,9 @@ struct Snapshot {
}

impl Snapshot {
fn into_body(self) -> Result<http::Response<RouterBody>, ()> {
fn into_response(self) -> Result<http::Response<RouterBody>, ()> {
let mut response = http::Response::builder().status(self.response.status);
let body_string = self.response.body.to_string();
if let Some(headers) = response.headers_mut() {
for (name, values) in self.response.headers.into_iter() {
if let Ok(name) = HeaderName::from_str(&name.clone()) {
Expand All @@ -456,8 +494,12 @@ impl Snapshot {
warn!("Invalid header name `{}` in snapshot", name);
}
}

// Rewrite the content length header to the actual body length. Serializing and
// deserializing the snapshot may result in a different length due to formatting
// differences.
headers.insert(CONTENT_LENGTH, HeaderValue::from(body_string.len()));
}
let body_string = self.response.body.to_string();
if let Ok(response) = response.body(router::body::from_bytes(body_string)) {
return Ok(response);
}
Expand Down

0 comments on commit 1138a16

Please sign in to comment.