Skip to content

Commit

Permalink
set nodelay for tcp streams
Browse files Browse the repository at this point in the history
  • Loading branch information
neevek committed May 12, 2024
1 parent 9631e27 commit bac2770
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 37 deletions.
5 changes: 5 additions & 0 deletions src/bin/omnip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ fn main() -> Result<()> {
args.proxy_rules_file,
args.threads,
args.watch_proxy_rules_change,
args.tcp_nodelay,
)?;

let common_quic_config = CommonQuicConfig {
Expand Down Expand Up @@ -98,6 +99,10 @@ struct OmnipArgs {
#[arg(short = 'R', long, default_value = "5000")]
retry_interval_ms: u64,

/// set TCP_NODELAY
#[arg(long, action)]
tcp_nodelay: bool,

/// reload proxy rules if updated
#[arg(short = 'w', long, action)]
watch_proxy_rules_change: bool,
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ pub struct Config {
pub dot_server: String,
pub name_servers: String,
pub watch_proxy_rules_change: bool,
pub tcp_nodelay: bool,
}

#[derive(Debug)]
Expand Down Expand Up @@ -427,6 +428,7 @@ pub fn create_config(
proxy_rules_file: String,
threads: usize,
watch_proxy_rules_change: bool,
tcp_nodelay: bool,
) -> Result<Config> {
let (server_type, orig_server_addr, is_layered_proto) = parse_server_addr(addr.as_str());

Expand Down Expand Up @@ -476,6 +478,7 @@ pub fn create_config(
dot_server,
name_servers,
watch_proxy_rules_change,
tcp_nodelay,
})
}

Expand Down Expand Up @@ -532,6 +535,7 @@ pub mod android {
jmaxIdleTimeoutMs: jint,
jretryIntervalMs: jint,
jthreads: jint,
jtcpNoDelay: jboolean,
) -> jlong {
if jaddr.is_null() {
return 0;
Expand All @@ -555,6 +559,8 @@ pub mod android {
proxy_rules_file,
jthreads as usize,
false,
true,
jtcpNoDelay as bool,
) {
Ok(config) => config,
Err(e) => {
Expand Down
90 changes: 53 additions & 37 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ impl Server {
}

async fn run_internal(self: &mut Arc<Self>) -> Result<()> {
info!("tcp_nodelay:{}", self.config.tcp_nodelay);
self.set_and_post_server_state(ServerState::Preparing);

// start the dashboard server
Expand Down Expand Up @@ -338,6 +339,7 @@ impl Server {
dashboard_addr,
proxy_rule_manager: inner_state!(self, proxy_rule_manager).clone(),
stats_sender,
tcp_nodelay: self.config.tcp_nodelay,
});

loop {
Expand All @@ -346,31 +348,31 @@ impl Server {
let psp = psp.clone();
let (prefer_upstream, upstream, dns_resolver) =
copy_inner_state!(self, prefer_upstream, upstream, dns_resolver);
if psp.tcp_nodelay {
inbound_stream
.set_nodelay(true)
.map_err(|e| error!("failed to call set_nodelay: {e}"))
.ok();
}

tokio::spawn(async move {
if let Some(ProtoType::Tcp) = psp.server_type {
if upstream.is_none() {
error!("tcp connection requires an upstream");
return;
}

match TcpStream::connect(upstream.unwrap()).await {
Ok(outbound_stream) => {
Self::start_stream_transfer(
inbound_stream,
outbound_stream,
&psp.stats_sender,
)
.await
.ok();
}
Err(e) => {
error!(
"failed to connect to upstream: {}, err: {e}",
upstream.unwrap()
);
}
};

if let Some(outbound_stream) =
Self::create_tcp_stream(upstream.unwrap(), psp.tcp_nodelay).await
{
Self::start_stream_transfer(
inbound_stream,
outbound_stream,
&psp.stats_sender,
)
.await
.ok();
}
return;
}

Expand Down Expand Up @@ -529,16 +531,8 @@ impl Server {
outbound_type, upstream
);

outbound_stream = match TcpStream::connect(upstream.unwrap()).await {
Ok(stream) => Some(stream),
Err(e) => {
error!(
"failed to connect to upstream: {}, err: {e}",
upstream.unwrap()
);
None
}
}
outbound_stream =
Self::create_tcp_stream(upstream.unwrap(), params.tcp_nodelay).await;
}

_ => {}
Expand Down Expand Up @@ -571,13 +565,17 @@ impl Server {
outbound_stream = Self::connect_to_dashboard(
params.dashboard_addr,
&inbound_stream,
params.tcp_nodelay,
)
.await?;
break;
}

let stream =
Self::create_tcp_stream(SocketAddr::new(ip, addr.port)).await;
let stream = Self::create_tcp_stream(
SocketAddr::new(ip, addr.port),
params.tcp_nodelay,
)
.await;
if stream.is_some() {
outbound_stream = stream;
break;
Expand All @@ -590,10 +588,18 @@ impl Server {
Host::IP(ip) => {
let inbound_addr = inbound_stream.local_addr().unwrap();
if ip == &inbound_addr.ip() && addr.port == inbound_addr.port() {
Self::connect_to_dashboard(params.dashboard_addr, &inbound_stream)
.await?
Self::connect_to_dashboard(
params.dashboard_addr,
&inbound_stream,
params.tcp_nodelay,
)
.await?
} else {
Self::create_tcp_stream(addr.to_socket_addr().unwrap()).await
Self::create_tcp_stream(
addr.to_socket_addr().unwrap(),
params.tcp_nodelay,
)
.await
}
}
};
Expand Down Expand Up @@ -628,11 +634,12 @@ impl Server {
async fn connect_to_dashboard(
dashboard_addr: Option<SocketAddr>,
inbound_stream: &TcpStream,
tcp_nodelay: bool,
) -> Result<Option<TcpStream>, ProxyError> {
match dashboard_addr {
Some(addr) => {
debug!("dashboard request: {}", inbound_stream.peer_addr().unwrap());
Ok(Self::create_tcp_stream(addr).await)
Ok(Self::create_tcp_stream(addr, tcp_nodelay).await)
}
None => {
log::warn!(
Expand All @@ -644,19 +651,27 @@ impl Server {
}
}

async fn create_tcp_stream(addr: SocketAddr) -> Option<TcpStream> {
async fn create_tcp_stream(addr: SocketAddr, nodelay: bool) -> Option<TcpStream> {
if addr.ip().is_unspecified() {
error!("address is unspecified: {addr}");
return None;
}

TcpStream::connect(addr)
let stream = TcpStream::connect(addr)
.await
.map_err(|e| {
error!("failed to connect to address: {addr}, err: {e}");
e
})
.ok()
.ok()?;

if nodelay {
stream
.set_nodelay(true)
.map_err(|e| error!("failed to call set_nodelay: {e}"))
.ok();
}
Some(stream)
}

fn server_type_as_string(&self) -> String {
Expand Down Expand Up @@ -974,4 +989,5 @@ struct ProxySupportParams {
dashboard_addr: Option<SocketAddr>,
proxy_rule_manager: Option<ProxyRuleManager>,
stats_sender: Sender<ServerStats>,
tcp_nodelay: bool,
}

0 comments on commit bac2770

Please sign in to comment.