Skip to content

Commit

Permalink
feat: better flow control
Browse files Browse the repository at this point in the history
This commit includes sized channel, allowing bounded memory usage and more stable wireguard performance.
  • Loading branch information
XOR-op committed Jan 9, 2024
1 parent ed0f67b commit ad67654
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 22 deletions.
11 changes: 9 additions & 2 deletions boltconn/src/adapter/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl Endpoint {
// control conn
let (stop_send, mut stop_recv) = broadcast::channel(1);

let (mut wg_smol_tx, wg_smol_rx) = flume::unbounded();
let (mut wg_smol_tx, wg_smol_rx) = flume::bounded(4096);
let (smol_wg_tx, mut smol_wg_rx) = flume::unbounded();
let tunnel = Arc::new(
WireguardTunnel::new(outbound, config, endpoint_resolver, notify.clone()).await?,
Expand Down Expand Up @@ -175,10 +175,17 @@ impl Endpoint {
immediate_next_loop |= stack_handle.poll_all_udp().await;
stack_handle.purge_closed_tcp();
stack_handle.purge_timeout_udp();
let wait_time = if immediate_next_loop {
Duration::from_secs(0)
} else {
stack_handle
.suggested_wait_time()
.unwrap_or(Duration::from_secs(3))
};
drop(stack_handle);
if !immediate_next_loop {
select! {
_ = tokio::time::sleep(Duration::from_secs(3)) =>{}
_ = tokio::time::sleep(wait_time) =>{}
_ = notifier.notified() =>{}
}
}
Expand Down
4 changes: 2 additions & 2 deletions boltconn/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ impl App {
};

// Create TUN
let (tun_udp_tx, tun_udp_rx) = flume::unbounded();
let (udp_tun_tx, udp_tun_rx) = flume::unbounded();
let (tun_udp_tx, tun_udp_rx) = flume::bounded(4096);
let (udp_tun_tx, udp_tun_rx) = flume::bounded(4096);
let tun = {
let mut tun = TunDevice::open(
manager.clone(),
Expand Down
2 changes: 1 addition & 1 deletion boltconn/src/network/tun_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl TunDevice {
let start_offset = 0;
pkt.into_bytes_mut().freeze().slice(start_offset..)
};
let _ = self.udp_tx.send(pkt);
let _ = self.udp_tx.send_async(pkt).await;
}
}
IpProtocol::Icmp => {
Expand Down
50 changes: 33 additions & 17 deletions boltconn/src/transport/smol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ impl TcpConnTask {
tx: back_tx,
rx: mut back_rx,
} = connector;
let (tx, rx) = flume::unbounded();
let (tx, rx) = flume::bounded(4096);
// notify smol when new message comes
tokio::spawn(async move {
while let Some(buf) = back_rx.recv().await {
let _ = tx.send(buf);
let _ = tx.send_async(buf).await;
notify.notify_one();
}
});
Expand All @@ -77,7 +77,7 @@ impl TcpConnTask {
pub fn try_send(&mut self, socket: &mut SmolTcpSocket<'_>) -> Result<bool, SmolError> {
let mut has_activity = false;
// Send data
if socket.can_send() {
while socket.can_send() {
if let Some((buf, start)) = self.remain_to_send.take() {
if let Ok(sent) = socket.send_slice(&buf.as_ref()[start..]) {
// successfully sent
Expand All @@ -102,7 +102,7 @@ impl TcpConnTask {
return Err(SmolError::Aborted);
}
}
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => return Err(SmolError::Disconnected),
}
}
Expand All @@ -115,16 +115,17 @@ impl TcpConnTask {

pub async fn try_recv(&self, socket: &mut SmolTcpSocket<'_>) -> bool {
// Receive data
if socket.can_recv() && self.back_tx.capacity() > 0 {
let mut has_activity = false;
while socket.can_recv() && self.back_tx.capacity() > 0 {
let mut buf = BytesMut::with_capacity(MAX_PKT_SIZE);
if let Ok(size) = socket.recv_slice(unsafe { mut_buf(&mut buf) }) {
unsafe { buf.advance_mut(size) };
// must not fail because there is only 1 sender
let _ = self.back_tx.send(buf.freeze()).await;
return true;
has_activity = true;
}
}
false
has_activity
}
}

Expand All @@ -148,7 +149,7 @@ impl UdpConnTask {
tx: back_tx,
rx: mut back_rx,
} = connector;
let (tx, rx) = flume::unbounded();
let (tx, rx) = flume::bounded(4096);
tokio::spawn(async move {
while let Some((buf, dst)) = back_rx.recv().await {
if let Some(dst) = match dst {
Expand All @@ -158,7 +159,7 @@ impl UdpConnTask {
.await
.map(|ip| SocketAddr::new(ip, port)),
} {
let _ = tx.send((buf, dst));
let _ = tx.send_async((buf, dst)).await;
notify.notify_one();
}
}
Expand All @@ -175,7 +176,7 @@ impl UdpConnTask {
pub fn try_send(&mut self, socket: &mut SmolUdpSocket<'_>) -> Result<bool, SmolError> {
let mut has_activity = false;
// Send data
if socket.can_send() {
while socket.can_send() {
// fetch new data
match self.rx.try_recv() {
// todo: full-cone NAT
Expand All @@ -190,7 +191,7 @@ impl UdpConnTask {
return Err(SmolError::Aborted);
}
}
Err(TryRecvError::Empty) => {}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => return Err(SmolError::Disconnected),
}
}
Expand All @@ -199,7 +200,8 @@ impl UdpConnTask {

pub async fn try_recv(&mut self, socket: &mut SmolUdpSocket<'_>) -> bool {
// Receive data
if socket.can_recv() && self.back_tx.capacity() > 0 {
let mut has_activity = false;
while socket.can_recv() && self.back_tx.capacity() > 0 {
let mut buf = BytesMut::with_capacity(MAX_PKT_SIZE);
if let Ok((size, ep)) = socket.recv_slice(unsafe { mut_buf(&mut buf) }) {
unsafe { buf.advance_mut(size) };
Expand All @@ -208,11 +210,11 @@ impl UdpConnTask {
NetworkAddr::Raw(SocketAddr::new(ep.endpoint.addr.into(), ep.endpoint.port));
// must not fail because there is only 1 sender
let _ = self.back_tx.send((buf.freeze(), src_addr)).await;
return true;
has_activity = true;
// discard mismatched packet
}
}
false
has_activity
}
}

Expand Down Expand Up @@ -270,6 +272,12 @@ impl SmolStack {
)
}

pub fn suggested_wait_time(&mut self) -> Option<Duration> {
self.iface
.poll_delay(SmolInstant::now(), &self.socket_set)
.map(|du| Duration::from_micros(du.total_micros()))
}

pub fn get_dns(&self) -> Arc<GenericDns<SmolDnsProvider>> {
self.dns.clone()
}
Expand Down Expand Up @@ -546,9 +554,17 @@ impl Device for VirtualIpDevice {
}

fn transmit(&mut self, _timestamp: SmolInstant) -> Option<Self::TxToken<'_>> {
Some(VirtualTxToken {
sender: self.outbound.clone(),
})
if self
.outbound
.capacity()
.map_or(self.outbound.len() < 4096, |cap| cap > self.outbound.len())
{
Some(VirtualTxToken {
sender: self.outbound.clone(),
})
} else {
None
}
}

fn capabilities(&self) -> DeviceCapabilities {
Expand Down

0 comments on commit ad67654

Please sign in to comment.