diff --git a/forwarder/src/socket/icmp.rs b/forwarder/src/socket/icmp.rs index 3964149..34e79a5 100644 --- a/forwarder/src/socket/icmp.rs +++ b/forwarder/src/socket/icmp.rs @@ -1,38 +1,20 @@ mod ether_helper; mod receiver; -use crate::MAX_PACKET_SIZE; - use super::SocketTrait; +use crate::MAX_PACKET_SIZE; use etherparse::{IcmpEchoHeader, Icmpv4Header, Icmpv4Type, Icmpv6Header, Icmpv6Type}; +use mio::{unix::SourceFd, Interest}; use parking_lot::{Mutex, RwLock}; use socket2::{Domain, Protocol, Type}; use std::{ - collections::{BTreeMap, VecDeque}, + collections::BTreeSet, io, mem::MaybeUninit, net::{SocketAddr, SocketAddrV6}, os::fd::AsRawFd, - sync::Arc, }; -/// Represents single packet -#[derive(Debug)] -struct Packet { - data: Vec, - from_addr: SocketAddr, -} - -/// Thread safe buffer for `Packet`s -type SharedPacketBuffer = Mutex>; - -/// `Controller` is passed to `IcmpReceiver` so it can communicate to `IcmpSocket` -#[derive(Debug)] -struct Controller { - packets: Arc, - waker: mio::Waker, -} - /// `IcmpSocket` that is very similiar to `UdpSocket` #[derive(Debug)] pub struct IcmpSocket { @@ -41,23 +23,27 @@ pub struct IcmpSocket { is_blocking: bool, /// udp socket that is kept alive for avoiding duplicate port udp_socket: std::net::UdpSocket, + /// address of udp socket + udp_socket_addr: SocketAddr, /// saves the socket that is connected to connected_addr: Option, - /// each `IcmpSocket` does not actually listen for new packets because - /// icmp protocol is on layer 2 and doesn't have any concept of ports - /// so each packet will wake up all `IcmpSocket`s, to fix that and remove - /// overheads of parsing each packet multiple times we listen to packets - /// only on one socket on another thread and after parsing port and packet - /// we put it in the corresponding controller `packets` - packets: Arc, } static IS_RECEIVER_STARTED: Mutex = Mutex::new(false); -static OPEN_PORTS: RwLock> = RwLock::new(BTreeMap::new()); + +/// each nonblocking `IcmpSocket` does not actually listen for new packets because +/// icmp protocol is on layer 2 and doesn't have any concept of ports +/// so each packet will wake up all `IcmpSocket`s, to fix that and remove +/// overheads of parsing each packet multiple times we listen to packets +/// only on one socket on another thread and after parsing port and packet +/// we put it in the corresponding controller `packets`, each nonblocking +/// `IcmpSocket` can register it's port via adding it to `OPEN_PORTS` +static OPEN_PORTS: RwLock> = RwLock::new(BTreeSet::new()); impl IcmpSocket { pub fn bind(addr: &SocketAddr) -> io::Result { let udp_socket = std::net::UdpSocket::bind(addr)?; + let udp_socket_addr = udp_socket.local_addr()?; let socket = IcmpSocket::inner_bind(*addr)?; // run the icmp receiver if it isn't running @@ -73,11 +59,10 @@ impl IcmpSocket { *is_receiver_alive = true; } - let packets = Mutex::new(VecDeque::with_capacity(10)); Ok(IcmpSocket { udp_socket, + udp_socket_addr, socket, - packets: Arc::new(packets), connected_addr: None, is_blocking: true, }) @@ -99,15 +84,24 @@ impl Drop for IcmpSocket { fn drop(&mut self) { // clear port let mut open_ports = OPEN_PORTS.write(); - let port = self.udp_socket.local_addr().unwrap().port(); - open_ports.remove(&port); + open_ports.remove(&self.udp_socket_addr.port()); } } impl SocketTrait for IcmpSocket { fn recv(&self, buffer: &mut [u8]) -> io::Result { - let (size, _) = self.recv_from(buffer)?; - Ok(size) + if self.is_blocking { + unimplemented!("currently IcmpSocket::recv in blocking mode is not being used") + } + // icmp receiver sends packets that it receives to udp socket of `IcmpSocket` + let (size, from_addr) = self.udp_socket.recv_from(buffer)?; + // make sure that the receiver sent the packet + // receiver is local so the packet ip is from loopback + if from_addr.ip().is_loopback() { + Ok(size) + } else { + Err(io::ErrorKind::ConnectionRefused.into()) + } } fn send(&self, buffer: &[u8]) -> io::Result { @@ -134,60 +128,52 @@ impl SocketTrait for IcmpSocket { } fn recv_from(&self, buffer: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - if self.is_blocking { - let mut second_buffer = [0u8; MAX_PACKET_SIZE]; - let local_addr = self.local_addr()?; - loop { - let (size, from_addr) = self.socket.recv_from(unsafe { - &mut *(&mut second_buffer as *mut [u8] as *mut [MaybeUninit]) - })?; - let Some(packet) = - receiver::parse_icmp_packet(&second_buffer[..size], local_addr.is_ipv6()) - else { - continue; - }; - if packet.dst_port != local_addr.port() { - continue; - } - let payload_len = packet.payload.len(); - buffer[..payload_len].copy_from_slice(packet.payload); - - let mut from_addr = from_addr.as_socket().unwrap(); - from_addr.set_port(packet.src_port); - return Ok((payload_len, from_addr)); - } - } else { - let mut packets = self.packets.lock(); - match packets.pop_front() { - Some(packet) => { - let len = packet.data.len(); - buffer[..len].copy_from_slice(&packet.data); - Ok((len, packet.from_addr)) - } - None => Err(io::ErrorKind::WouldBlock.into()), + if !self.is_blocking { + unimplemented!("currently IcmpSocket::recv_from in nonblocking mode is not being used") + } + let mut second_buffer = [0u8; MAX_PACKET_SIZE]; + let local_addr = self.local_addr()?; + loop { + let (size, from_addr) = self.socket.recv_from(unsafe { + &mut *(&mut second_buffer as *mut [u8] as *mut [MaybeUninit]) + })?; + let Some(packet) = + receiver::parse_icmp_packet(&second_buffer[..size], local_addr.is_ipv6()) + else { + continue; + }; + if packet.dst_port != local_addr.port() { + continue; } + let payload_len = packet.payload.len(); + buffer[..payload_len].copy_from_slice(packet.payload); + + let mut from_addr = from_addr.as_socket().unwrap(); + from_addr.set_port(packet.src_port); + return Ok((payload_len, from_addr)); } } fn local_addr(&self) -> io::Result { - self.udp_socket.local_addr() + Ok(self.udp_socket_addr) } fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { self.socket.set_nonblocking(nonblocking)?; + self.udp_socket.set_nonblocking(nonblocking)?; self.is_blocking = !nonblocking; Ok(()) } fn register(&mut self, registry: &mio::Registry, token: mio::Token) -> io::Result<()> { - let waker = mio::Waker::new(registry, token)?; let mut open_ports = OPEN_PORTS.write(); - let port = self.local_addr()?.port(); - let controller = Controller { - packets: self.packets.clone(), - waker, - }; - open_ports.insert(port, controller); + open_ports.insert(self.udp_socket_addr.port()); + + registry.register( + &mut SourceFd(&self.udp_socket.as_raw_fd()), + token, + Interest::READABLE, + )?; Ok(()) } } diff --git a/forwarder/src/socket/icmp/receiver.rs b/forwarder/src/socket/icmp/receiver.rs index 31130e5..cde2153 100644 --- a/forwarder/src/socket/icmp/receiver.rs +++ b/forwarder/src/socket/icmp/receiver.rs @@ -1,47 +1,33 @@ -use super::{ether_helper::IcmpSlice, IcmpSocket, Packet, OPEN_PORTS}; +use super::{ether_helper::IcmpSlice, IcmpSocket, OPEN_PORTS}; use crate::MAX_PACKET_SIZE; use etherparse::Ipv4HeaderSlice; -use socket2::SockAddr; use std::{mem::MaybeUninit, net::SocketAddr}; pub fn run_icmp_receiver(addr: SocketAddr) -> anyhow::Result<()> { let is_ipv6 = addr.is_ipv6(); let socket: socket2::Socket = IcmpSocket::inner_bind(addr)?; + let udp_socket = std::net::UdpSocket::bind(SocketAddr::new(addr.ip(), 0))?; + udp_socket.set_nonblocking(true)?; + let mut buffer = [0u8; MAX_PACKET_SIZE]; + let mut addr_buffer = addr; loop { - let Ok((size, from_addr)) = - socket.recv_from(unsafe { &mut *(&mut buffer as *mut [u8] as *mut [MaybeUninit]) }) + let Ok(size) = + socket.recv(unsafe { &mut *(&mut buffer as *mut [u8] as *mut [MaybeUninit]) }) else { continue; }; - - let Some(packet) = parse_icmp_packet(&buffer[..size], is_ipv6) else { + let Some(icmp_packet) = parse_icmp_packet(&buffer[..size], is_ipv6) else { continue; }; - handle_packet(packet, from_addr); - } -} - -fn handle_packet(icmp: IcmpPacket<'_>, from_addr: SockAddr) -> Option<()> { - let open_ports = OPEN_PORTS.write(); - let controller = open_ports.get(&icmp.dst_port)?; - - let mut source_addr = from_addr.as_socket().unwrap(); - source_addr.set_port(icmp.src_port); - - let packet = Packet { - data: icmp.payload.to_vec(), - from_addr: source_addr, - }; - { - let mut packets = controller.packets.lock(); - packets.push_back(packet); - } - if let Err(error) = controller.waker.wake() { - log::warn!("couldn't wake up icmp socket: {error:?}") + let open_ports = OPEN_PORTS.read(); + let port = icmp_packet.dst_port; + if open_ports.contains(&port) { + addr_buffer.set_port(port); + udp_socket.send_to(icmp_packet.payload, addr_buffer).ok(); + } } - Some(()) } pub struct IcmpPacket<'a> {