Skip to content

Commit

Permalink
Refactor reverse_proxy function to use upstream request builder
Browse files Browse the repository at this point in the history
  • Loading branch information
arloor committed Sep 17, 2024
1 parent ee659e4 commit 33b15c2
Showing 1 changed file with 46 additions and 50 deletions.
96 changes: 46 additions & 50 deletions rust_http_proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ impl ProxyHandler {
.or(self.config.reverse_proxy_config.get(DEFAULT_HOST))
{
if let Some(location_config) = pick_location(req.uri().path(), locations) {
let upstream_req = build_upstream_req(req, location_config)?;
return self
.reverse_proxy(req, &scheme_host_port, location_config, &client_socket_addr)
.reverse_proxy(upstream_req, &scheme_host_port, &client_socket_addr)
.await;
}
}
Expand Down Expand Up @@ -329,38 +330,13 @@ impl ProxyHandler {

async fn reverse_proxy(
&self,
req: Request<Incoming>,
scheme_host_port: &(String, String, u16),
location_config: &LocationConfig,
upstream_req: Request<Incoming>,
raw_scheme_host_port: &(String, String, u16),
client_socket_addr: &SocketAddr,
) -> Result<Response<BoxBody<Bytes, io::Error>>, io::Error> {
let mut upstream_req_builder = new_upstream_req_builder(&req, location_config);
let header_map = match upstream_req_builder.headers_mut() {
Some(header_map) => header_map,
None => {
return Err(io::Error::new(
ErrorKind::InvalidData,
"new_req.headers_mut() is None, which means error occurs in new request build. Check URL, method, version...",
));
}
};
for ele in req.headers() {
if ele.0 != header::HOST {
header_map.append(ele.0.clone(), ele.1.clone());
} else {
info!("remove host header: {:?}", ele.1);
}
}
let upstream_req = upstream_req_builder
.body(req.into_body())
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
) -> io::Result<Response<BoxBody<Bytes, io::Error>>> {
info!(
"[reverse proxy] {:^35} ---> {} ---> [{}] {:?} [{:?}]",
"[reverse proxy] {:^35} ==> [{}] {:?} [{:?}]",
SocketAddrFormat(client_socket_addr).to_string(),
format!(
"{}://{}:{}",
scheme_host_port.0, scheme_host_port.1, scheme_host_port.2
),
upstream_req.method(),
&upstream_req.uri(),
upstream_req.version()
Expand All @@ -371,7 +347,7 @@ impl ProxyHandler {
Ok(mut resp) => {
handle_redirect(
&mut resp,
scheme_host_port,
raw_scheme_host_port,
upstream_req_url,
&self.redirect_bachpaths,
);
Expand All @@ -391,10 +367,10 @@ impl ProxyHandler {
}
}

fn new_upstream_req_builder(
req: &Request<Incoming>,
fn build_upstream_req(
req: Request<Incoming>,
location_config: &LocationConfig,
) -> http::request::Builder {
) -> io::Result<Request<Incoming>> {
let method = req.method().clone();
let path_and_query = match req.uri().path_and_query() {
Some(path_and_query) => path_and_query.as_str(),
Expand All @@ -408,22 +384,42 @@ fn new_upstream_req_builder(
path_and_query
);

Request::builder()
.method(method)
.uri(url.clone())
.version(if !url.starts_with("https:") {
match location_config.upstream.version {
reverse::Version::H1 => Version::HTTP_11,
reverse::Version::H2 => Version::HTTP_2,
reverse::Version::Auto => Version::HTTP_11,
}
let mut builder =
Request::builder()
.method(method)
.uri(url.clone())
.version(if !url.starts_with("https:") {
match location_config.upstream.version {
reverse::Version::H1 => Version::HTTP_11,
reverse::Version::H2 => Version::HTTP_2,
reverse::Version::Auto => Version::HTTP_11,
}
} else {
match location_config.upstream.version {
reverse::Version::H1 => Version::HTTP_11,
reverse::Version::H2 => Version::HTTP_2,
reverse::Version::Auto => req.version(),
}
});
let header_map = match builder.headers_mut() {
Some(header_map) => header_map,
None => {
return Err(io::Error::new(
ErrorKind::InvalidData,
"new_req.headers_mut() is None, which means error occurs in new request build. Check URL, method, version...",
));
}
};
for ele in req.headers() {
if ele.0 != header::HOST {
header_map.append(ele.0.clone(), ele.1.clone());
} else {
match location_config.upstream.version {
reverse::Version::H1 => Version::HTTP_11,
reverse::Version::H2 => Version::HTTP_2,
reverse::Version::Auto => req.version(),
}
})
info!("remove host header: {:?}", ele.1);
}
}
builder
.body(req.into_body())
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))
}

fn scheme_host_port(req: &Request<Incoming>, server_over_tls: bool) -> (String, String, u16) {
Expand Down Expand Up @@ -509,7 +505,7 @@ fn lookup(
) -> Option<String> {
for ele in redirect_bachpaths.iter() {
if absolute_location.starts_with(ele.redirect_url.as_str()) {
debug!(
info!(
" {:<70} -> {}",
format!("[scheme]://{}:[port]{}", ele.host, ele.location),
ele.redirect_url,
Expand Down

0 comments on commit 33b15c2

Please sign in to comment.