Skip to content

Commit

Permalink
refactor: Update reverse_proxy.yaml with new upstream locations
Browse files Browse the repository at this point in the history
  • Loading branch information
arloor committed Sep 18, 2024
1 parent 1d9a784 commit 12a8e5e
Showing 1 changed file with 47 additions and 34 deletions.
81 changes: 47 additions & 34 deletions rust_http_proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,25 @@ impl ProxyHandler {
let never_ask_for_auth = self.config.never_ask_for_auth;
// 1. serve stage (static files|reverse proxy)
if Method::CONNECT != req.method() {
let scheme_host_port = scheme_host_port(&req, self.config.over_tls);
let req_basic = extract_requst_basic_info(&req, self.config.over_tls);
if let Some(locations) = self
.config
.reverse_proxy_config
.get(&scheme_host_port.1)
.get(&req_basic.host)
.or(self.config.reverse_proxy_config.get(DEFAULT_HOST))
{
if let Some(location_config) = pick_location(req.uri().path(), locations) {
let upstream_req = build_upstream_req(req, location_config)?;
return self
.reverse_proxy(upstream_req, &scheme_host_port, &client_socket_addr)
.await;
info!(
"[reverse proxy] {:^35} => {}{}... ==> [{}] {:?} [{:?}]",
SocketAddrFormat(&client_socket_addr).to_string(),
req_basic,
location_config.location,
upstream_req.method(),
&upstream_req.uri(),
upstream_req.version(),
);
return self.reverse_proxy(upstream_req, &req_basic).await;
}
}
if req.version() == Version::HTTP_2 || req.uri().host().is_none() {
Expand Down Expand Up @@ -331,23 +338,15 @@ impl ProxyHandler {
async fn reverse_proxy(
&self,
upstream_req: Request<Incoming>,
raw_scheme_host_port: &(String, String, u16),
client_socket_addr: &SocketAddr,
req_basic: &ReqBasic,
) -> io::Result<Response<BoxBody<Bytes, io::Error>>> {
info!(
"[reverse proxy] {:^35} ==> [{}] {:?} [{:?}]",
SocketAddrFormat(client_socket_addr).to_string(),
upstream_req.method(),
&upstream_req.uri(),
upstream_req.version()
);
let upstream_req_url = upstream_req.uri().clone();
// debug!("reverse_proxy: {:?}", new_req);
match self.reverse_client.request(upstream_req).await {
Ok(mut resp) => {
handle_redirect(
&mut resp,
raw_scheme_host_port,
req_basic,
upstream_req_url,
&self.redirect_bachpaths,
);
Expand Down Expand Up @@ -424,23 +423,35 @@ fn build_upstream_req(
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))
}

fn scheme_host_port(req: &Request<Incoming>, server_over_tls: bool) -> (String, String, u16) {
struct ReqBasic {
scheme: String,
host: String,
port: u16,
}

impl Display for ReqBasic {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}://{}:{}", self.scheme, self.host, self.port)
}
}

fn extract_requst_basic_info(req: &Request<Incoming>, server_over_tls: bool) -> ReqBasic {
let uri = req.uri();
let scheme = uri.scheme_str().unwrap_or(match server_over_tls {
true => "https",
false => "http",
});
if req.version() == Version::HTTP_2 {
//H2,信息全在uri中
(
scheme.to_owned(),
uri.host().unwrap_or("").to_string(),
uri.port_u16().unwrap_or(match scheme {
ReqBasic {
scheme: scheme.to_owned(),
host: uri.host().unwrap_or("").to_string(),
port: uri.port_u16().unwrap_or(match scheme {
"https" => 443,
"http" => 80,
_ => 443,
}),
)
}
} else {
let mut split = req
.headers()
Expand All @@ -461,13 +472,17 @@ fn scheme_host_port(req: &Request<Incoming>, server_over_tls: bool) -> (String,
"http" => 80,
_ => 443,
});
(scheme.to_owned(), host, port)
ReqBasic {
scheme: scheme.to_owned(),
host,
port,
}
}
}

fn handle_redirect(
resp: &mut Response<Incoming>,
scheme_host_port: &(String, String, u16),
req_basic: &ReqBasic,
upstream_req_uri: Uri,
redirect_bachpaths: &[RedirectBackpaths],
) {
Expand All @@ -477,11 +492,9 @@ fn handle_redirect(
if let Some(absolute_redirect_location) =
ensure_absolute(redirect_location, &upstream_req_uri)
{
if let Some(replacement) = lookup(
scheme_host_port,
absolute_redirect_location,
redirect_bachpaths,
) {
if let Some(replacement) =
lookup(req_basic, absolute_redirect_location, redirect_bachpaths)
{
let _ = mem::replace(
redirect_location,
HeaderValue::from_str(replacement.as_str()).unwrap(),
Expand All @@ -500,27 +513,27 @@ struct RedirectBackpaths {
}

fn lookup(
scheme_host_port: &(String, String, u16),
req_basic: &ReqBasic,
absolute_location: String,
redirect_bachpaths: &[RedirectBackpaths],
) -> Option<String> {
for ele in redirect_bachpaths.iter() {
if absolute_location.starts_with(ele.redirect_url.as_str()) {
info!(
" {:<70} -> {}",
format!("[scheme]://{}:[port]{}", ele.host, ele.location),
"redirect back path for {} is {}",
ele.redirect_url,
format!("[scheme]://{}:[port]{}", ele.host, ele.location),
);
let host = match ele.host.as_str() {
DEFAULT_HOST => &scheme_host_port.1, // 如果是default_host,就用当前host
DEFAULT_HOST => &req_basic.host, // 如果是default_host,就用当前host
other => other,
};
return Some(
scheme_host_port.0.to_owned() // use raw requst's scheme
req_basic.scheme.to_owned() // use raw requst's scheme
+ "://"
+ host // if it's default_host, use raw requst's host
+ ":"
+ &scheme_host_port.2.to_string() // use raw requst's port
+ &req_basic.port.to_string() // use raw requst's port
+ &ele.location
+ &absolute_location[ele.redirect_url.len()..],
);
Expand Down

0 comments on commit 12a8e5e

Please sign in to comment.