Skip to content

Commit

Permalink
Avoid closing bridges/sockets too early.
Browse files Browse the repository at this point in the history
Ensure that local addresses don't cause collisions.
  • Loading branch information
zlogic committed Jul 25, 2024
1 parent 51ef347 commit 7eb61a0
Showing 1 changed file with 70 additions and 61 deletions.
131 changes: 70 additions & 61 deletions src/network.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
error, fmt, io,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
Expand Down Expand Up @@ -27,6 +27,7 @@ pub struct Network<'a> {
bridges: HashMap<iface::SocketHandle, SocketTunnel>,
opening_connections:
HashMap<iface::SocketHandle, Option<oneshot::Sender<SocketConnectionResult>>>,
unique_ports: HashMap<UniqueConnection, iface::SocketHandle>,
canceled: bool,
cmd_sender: mpsc::Sender<Command>,
cmd_receiver: mpsc::Receiver<Command>,
Expand Down Expand Up @@ -60,6 +61,7 @@ impl Network<'_> {
sockets,
bridges: HashMap::new(),
opening_connections: HashMap::new(),
unique_ports: HashMap::new(),
canceled: false,
cmd_sender,
cmd_receiver,
Expand All @@ -68,6 +70,7 @@ impl Network<'_> {

pub async fn run(&mut self) -> Result<(), NetworkError> {
while !self.canceled {
self.device.process_keepalive().await?;
self.copy_all_data();
let timestamp = smoltcp::time::Instant::now();
self.iface
Expand All @@ -76,7 +79,6 @@ impl Network<'_> {
Some(poll_delay) => Duration::from_micros(poll_delay.total_micros()),
None => MAX_POLL_INTERVAL,
};
self.device.process_keepalive().await?;
let timeout = if self.device.send_data().await? > 0 {
Duration::from_millis(0)
} else {
Expand All @@ -97,64 +99,54 @@ impl Network<'_> {
// doesn't work well with smoltcp - as smoltcp keeps ownership of most of its data, and any writes
// need to be guarded.
use socket::tcp;
self.bridges.iter().for_each(|(handle, tunnel)| {
let socket = self.sockets.get_mut::<tcp::Socket>(*handle);
if socket.can_send() {
let result = socket.send(|dest| match tunnel.reader.try_read(dest) {
Ok(bytes) => {
if bytes > 0 && dest.len() > 0 {
(bytes, Ok::<(), NetworkError>(()))
} else {
// Zero bytes means the stream is closed.
// TODO: add a custom handler for this error.
(0, Err("Proxy reader is closed".into()))
let closed_bridges = self
.bridges
.iter()
.filter_map(|(handle, tunnel)| {
let socket = self.sockets.get_mut::<tcp::Socket>(*handle);
if socket.can_send() {
let result = socket.send(|dest| match tunnel.reader.try_read(dest) {
Ok(bytes) => (bytes, Ok::<(), NetworkError>(())),
Err(err) => match err.kind() {
io::ErrorKind::WouldBlock => (0, Ok(())),
_ => (0, Err(err.into())),
},
});
if let Ok(result) = result {
if let Err(err) = result {
debug!("Failed to read data from SOCKS socket: {}", err);
return Some(handle.to_owned());
}
} else if let Err(err) = result {
// Not critical if socket is still opening.
warn!("Failed to send data to virtual socket: {}", err);
}
Err(err) => match err.kind() {
io::ErrorKind::WouldBlock => (0, Ok(())),
_ => (0, Err(err.into())),
},
});
if let Ok(result) = result {
if let Err(err) = result {
debug!("Failed to read data from SOCKS socket: {}", err);
socket.close();
return;
}
} else if let Err(err) = result {
// Not critical if socket is still opening.
warn!("Failed to send data to virtual socket: {}", err);
}
}

if socket.can_recv() {
let result = socket.recv(|src| match tunnel.writer.try_write(src) {
Ok(bytes) => (bytes, Ok(())),
Err(err) => match err.kind() {
io::ErrorKind::WouldBlock => (0, Ok(())),
_ => (0, Err(err)),
},
});
if let Ok(result) = result {
if let Err(err) = result {
debug!("Failed to write data to SOCKS socket: {}", err);
socket.close();
return;
if socket.can_recv() {
let result = socket.recv(|src| match tunnel.writer.try_write(src) {
Ok(bytes) => (bytes, Ok(())),
Err(err) => match err.kind() {
io::ErrorKind::WouldBlock => (0, Ok(())),
_ => (0, Err(err)),
},
});
if let Ok(result) = result {
if let Err(err) = result {
debug!("Failed to write data to SOCKS socket: {}", err);
socket.close();
return Some(handle.to_owned());
}
} else if let Err(err) = result {
warn!("Failed to read data from virtual socket: {}", err);
}
} else if let Err(err) = result {
warn!("Failed to read data from virtual socket: {}", err);
}
}
});
None
})
.collect::<HashSet<_>>();

self.bridges.retain(|handle, _| {
let socket = self.sockets.get_mut::<tcp::Socket>(*handle);
if socket.is_open() {
return true;
}
self.sockets.remove(*handle);
false
});
self.bridges
.retain(|handle, _| !closed_bridges.contains(handle));
let closed_sockets = self
.sockets
.iter()
Expand All @@ -166,7 +158,7 @@ impl Network<'_> {
None
}
})
.collect::<Vec<_>>();
.collect::<HashSet<_>>();
closed_sockets.into_iter().for_each(|handle| {
self.sockets.remove(handle);
});
Expand Down Expand Up @@ -207,6 +199,10 @@ impl Network<'_> {
}
self.opening_connections
.retain(|_, response| response.is_some());

self.unique_ports.retain(|_, handle| {
self.opening_connections.contains_key(handle) || self.bridges.contains_key(handle)
});
}

pub fn create_command_sender(&self) -> mpsc::Sender<Command> {
Expand All @@ -228,12 +224,20 @@ impl Network<'_> {
let rx_buffer = tcp::SocketBuffer::new(vec![0; SOCKET_BUFFER_SIZE]);
let tx_buffer = tcp::SocketBuffer::new(vec![0; SOCKET_BUFFER_SIZE]);
let mut socket = tcp::Socket::new(rx_buffer, tx_buffer);
// TODO: check for collisions.
let local_port = rand::thread_rng().gen_range(49152..=65535);

let mut unique_connection = UniqueConnection {
remote_addr: addr,
local_port: rand::thread_rng().gen_range(49152..=65535),
};
while self.unique_ports.contains_key(&unique_connection) {
unique_connection.local_port = rand::thread_rng().gen_range(49152..=65535);
}
let remote_addr = wire::IpAddress::from(addr.ip());
if let Err(err) =
socket.connect(self.iface.context(), (remote_addr, addr.port()), local_port)
{
if let Err(err) = socket.connect(
self.iface.context(),
(remote_addr, addr.port()),
(self.device.vpn.ip_addr(), unique_connection.local_port),
) {
if let Err(_) = response.send(Err(err.into())) {
debug!("Proxy listener not listening for response");
}
Expand All @@ -243,6 +247,7 @@ impl Network<'_> {
let socket_handle = self.sockets.add(socket);
self.opening_connections
.insert(socket_handle, Some(response));
self.unique_ports.insert(unique_connection, socket_handle);
}
Command::Bridge(socket_handle, reader, writer) => {
let socket_tunnel = SocketTunnel { reader, writer };
Expand All @@ -260,13 +265,17 @@ impl Network<'_> {
}
}

#[derive(Eq, PartialEq, Hash)]
struct UniqueConnection {
remote_addr: SocketAddr,
local_port: u16,
}

struct SocketTunnel {
reader: tokio::net::tcp::OwnedReadHalf,
writer: tokio::net::tcp::OwnedWriteHalf,
}

impl SocketTunnel {}

type SocketConnectionResult = Result<(iface::SocketHandle, SocketAddr), NetworkError>;

pub enum Command {
Expand Down

0 comments on commit 7eb61a0

Please sign in to comment.