From 131d32d94ca5ed8bf569d6149683252313adc213 Mon Sep 17 00:00:00 2001 From: neevek Date: Sun, 12 May 2024 19:45:39 +0800 Subject: [PATCH] support timeout and close fd in time to avoid FIN_WAIT2 --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/lib.rs | 4 +- src/server.rs | 182 ++++++++++++++++++++++++++++++++++++++------------ 4 files changed, 145 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ed598dc..071badf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1089,7 +1089,7 @@ dependencies = [ [[package]] name = "omnip" -version = "0.4.20" +version = "0.4.21" dependencies = [ "android_logger", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 0d05bd7..2249c83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "omnip" -version = "0.4.20" +version = "0.4.21" edition = "2021" [lib] diff --git a/src/lib.rs b/src/lib.rs index bd3009b..fd83164 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ pub enum ProxyError { IPv6NotSupported, // not supported by Socks4 InternalError, BadRequest, + Timeout, PayloadTooLarge, BadGateway(anyhow::Error), Disconnected(anyhow::Error), @@ -559,8 +560,7 @@ pub mod android { proxy_rules_file, jthreads as usize, false, - true, - jtcpNoDelay as bool, + jtcpNoDelay != 0, ) { Ok(config) => config, Err(e) => { diff --git a/src/server.rs b/src/server.rs index 4ecf9a6..b7e8657 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,7 +14,7 @@ use crate::{ QuicServerConfig, }; use anyhow::{anyhow, bail, Context, Result}; -use log::{debug, error, info}; +use log::{debug, error, info, warn}; use notify::event::ModifyKind; use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; use rs_utilities::dns::{ @@ -28,11 +28,14 @@ use std::path::Path; use std::str::{self, FromStr}; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::task::JoinHandle; +use tokio::time::error::Elapsed; use tokio::{ io::AsyncReadExt, + io::AsyncWriteExt, net::{TcpListener, TcpStream}, }; @@ -344,7 +347,7 @@ impl Server { loop { match proxy_listener.accept().await { - Ok((inbound_stream, addr)) => { + Ok((inbound_stream, _addr)) => { let psp = psp.clone(); let (prefer_upstream, upstream, dns_resolver) = copy_inner_state!(self, prefer_upstream, upstream, dns_resolver); @@ -373,26 +376,31 @@ impl Server { .await .ok(); } - return; - } - - Self::process_stream( - inbound_stream, - psp, - upstream, - prefer_upstream, - dns_resolver.unwrap(), - ) - .await - .map_err(|e| match e { - ProxyError::BadRequest | ProxyError::BadGateway(_) => { - error!("err: {e:?}"); + } else { + match Self::process_stream( + inbound_stream, + psp, + upstream, + prefer_upstream, + dns_resolver.unwrap(), + ) + .await + { + Ok(()) => {} + Err(ProxyError::BadRequest) => { + error!("BadRequest"); + } + Err(ProxyError::BadGateway(e)) => { + error!("BadGateway: {e:?}"); + } + Err(ProxyError::Timeout) => { + error!("Timeout"); + } + Err(e) => { + error!("generic error: {e:?}"); + } } - _ => {} - }) - .ok(); - - debug!("connection closed: {addr}"); + } }); } @@ -466,7 +474,7 @@ impl Server { tokio::time::timeout(Duration::from_secs(2), inbound_stream.read(&mut buffer)) .await .map_err(|_| ProxyError::BadRequest)? - .map_err(|_| ProxyError::BadRequest)?; + .map_err(|_| ProxyError::Timeout)?; if proxy_handler.is_none() { proxy_handler = Some(Self::create_proxy_handler( @@ -527,8 +535,8 @@ impl Server { }; debug!( - "forward payload to next proxy({:?}), {addr} -> {:?}", - outbound_type, upstream + "forward payload to next proxy({outbound_type:?}), {addr} -> {}", + upstream.unwrap() ); outbound_stream = @@ -620,6 +628,10 @@ impl Server { } None => { + warn!( + "failed to create outbound stream for: {addr} from {}", + inbound_stream.peer_addr().unwrap() + ); proxy_handler .handle_outbound_failure(&mut inbound_stream) .await?; @@ -682,31 +694,119 @@ impl Server { } async fn start_stream_transfer( - mut a_stream: TcpStream, - mut b_stream: TcpStream, + mut inbound_stream: TcpStream, + mut outbound_stream: TcpStream, stats_sender: &Sender, ) -> Result { stats_sender.send(ServerStats::NewConnection).await.ok(); - let result = match tokio::io::copy_bidirectional(&mut a_stream, &mut b_stream).await { - Ok((tx_bytes, rx_bytes)) => { - debug!( - "transfer, out:{tx_bytes}, in:{rx_bytes}, {} <-> {:?}", - a_stream.local_addr().unwrap(), - b_stream.local_addr().unwrap(), - ); - stats_sender - .send(ServerStats::Traffic(ProxyTraffic { tx_bytes, rx_bytes })) - .await - .ok(); + const BUFFER_SIZE: usize = 4096; + let mut inbound_buffer = [0u8; BUFFER_SIZE]; + let mut outbound_buffer = [0u8; BUFFER_SIZE]; + let (mut inbound_reader, mut inbound_writer) = inbound_stream.split(); + let (mut outbound_reader, mut outbound_writer) = outbound_stream.split(); + + let mut tx_bytes = 0u64; + let mut rx_bytes = 0u64; + let mut inbound_stream_eos = false; + let mut outbound_stream_eos = false; + let mut loop_count = 0; + + loop { + loop_count += 1; + let result = if !inbound_stream_eos && !outbound_stream_eos { + tokio::select! { + result = Self::transfer_data_with_timeout( + &mut inbound_reader, + &mut outbound_writer, + &mut inbound_buffer, + &mut tx_bytes, + &mut inbound_stream_eos) => result, + result = Self::transfer_data_with_timeout( + &mut outbound_reader, + &mut inbound_writer, + &mut outbound_buffer, + &mut rx_bytes, + &mut outbound_stream_eos) => result, + } + } else if !outbound_stream_eos { + Self::transfer_data_with_timeout( + &mut outbound_reader, + &mut inbound_writer, + &mut outbound_buffer, + &mut rx_bytes, + &mut inbound_stream_eos, + ) + .await + } else { + Self::transfer_data_with_timeout( + &mut inbound_reader, + &mut outbound_writer, + &mut inbound_buffer, + &mut tx_bytes, + &mut outbound_stream_eos, + ) + .await + }; - Ok(ProxyTraffic { rx_bytes, tx_bytes }) + match result { + Err(ProxyError::Timeout) | Ok(0) => { + if outbound_stream_eos || inbound_stream_eos { + break; + } + } + Err(_) => break, + Ok(_) => {} } - Err(e) => Err(ProxyError::Disconnected(anyhow!(e))), - }; + } + + debug!( + "data [{:<8}] = ↑ {:<10} ↓ {:<10} {} ↔ {}", + loop_count, + tx_bytes, + rx_bytes, + inbound_stream.peer_addr().unwrap(), + outbound_stream.local_addr().unwrap(), + ); + + stats_sender + .send(ServerStats::Traffic(ProxyTraffic { tx_bytes, rx_bytes })) + .await + .ok(); stats_sender.send(ServerStats::CloseConnection).await.ok(); - result + Ok(ProxyTraffic { rx_bytes, tx_bytes }) + } + + async fn transfer_data_with_timeout( + reader: &mut R, + writer: &mut W, + buffer: &mut [u8], + out_bytes: &mut u64, + eos_flag: &mut bool, + ) -> Result + where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, + { + match tokio::time::timeout(Duration::from_secs(15), reader.read(buffer)) + .await + .map_err(|_: Elapsed| ProxyError::Timeout)? + { + Ok(0) => { + *eos_flag = true; + Ok(0) + } + Ok(n) => { + *out_bytes += n as u64; + writer + .write_all(&buffer[..n]) + .await + .map_err(|_| ProxyError::InternalError)?; + Ok(n) + } + Err(_) => Err(ProxyError::InternalError), // Connection mostly reset by peer + } } fn collect_and_report_server_stats(&self, mut stats_receiver: Receiver) {