From ad67654f7f8cd5af08ec12cb635c2925dfbdd893 Mon Sep 17 00:00:00 2001 From: XOR-op <17672363+XOR-op@users.noreply.github.com> Date: Tue, 9 Jan 2024 14:03:53 -0500 Subject: [PATCH] feat: better flow control This commit includes sized channel, allowing bounded memory usage and more stable wireguard performance. --- boltconn/src/adapter/wireguard.rs | 11 +++++-- boltconn/src/app.rs | 4 +-- boltconn/src/network/tun_device.rs | 2 +- boltconn/src/transport/smol.rs | 50 ++++++++++++++++++++---------- 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/boltconn/src/adapter/wireguard.rs b/boltconn/src/adapter/wireguard.rs index 5c1401f..8bc5374 100644 --- a/boltconn/src/adapter/wireguard.rs +++ b/boltconn/src/adapter/wireguard.rs @@ -48,7 +48,7 @@ impl Endpoint { // control conn let (stop_send, mut stop_recv) = broadcast::channel(1); - let (mut wg_smol_tx, wg_smol_rx) = flume::unbounded(); + let (mut wg_smol_tx, wg_smol_rx) = flume::bounded(4096); let (smol_wg_tx, mut smol_wg_rx) = flume::unbounded(); let tunnel = Arc::new( WireguardTunnel::new(outbound, config, endpoint_resolver, notify.clone()).await?, @@ -175,10 +175,17 @@ impl Endpoint { immediate_next_loop |= stack_handle.poll_all_udp().await; stack_handle.purge_closed_tcp(); stack_handle.purge_timeout_udp(); + let wait_time = if immediate_next_loop { + Duration::from_secs(0) + } else { + stack_handle + .suggested_wait_time() + .unwrap_or(Duration::from_secs(3)) + }; drop(stack_handle); if !immediate_next_loop { select! { - _ = tokio::time::sleep(Duration::from_secs(3)) =>{} + _ = tokio::time::sleep(wait_time) =>{} _ = notifier.notified() =>{} } } diff --git a/boltconn/src/app.rs b/boltconn/src/app.rs index 4db3482..27ca16f 100644 --- a/boltconn/src/app.rs +++ b/boltconn/src/app.rs @@ -116,8 +116,8 @@ impl App { }; // Create TUN - let (tun_udp_tx, tun_udp_rx) = flume::unbounded(); - let (udp_tun_tx, udp_tun_rx) = flume::unbounded(); + let (tun_udp_tx, tun_udp_rx) = flume::bounded(4096); + let (udp_tun_tx, udp_tun_rx) = flume::bounded(4096); let tun = { let mut tun = TunDevice::open( manager.clone(), diff --git a/boltconn/src/network/tun_device.rs b/boltconn/src/network/tun_device.rs index e9e7ef3..ec36e82 100644 --- a/boltconn/src/network/tun_device.rs +++ b/boltconn/src/network/tun_device.rs @@ -273,7 +273,7 @@ impl TunDevice { let start_offset = 0; pkt.into_bytes_mut().freeze().slice(start_offset..) }; - let _ = self.udp_tx.send(pkt); + let _ = self.udp_tx.send_async(pkt).await; } } IpProtocol::Icmp => { diff --git a/boltconn/src/transport/smol.rs b/boltconn/src/transport/smol.rs index 4643d87..b6ad63e 100644 --- a/boltconn/src/transport/smol.rs +++ b/boltconn/src/transport/smol.rs @@ -57,11 +57,11 @@ impl TcpConnTask { tx: back_tx, rx: mut back_rx, } = connector; - let (tx, rx) = flume::unbounded(); + let (tx, rx) = flume::bounded(4096); // notify smol when new message comes tokio::spawn(async move { while let Some(buf) = back_rx.recv().await { - let _ = tx.send(buf); + let _ = tx.send_async(buf).await; notify.notify_one(); } }); @@ -77,7 +77,7 @@ impl TcpConnTask { pub fn try_send(&mut self, socket: &mut SmolTcpSocket<'_>) -> Result { let mut has_activity = false; // Send data - if socket.can_send() { + while socket.can_send() { if let Some((buf, start)) = self.remain_to_send.take() { if let Ok(sent) = socket.send_slice(&buf.as_ref()[start..]) { // successfully sent @@ -102,7 +102,7 @@ impl TcpConnTask { return Err(SmolError::Aborted); } } - Err(TryRecvError::Empty) => {} + Err(TryRecvError::Empty) => break, Err(TryRecvError::Disconnected) => return Err(SmolError::Disconnected), } } @@ -115,16 +115,17 @@ impl TcpConnTask { pub async fn try_recv(&self, socket: &mut SmolTcpSocket<'_>) -> bool { // Receive data - if socket.can_recv() && self.back_tx.capacity() > 0 { + let mut has_activity = false; + while socket.can_recv() && self.back_tx.capacity() > 0 { let mut buf = BytesMut::with_capacity(MAX_PKT_SIZE); if let Ok(size) = socket.recv_slice(unsafe { mut_buf(&mut buf) }) { unsafe { buf.advance_mut(size) }; // must not fail because there is only 1 sender let _ = self.back_tx.send(buf.freeze()).await; - return true; + has_activity = true; } } - false + has_activity } } @@ -148,7 +149,7 @@ impl UdpConnTask { tx: back_tx, rx: mut back_rx, } = connector; - let (tx, rx) = flume::unbounded(); + let (tx, rx) = flume::bounded(4096); tokio::spawn(async move { while let Some((buf, dst)) = back_rx.recv().await { if let Some(dst) = match dst { @@ -158,7 +159,7 @@ impl UdpConnTask { .await .map(|ip| SocketAddr::new(ip, port)), } { - let _ = tx.send((buf, dst)); + let _ = tx.send_async((buf, dst)).await; notify.notify_one(); } } @@ -175,7 +176,7 @@ impl UdpConnTask { pub fn try_send(&mut self, socket: &mut SmolUdpSocket<'_>) -> Result { let mut has_activity = false; // Send data - if socket.can_send() { + while socket.can_send() { // fetch new data match self.rx.try_recv() { // todo: full-cone NAT @@ -190,7 +191,7 @@ impl UdpConnTask { return Err(SmolError::Aborted); } } - Err(TryRecvError::Empty) => {} + Err(TryRecvError::Empty) => break, Err(TryRecvError::Disconnected) => return Err(SmolError::Disconnected), } } @@ -199,7 +200,8 @@ impl UdpConnTask { pub async fn try_recv(&mut self, socket: &mut SmolUdpSocket<'_>) -> bool { // Receive data - if socket.can_recv() && self.back_tx.capacity() > 0 { + let mut has_activity = false; + while socket.can_recv() && self.back_tx.capacity() > 0 { let mut buf = BytesMut::with_capacity(MAX_PKT_SIZE); if let Ok((size, ep)) = socket.recv_slice(unsafe { mut_buf(&mut buf) }) { unsafe { buf.advance_mut(size) }; @@ -208,11 +210,11 @@ impl UdpConnTask { NetworkAddr::Raw(SocketAddr::new(ep.endpoint.addr.into(), ep.endpoint.port)); // must not fail because there is only 1 sender let _ = self.back_tx.send((buf.freeze(), src_addr)).await; - return true; + has_activity = true; // discard mismatched packet } } - false + has_activity } } @@ -270,6 +272,12 @@ impl SmolStack { ) } + pub fn suggested_wait_time(&mut self) -> Option { + self.iface + .poll_delay(SmolInstant::now(), &self.socket_set) + .map(|du| Duration::from_micros(du.total_micros())) + } + pub fn get_dns(&self) -> Arc> { self.dns.clone() } @@ -546,9 +554,17 @@ impl Device for VirtualIpDevice { } fn transmit(&mut self, _timestamp: SmolInstant) -> Option> { - Some(VirtualTxToken { - sender: self.outbound.clone(), - }) + if self + .outbound + .capacity() + .map_or(self.outbound.len() < 4096, |cap| cap > self.outbound.len()) + { + Some(VirtualTxToken { + sender: self.outbound.clone(), + }) + } else { + None + } } fn capabilities(&self) -> DeviceCapabilities {