Skip to content

Commit

Permalink
fix abs uri priority in other request extractor as well
Browse files Browse the repository at this point in the history
  • Loading branch information
GlenDC committed Sep 16, 2024
1 parent 288fe7e commit 326d6c0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
13 changes: 10 additions & 3 deletions rama-http/src/service/web/endpoint/extract/authority.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ mod tests {
use crate::{Body, HeaderName, Request};
use rama_core::Service;

async fn test_authority_from_request(authority: &str, headers: Vec<(&HeaderName, &str)>) {
async fn test_authority_from_request(
uri: &str,
authority: &str,
headers: Vec<(&HeaderName, &str)>,
) {
let svc = GetForwardedHeadersService::x_forwarded_host(
WebService::default().get("/", |Authority(authority): Authority| async move {
authority.to_string()
}),
);

let mut builder = Request::builder().method("GET").uri("http://example.com/");
let mut builder = Request::builder().method("GET").uri(uri);
for (header, value) in headers {
builder = builder.header(header, value);
}
Expand All @@ -71,6 +75,7 @@ mod tests {
#[tokio::test]
async fn host_header() {
test_authority_from_request(
"/",
"some-domain:123",
vec![(&http::header::HOST, "some-domain:123")],
)
Expand All @@ -80,6 +85,7 @@ mod tests {
#[tokio::test]
async fn x_forwarded_host_header() {
test_authority_from_request(
"/",
"some-domain:456",
vec![(&X_FORWARDED_HOST, "some-domain:456")],
)
Expand All @@ -89,6 +95,7 @@ mod tests {
#[tokio::test]
async fn x_forwarded_host_precedence_over_host_header() {
test_authority_from_request(
"/",
"some-domain:456",
vec![
(&X_FORWARDED_HOST, "some-domain:456"),
Expand All @@ -100,6 +107,6 @@ mod tests {

#[tokio::test]
async fn uri_host() {
test_authority_from_request("example.com:80", vec![]).await;
test_authority_from_request("http://example.com", "example.com:80", vec![]).await;
}
}
15 changes: 11 additions & 4 deletions rama-http/src/service/web/endpoint/extract/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ mod tests {
use crate::{Body, HeaderName, Request};
use rama_core::Service;

async fn test_host_from_request(host: &str, headers: Vec<(&HeaderName, &str)>) {
async fn test_host_from_request(uri: &str, host: &str, headers: Vec<(&HeaderName, &str)>) {
let svc = GetForwardedHeadersService::x_forwarded_host(
WebService::default().get("/", |Host(host): Host| async move { host.to_string() }),
);

let mut builder = Request::builder().method("GET").uri("http://example.com/");
let mut builder = Request::builder().method("GET").uri(uri);
for (header, value) in headers {
builder = builder.header(header, value);
}
Expand All @@ -70,6 +70,7 @@ mod tests {
#[tokio::test]
async fn host_header() {
test_host_from_request(
"/",
"some-domain",
vec![(&http::header::HOST, "some-domain:123")],
)
Expand All @@ -78,12 +79,18 @@ mod tests {

#[tokio::test]
async fn x_forwarded_host_header() {
test_host_from_request("some-domain", vec![(&X_FORWARDED_HOST, "some-domain:456")]).await;
test_host_from_request(
"/",
"some-domain",
vec![(&X_FORWARDED_HOST, "some-domain:456")],
)
.await;
}

#[tokio::test]
async fn x_forwarded_host_precedence_over_host_header() {
test_host_from_request(
"/",
"some-domain",
vec![
(&X_FORWARDED_HOST, "some-domain:456"),
Expand All @@ -95,6 +102,6 @@ mod tests {

#[tokio::test]
async fn uri_host() {
test_host_from_request("example.com", vec![]).await;
test_host_from_request("http://example.com", "example.com", vec![]).await;
}
}
32 changes: 15 additions & 17 deletions rama-net/src/http/request_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,16 @@ impl<State> TryFrom<(&Context<State>, &Parts)> for RequestContext {
let default_port = uri.port_u16().unwrap_or_else(|| protocol.default_port());
tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");

let authority = ctx
.get::<Forwarded>()
.and_then(|f| {
f.client_host().map(|fauth| {
let (host, port) = fauth.clone().into_parts();
let port = port.unwrap_or(default_port);
(host, port).into()
let authority = uri
.host()
.and_then(|h| Host::try_from(h).ok().map(|h| (h, default_port).into()))
.or_else(|| {
ctx.get::<Forwarded>().and_then(|f| {
f.client_host().map(|fauth| {
let (host, port) = fauth.clone().into_parts();
let port = port.unwrap_or(default_port);
(host, port).into()
})
})
})
.or_else(|| {
Expand All @@ -127,10 +130,6 @@ impl<State> TryFrom<(&Context<State>, &Parts)> for RequestContext {
.ok()
})
})
.or_else(|| {
uri.host()
.and_then(|h| Host::try_from(h).ok().map(|h| (h, default_port).into()))
})
.ok_or_else(|| {
OpaqueError::from_display(
"RequestContext: no authourity found in http::request::Parts",
Expand Down Expand Up @@ -163,15 +162,14 @@ impl<State> TryFrom<(&Context<State>, &Parts)> for RequestContext {

#[allow(clippy::unnecessary_lazy_evaluations)]
fn protocol_from_uri_or_context<State>(ctx: &Context<State>, uri: &Uri) -> Protocol {
ctx.get::<Forwarded>()
uri.scheme().map(|s| {
tracing::trace!(uri = %uri, "request context: detected protocol from scheme");
s.into()
}).or_else(|| ctx.get::<Forwarded>()
.and_then(|f| f.client_proto().map(|p| {
tracing::trace!(uri = %uri, "request context: detected protocol from forwarded client proto");
p.into()
}))
.or_else(|| uri.scheme().map(|s| {
tracing::trace!(uri = %uri, "request context: detected protocol from scheme");
s.into()
}))
})))
.or_else(|| {
#[cfg(feature = "tls")]
{
Expand Down

0 comments on commit 326d6c0

Please sign in to comment.