Skip to content

Commit

Permalink
Refactor build_upstream_req function to handle HTTP/2 and HTTPS corre…
Browse files Browse the repository at this point in the history
…ctly
  • Loading branch information
arloor committed Sep 19, 2024
1 parent cfeb6d1 commit a5394e3
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 80 deletions.
6 changes: 3 additions & 3 deletions rust_http_proxy/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl TryFrom<Param> for Config {
}
}

pub(crate) fn load_config() -> Config {
pub(crate) fn load_config() -> Result<Config, DynError> {
let mut param = Param::parse();
param.hostname = get_hostname();
if let Err(log_init_error) = init_log(&param.log_dir, &param.log_file) {
Expand All @@ -174,7 +174,7 @@ pub(crate) fn load_config() -> Config {
let _ = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider().install_default();
}
info!("hostname seems to be {}", param.hostname);
let config = Config::try_from(param).unwrap();
let config = Config::try_from(param)?;
for ele in &config.reverse_proxy_config {
for location_config in ele.1 {
match location_config.upstream.scheme_and_authority.parse::<Uri>() {
Expand Down Expand Up @@ -212,7 +212,7 @@ pub(crate) fn load_config() -> Config {
}
log_config(&config);
info!("auto close connection after idle for {:?}", IDLE_TIMEOUT);
config
Ok(config)
}

fn log_config(config: &Config) {
Expand Down
51 changes: 0 additions & 51 deletions rust_http_proxy/src/macros.rs

This file was deleted.

4 changes: 1 addition & 3 deletions rust_http_proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ mod net_monitor;
mod proxy;
mod tls_helper;
mod web_func;
#[macro_use]
mod macros;
mod address;
mod config;
mod http1_client;
Expand Down Expand Up @@ -54,7 +52,7 @@ static LOCAL_IP: LazyLock<String> = LazyLock::new(|| local_ip().unwrap_or("0.0.0

#[tokio::main]
async fn main() -> Result<(), DynError> {
let proxy_config: Config = load_config();
let proxy_config: Config = load_config()?;
let ports = proxy_config.port.clone();
let proxy_handler = Arc::new(ProxyHandler::new(proxy_config));
#[cfg(feature = "jemalloc")]
Expand Down
8 changes: 5 additions & 3 deletions rust_http_proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl ProxyHandler {
origin_req_basic,
&(upstream_uri_parts.scheme, upstream_uri_parts.authority),
&self.redirect_bachpaths,
);
)?;
Ok(resp.map(|body| {
body.map_err(|e| {
let e = e;
Expand Down Expand Up @@ -551,7 +551,7 @@ fn handle_redirect(
req_basic: &ReqBasic,
upstream_uri_parts: &(Option<Scheme>, Option<Authority>),
redirect_bachpaths: &[RedirectBackpaths],
) {
) -> io::Result<()> {
if resp.status().is_redirection() {
let headers = resp.headers_mut();
if let Some(redirect_location) = headers.get_mut(LOCATION) {
Expand All @@ -563,13 +563,15 @@ fn handle_redirect(
{
let origin = headers.insert(
LOCATION,
HeaderValue::from_str(replacement.as_str()).unwrap(),
HeaderValue::from_str(replacement.as_str())
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?,
);
info!("redirect to [{}], origin is [{:?}]", replacement, origin);
}
}
}
}
Ok(())
}

struct RedirectBackpaths {
Expand Down
56 changes: 36 additions & 20 deletions rust_http_proxy/src/web_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ fn incr_counter_if_need(
if is_outer_view_html && (res.status().is_success() || res.status().is_redirection()) {
http_req_counter
.get_or_create(&LabelImpl::new(ReqLabels {
referer: extract_search_engine_from_referer(referer_header),
referer: extract_search_engine_from_referer(referer_header)
.unwrap_or("parse_failed".to_string()),
path: path.to_string(),
}))
.inc();
Expand All @@ -165,23 +166,23 @@ fn incr_counter_if_need(
}
}

fn extract_search_engine_from_referer(referer: &str) -> String {
fn extract_search_engine_from_referer(referer: &str) -> io::Result<String> {
if let Some(caps) = Regex::new("^https?://(.+?)(/|$)")
.unwrap()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
.captures(referer)
{
let address = caps.get(1).map_or(referer, |g| g.as_str());
if let Some(caps) =
Regex::new("(google|baidu|bing|yandex|v2ex|github|stackoverflow|duckduckgo)")
.unwrap()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
.captures(address)
{
caps.get(1).map_or(address, |g| g.as_str()).to_string()
Ok(caps.get(1).map_or(address, |g| g.as_str()).to_string())
} else {
address.to_owned()
Ok(address.to_owned())
}
} else {
referer.to_string()
Ok(referer.to_string())
}
}

Expand Down Expand Up @@ -385,9 +386,12 @@ fn parse_range(
let mut start = 0;
let mut end = file_size - 1;
if let Some(range_value) = range_header {
let range_value = range_value.to_str().unwrap();
let range_value = range_value
.to_str()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
// 仅支持单个range,不支持多个range
let re = Regex::new(r"^bytes=(\d*)-(\d*)$").unwrap();
let re = Regex::new(r"^bytes=(\d*)-(\d*)$")
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
// 使用正则表达式匹配字符串并捕获组
let caps = re.captures(range_value);
match caps {
Expand All @@ -399,7 +403,9 @@ fn parse_range(
if left.is_empty() {
if !right.is_empty() {
// suffix-length格式,例如bytes=-100
let right = right.parse::<u64>().unwrap();
let right = right
.parse::<u64>()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
if right < file_size {
start = file_size - right;
} else {
Expand All @@ -409,9 +415,13 @@ fn parse_range(
}
} else {
// start-end格式,例如bytes=100-200或bytes=100-
start = left.parse::<u64>().unwrap();
start = left
.parse::<u64>()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
if !right.is_empty() {
end = right.parse::<u64>().unwrap();
end = right
.parse::<u64>()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
}
}
builder = builder
Expand Down Expand Up @@ -600,34 +610,40 @@ mod tests {
#[test]
fn test_extract_domain_from_url() {
assert_eq!(
extract_search_engine_from_referer("https://www.baidu.com/"),
extract_search_engine_from_referer("https://www.baidu.com/")
.unwrap_or("default".to_string()),
"baidu"
);
assert_eq!(
extract_search_engine_from_referer("https://www.baidu.com"),
extract_search_engine_from_referer("https://www.baidu.com")
.unwrap_or("default".to_string()),
"baidu"
);
assert_eq!(
extract_search_engine_from_referer("http://www.baidu.com/"),
extract_search_engine_from_referer("http://www.baidu.com/")
.unwrap_or("default".to_string()),
"baidu"
);
assert_eq!(
extract_search_engine_from_referer("sadasdasdsadas"),
extract_search_engine_from_referer("sadasdasdsadas").unwrap_or("default".to_string()),
"sadasdasdsadas"
);
assert_eq!(
extract_search_engine_from_referer("http://huaiwen.com/baidu.com/bing.com"),
extract_search_engine_from_referer("http://huaiwen.com/baidu.com/bing.com")
.unwrap_or("default".to_string()),
"huaiwen.com"
);
assert_eq!(
extract_search_engine_from_referer("http://huaiwenbaidu.com/baidu.com/bing.com"),
extract_search_engine_from_referer("http://huaiwenbaidu.com/baidu.com/bing.com")
.unwrap_or("default".to_string()),
"baidu"
);
assert_eq!(
extract_search_engine_from_referer("https://www.google.com.hk/"),
extract_search_engine_from_referer("https://www.google.com.hk/")
.unwrap_or("default".to_string()),
"google"
);
assert_eq!(extract_search_engine_from_referer("https://www.bing.com/search?q=google%E6%9C%8D%E5%8A%A1%E4%B8%8B%E8%BD%BD+anzhuo11&qs=ds&form=QBRE"), "bing");
assert_eq!(extract_search_engine_from_referer("https://www.bing.com/search?q=google%E6%9C%8D%E5%8A%A1%E4%B8%8B%E8%BD%BD+anzhuo11&qs=ds&form=QBRE").unwrap_or("default".to_string()), "bing");
}

#[test]
Expand Down

0 comments on commit a5394e3

Please sign in to comment.