From 55f914e153a6f0947148dd3cbebc821a5a5b9a00 Mon Sep 17 00:00:00 2001 From: Dmitry Zolotukhin Date: Sun, 1 Dec 2024 20:01:17 +0100 Subject: [PATCH] Completed refactoring from previous commit. Instead of building queues, use a poll function to check status of all futures: * No queue to overfill (or self-lock) * Use a stack-based buffer instead of allocating memory for each packet (no-alloc for most of the IP traffic) * More direct calls when sending responses. --- src/futures.rs | 41 --- src/ikev2/mod.rs | 881 +++++++++++++++++++++++------------------------ src/main.rs | 7 +- 3 files changed, 441 insertions(+), 488 deletions(-) delete mode 100644 src/futures.rs diff --git a/src/futures.rs b/src/futures.rs deleted file mode 100644 index b27ab65..0000000 --- a/src/futures.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::{ - future::{self, Future}, - task::Poll, -}; - -pub struct RoundRobinSelector { - a_first: bool, -} - -impl RoundRobinSelector { - pub fn new() -> RoundRobinSelector { - RoundRobinSelector { a_first: true } - } - - pub fn select( - &mut self, - a: impl Future, - b: impl Future, - ) -> impl Future { - let (mut a, mut b) = (Box::pin(a), Box::pin(b)); - let a_first = self.a_first; - self.a_first = !self.a_first; - future::poll_fn(move |cx| { - if a_first { - if let Poll::Ready(r) = a.as_mut().poll(cx) { - Poll::Ready(r) - } else if let Poll::Ready(r) = b.as_mut().poll(cx) { - Poll::Ready(r) - } else { - Poll::Pending - } - } else if let Poll::Ready(r) = b.as_mut().poll(cx) { - Poll::Ready(r) - } else if let Poll::Ready(r) = a.as_mut().poll(cx) { - Poll::Ready(r) - } else { - Poll::Pending - } - }) - } -} diff --git a/src/ikev2/mod.rs b/src/ikev2/mod.rs index 04dfa4f..a9ea28d 100644 --- a/src/ikev2/mod.rs +++ b/src/ikev2/mod.rs @@ -6,7 +6,7 @@ use std::{ future::{self, Future}, io, net::{IpAddr, Ipv4Addr, SocketAddr}, - pin::pin, + pin::{pin, Pin}, sync::Arc, task::Poll, time::{Duration, Instant}, @@ -15,7 +15,7 @@ use tokio::{ net::UdpSocket, runtime, sync::{mpsc, oneshot}, - task::{JoinHandle, JoinSet}, + task::JoinSet, time, }; @@ -52,32 +52,10 @@ pub struct Server { nat_port: u16, pki_processing: Arc, tunnel_domains: Vec, -} - -pub struct ServerHandle { - command_sender: mpsc::Sender, + cancel_sender: Option>, join_set: JoinSet>, } -impl ServerHandle { - pub async fn terminate(&mut self) -> Result<(), IKEv2Error> { - if self - .command_sender - .send(SessionMessage::Shutdown) - .await - .is_err() - { - return Err("Command channel closed".into()); - } - while let Some(res) = self.join_set.join_next().await { - if let Err(err) = res { - warn!("Error returned when shutting down: {}", err); - } - } - Ok(()) - } -} - impl Server { pub fn new(config: Config) -> Result { let pki_processing = pki::PkiProcessing::new( @@ -94,10 +72,12 @@ impl Server { nat_port: config.nat_port, pki_processing: Arc::new(pki_processing), tunnel_domains: config.tunnel_domains, + cancel_sender: None, + join_set: JoinSet::new(), }) } - async fn send_timer_ticks( + async fn send_cleanup_ticks( duration: Duration, dest: mpsc::Sender, ) -> Result<(), IKEv2Error> { @@ -111,39 +91,48 @@ impl Server { } } - pub async fn start( - self, - fortivpn_config: fortivpn::Config, - ) -> Result { + async fn send_echo_ticks(dest: mpsc::Sender) -> Result<(), IKEv2Error> { + let mut interval = fortivpn::echo_send_interval(); + loop { + interval.tick().await; + dest.send(SessionMessage::SendVpnKeepalive) + .await + .map_err(|_| "Channel closed")?; + } + } + + pub async fn terminate(&mut self) -> Result<(), IKEv2Error> { + match self.cancel_sender.take() { + Some(cancel_sender) => { + if cancel_sender.send(()).is_err() { + return Err("Cancel channel closed".into()); + } + } + None => return Err("Shutdown already in progress".into()), + } + while let Some(res) = self.join_set.join_next().await { + if let Err(err) = res { + warn!("Error returned when shutting down: {}", err); + } + } + Ok(()) + } + + pub async fn start(&mut self, fortivpn_config: fortivpn::Config) -> Result<(), IKEv2Error> { let sockets = Arc::new(Sockets::new(&self.listen_ips, self.port, self.nat_port).await?); let mut split_routes = SplitRouteRegistry::new(self.tunnel_domains.clone()); let (tunnel_ips, traffic_selectors) = split_routes.refresh_addresses().await?; let vpn_service = FortiService::new(fortivpn_config); - let mut sessions = Sessions::new( - self.pki_processing.clone(), - sockets.clone(), - vpn_service, - tunnel_ips, - traffic_selectors, - ); + let rt = runtime::Handle::current(); - // Non-critical futures sockets will be terminated by Tokio during the shutdown_timeout phase. - /* - sockets.iter_sockets().for_each(|(listen_addr, socket)| { - let is_nat_port = listen_addr.port() == self.nat_port; - rt.spawn(Server::listen_socket( - socket.clone(), - *listen_addr, - is_nat_port, - sessions.create_sender(), - )); - }); - */ - rt.spawn(Server::send_timer_ticks( + let (command_sender, command_receiver) = mpsc::channel(32); + // Non-critical futures will be terminated by Tokio during the shutdown_timeout phase. + rt.spawn(Server::send_cleanup_ticks( CLEANUP_INTERVAL, - sessions.create_sender(), + command_sender.clone(), )); - let command_sender = sessions.create_sender(); + rt.spawn(Server::send_echo_ticks(command_sender.clone())); + let routes_sender = command_sender.clone(); rt.spawn(async move { let mut delay = tokio::time::interval(SPLIT_TUNNEL_REFRESH_INTERVAL); delay.set_missed_tick_behavior(time::MissedTickBehavior::Skip); @@ -156,7 +145,7 @@ impl Server { continue; } }; - let _ = command_sender + let _ = routes_sender .send(SessionMessage::UpdateSplitRoutes( tunnel_ips, traffic_selectors, @@ -164,21 +153,223 @@ impl Server { .await; } }); - let command_sender = sessions.create_sender(); - let mut join_set = JoinSet::new(); - join_set.spawn_on( - async move { - loop { - let datagram = sockets.read_datagram().await?; - warn!("Received datagram {:?}", datagram.bytes); - } - }, + + let sessions = Sessions::new( + self.pki_processing.clone(), + command_sender.clone(), + sockets.clone(), + tunnel_ips, + traffic_selectors, + ); + let (cancel_sender, cancel_receiver) = oneshot::channel(); + self.cancel_sender = Some(cancel_sender); + rt.spawn(async move { + if cancel_receiver.await.is_ok() + && command_sender.send(SessionMessage::Shutdown).await.is_err() + { + warn!("Command channel closed"); + } + }); + + self.join_set.spawn_on( + Self::run(command_receiver, sockets, sessions, vpn_service), &rt, ); - //join_set.spawn_on(async move { sessions.process_messages().await }, &rt); - Ok(ServerHandle { - command_sender, - join_set, + Ok(()) + } + + async fn run( + mut command_receiver: mpsc::Receiver, + sockets: Arc, + mut sessions: Sessions, + mut vpn_service: FortiService, + ) -> Result<(), IKEv2Error> { + let rt = runtime::Handle::current(); + let poller = MultiPoller::new(&sockets); + let mut udp_buffer = [0u8; MAX_DATAGRAM_SIZE]; + let mut vpn_buffer = [0u8; MAX_ESP_PACKET_SIZE]; + let nat_port = sockets.nat_port; + let mut vpn_connected = false; + let mut shutdown = false; + loop { + if shutdown && sessions.is_empty() && !vpn_service.is_connected() { + debug!("Shutdown completed"); + return Ok(()); + } + // Wait until something is ready. + let poll_result = { + let ignore_vpn = shutdown && !vpn_service.is_connected(); + let next_vpn_packet = vpn_service.next_packet(); + let next_vpn_packet_pin = pin!(next_vpn_packet); + let next_command = command_receiver.recv(); + let next_command_pin = pin!(next_command); + poller + .ready_list(next_command_pin, next_vpn_packet_pin, ignore_vpn) + .await + }; + // Process all ready events. + if let Some(message) = poll_result.command_message { + match message { + SessionMessage::Shutdown => { + shutdown = true; + sessions.cleanup(&rt); + if let Err(err) = vpn_service.terminate_shutdown(&rt) { + warn!("Failed to terminate VPN client connection: {}", err); + } + } + SessionMessage::SendVpnKeepalive => { + if let Err(err) = vpn_service.process_keepalive().await { + warn!("Echo request timed out: {}", err); + if let Err(err) = vpn_service.start_disconnection(&rt) { + warn!("Failed to terminate VPN client connection: {}", err); + } + } + } + _ => { + // These messages are handled by Session. + } + } + sessions.process_message(message).await; + } + for listen_addr in poll_result.ready_sockets { + if let Some(socket) = sockets.sockets.get(&listen_addr) { + let mut datagram = match socket.try_recv_from(&mut udp_buffer) { + Ok((bytes, remote_addr)) => UdpDatagram { + remote_addr, + local_addr: listen_addr, + is_nat_port: listen_addr.port() == nat_port, + bytes: &mut udp_buffer[..bytes], + }, + Err(err) => { + match err.kind() { + io::ErrorKind::WouldBlock => continue, + _ => { + warn!("Failed to receive data from ready socket: {}", err); + continue; + } + }; + } + }; + let result = if datagram.is_ikev2() { + sessions + .process_ikev2_message(&datagram, vpn_service.ip_configuration()) + .await + } else { + sessions + .process_esp_packet(&mut datagram, &mut vpn_service) + .await + }; + if let Err(err) = result { + warn!( + "Failed to process message from {}: {}", + datagram.remote_addr, err + ); + } + } else { + warn!( + "Received notification from non-existing listen address {}", + listen_addr + ); + } + } + if vpn_connected && vpn_service.is_connected() { + vpn_connected = false; + sessions.delete_all_sessions(&rt); + } + if let Some(vpn_status) = poll_result.vpn_status { + let can_recv = match vpn_status { + Ok(ready) => ready, + Err(err) => { + warn!("VPN reported an error status: {}", err); + false + } + }; + + let read_bytes = if can_recv { + match vpn_service.read_vpn_packet(&mut vpn_buffer).await { + Ok(bytes) => bytes, + Err(err) => { + warn!("Failed to receive packet from VPN: {}", err); + if let Err(err) = vpn_service.start_disconnection(&rt) { + warn!("Failed to start VPN disconnection: {}", err); + } + 0 + } + } + } else { + 0 + }; + if let Err(err) = sessions + .process_vpn_packet(&mut vpn_buffer, read_bytes) + .await + { + warn!("Failed to process VPN packet: {}", err); + } + } + } + } +} + +struct MultiPoller { + sockets: Vec<(SocketAddr, Arc)>, +} + +struct MultiPollResult { + command_message: Option, + ready_sockets: Vec, + vpn_status: Option>, +} + +impl MultiPoller { + fn new(sockets: &Sockets) -> MultiPoller { + let sockets = sockets + .iter_sockets() + .map(|(listen_addr, socket)| (*listen_addr, socket.clone())) + .collect::>(); + MultiPoller { sockets } + } + + fn ready_list<'a, C, V>( + &self, + mut command_recv: Pin<&'a mut C>, + mut peek_vpn: Pin<&'a mut V>, + ignore_vpn: bool, + ) -> impl Future + use<'a, '_, C, V> + where + C: Future>, + V: Future>, + { + let sockets = self.sockets.clone(); + future::poll_fn(move |cx| { + let mut ready_sockets = Vec::with_capacity(sockets.len()); + ready_sockets.extend(sockets.iter().filter_map(|(listen_addr, socket)| { + match socket.poll_recv_ready(cx) { + Poll::Ready(_) => Some(*listen_addr), + Poll::Pending => None, + } + })); + let vpn_status = if !ignore_vpn { + match peek_vpn.as_mut().poll(cx) { + Poll::Ready(res) => Some(res), + Poll::Pending => None, + } + } else { + // Avoid waking if VPN is already shut down. + None + }; + let command_message = match command_recv.as_mut().poll(cx) { + Poll::Ready(command) => command, + Poll::Pending => None, + }; + if command_message.is_some() || vpn_status.is_some() || !ready_sockets.is_empty() { + Poll::Ready(MultiPollResult { + command_message, + ready_sockets, + vpn_status, + }) + } else { + Poll::Pending + } }) } } @@ -208,65 +399,6 @@ impl Sockets { Ok(Sockets { sockets, nat_port }) } - fn read_datagram(&self) -> impl Future> + use<'_> { - /* - let mut poll_futures = self - .sockets - .iter() - .map(|(listen_addr, socket)| { - let socket = socket.clone(); - // TODO: reuse buffers, read into stack before creating a vec, or create a buffer - // pool? - let mut dest = vec![0u8; MAX_DATAGRAM_SIZE]; - let recv_future = async move { - let (bytes_res, remote_addr) = socket.recv_from(dest.as_mut_slice()).await?; - dest.truncate(bytes_res); - Ok(UdpDatagram { - remote_addr, - local_addr: *listen_addr, - is_nat_port: listen_addr.port() == self.nat_port, - bytes: dest, - }) - }; - pin!(recv_future) - }) - .collect::>(); - */ - // TODO: use a round robin order to read from all sockets simultaneously. - let sockets = self - .sockets - .iter() - .map(|(listen_addr, socket)| (*listen_addr, socket.clone())) - .collect::>(); - future::poll_fn(move |cx| { - sockets - .iter() - .flat_map(|(listen_addr, socket)| { - let recv_future = pin!(async { - let mut dest = vec![0u8; MAX_DATAGRAM_SIZE]; - let (bytes_res, remote_addr) = - socket.recv_from(dest.as_mut_slice()).await?; - dest.truncate(bytes_res); - Ok(UdpDatagram { - remote_addr, - local_addr: *listen_addr, - is_nat_port: listen_addr.port() == self.nat_port, - bytes: dest, - }) - }); - if let Poll::Ready(buf) = recv_future.poll(cx) { - Some(Poll::Ready(buf)) - } else { - None - } - }) - .next() - .unwrap_or(Poll::Pending) - }) - //dest.resize(MAX_DATAGRAM_SIZE, 0); - //dest.truncate(bytes_res); - } - fn iter_sockets(&self) -> collections::hash_map::Iter> { self.sockets.iter() } @@ -293,14 +425,14 @@ impl Sockets { } } -struct UdpDatagram { +struct UdpDatagram<'a> { remote_addr: SocketAddr, local_addr: SocketAddr, is_nat_port: bool, - bytes: Vec, + bytes: &'a mut [u8], } -impl UdpDatagram { +impl UdpDatagram<'_> { fn is_non_esp(&self) -> bool { self.bytes.len() >= 4 && self.bytes[0..4] == [0x00, 0x00, 0x00, 0x00] } @@ -311,21 +443,18 @@ impl UdpDatagram { } enum SessionMessage { - TransmitResponse(session::SessionID, u32, bool), DeleteSession(session::SessionID), DeleteSecurityAssociation(u32), RetransmitRequest(session::SessionID, u32), - VpnPacket(Vec), - VpnDisconnected, CleanupTimer, - Shutdown, UpdateSplitRoutes(Vec, Vec), + SendVpnKeepalive, + Shutdown, } struct Sessions { pki_processing: Arc, sockets: Arc, - vpn_service: FortiService, tunnel_ips: Vec, traffic_selectors: Vec, sessions: HashMap, @@ -333,24 +462,21 @@ struct Sessions { next_sa_index: usize, half_sessions: HashMap<(SocketAddr, u64), (u64, Instant)>, reserved_spi: Option, - tx: mpsc::Sender, - rx: mpsc::Receiver, + command_sender: mpsc::Sender, shutdown: bool, } impl Sessions { fn new( pki_processing: Arc, + command_sender: mpsc::Sender, sockets: Arc, - vpn_service: FortiService, tunnel_ips: Vec, traffic_selectors: Vec, ) -> Sessions { - let (tx, rx) = mpsc::channel(32); Sessions { pki_processing, sockets, - vpn_service, tunnel_ips, traffic_selectors, sessions: HashMap::new(), @@ -358,16 +484,11 @@ impl Sessions { security_associations: HashMap::new(), half_sessions: HashMap::new(), reserved_spi: None, - tx, - rx, + command_sender, shutdown: false, } } - fn create_sender(&self) -> mpsc::Sender { - self.tx.clone() - } - fn get_init_session( &mut self, remote_spi: u64, @@ -401,8 +522,8 @@ impl Sessions { session_id } - fn get(&mut self, id: session::SessionID) -> Option<&mut session::IKEv2Session> { - self.sessions.get_mut(&id) + fn is_empty(&self) -> bool { + self.sessions.is_empty() } fn cleanup(&mut self, rt: &runtime::Handle) { @@ -467,7 +588,7 @@ impl Sessions { continue; } }; - let sender = self.tx.clone(); + let sender = self.command_sender.clone(); let session_id = *session_id; rt.spawn(async move { let _ = sender @@ -536,74 +657,40 @@ impl Sessions { reserved_spi } - async fn process_messages(&mut self) -> Result<(), IKEv2Error> { - //self.vpn_service.start(self.create_sender()).await?; - while let Some(message) = self.rx.recv().await { - match message { - /* - SessionMessage::UdpDatagram(mut datagram) => { - if let Err(err) = self.process_datagram(&mut datagram).await { - warn!( - "Failed to process message from {}: {}", - datagram.remote_addr, err - ); - } - } - */ - SessionMessage::TransmitResponse(session_id, message_id, is_nat) => { - self.transmit_response(session_id, message_id, is_nat).await; - } - SessionMessage::VpnPacket(data) => { - if let Err(err) = self.process_vpn_packet(data).await { - warn!("Failed to process VPN packet: {}", err); - } - } - SessionMessage::DeleteSession(session_id) => self.delete_session(session_id), - SessionMessage::DeleteSecurityAssociation(session_id) => { - self.delete_security_association(session_id) - } - SessionMessage::VpnDisconnected => { - let rt = runtime::Handle::current(); - self.delete_all_sessions(&rt); - } - SessionMessage::CleanupTimer => { - let rt = runtime::Handle::current(); - self.cleanup(&rt); - } - SessionMessage::UpdateSplitRoutes(tunnel_ips, traffic_selectors) => { - self.tunnel_ips = tunnel_ips; - self.traffic_selectors = traffic_selectors; - self.update_all_split_routes().await; - } - SessionMessage::RetransmitRequest(session_id, message_id) => { - self.retransmit_request(session_id, message_id).await; - } - SessionMessage::Shutdown => { - self.shutdown = true; - let rt = runtime::Handle::current(); - self.cleanup(&rt); - } + async fn process_message(&mut self, message: SessionMessage) { + match message { + SessionMessage::DeleteSession(session_id) => self.delete_session(session_id), + SessionMessage::DeleteSecurityAssociation(session_id) => { + self.delete_security_association(session_id) + } + SessionMessage::CleanupTimer => { + let rt = runtime::Handle::current(); + self.cleanup(&rt); } - if self.shutdown && self.sessions.is_empty() { - break; + SessionMessage::UpdateSplitRoutes(tunnel_ips, traffic_selectors) => { + self.tunnel_ips = tunnel_ips; + self.traffic_selectors = traffic_selectors; + self.update_all_split_routes().await; + } + SessionMessage::RetransmitRequest(session_id, message_id) => { + self.retransmit_request(session_id, message_id).await; + } + SessionMessage::SendVpnKeepalive => { + // This messages is handled externally. + } + SessionMessage::Shutdown => { + self.shutdown = true; } - } - self.vpn_service.terminate().await?; - debug!("Shutdown completed"); - Ok(()) - } - - async fn process_datagram(&mut self, datagram: &mut UdpDatagram) -> Result<(), IKEv2Error> { - if datagram.is_ikev2() { - self.process_ikev2_message(datagram).await - } else { - self.process_esp_packet(datagram).await } } - async fn process_ikev2_message(&mut self, datagram: &UdpDatagram) -> Result<(), IKEv2Error> { + async fn process_ikev2_message( + &mut self, + datagram: &UdpDatagram<'_>, + ip_configuration: Option<(IpAddr, &[IpAddr])>, + ) -> Result<(), IKEv2Error> { let is_nat = datagram.is_non_esp(); - let ikev2_request = message::InputMessage::from_datagram(&datagram.bytes, is_nat)?; + let ikev2_request = message::InputMessage::from_datagram(datagram.bytes, is_nat)?; if !ikev2_request.is_valid() { return Err("Invalid message received".into()); } @@ -638,23 +725,20 @@ impl Sessions { } else { session::SessionID::from_message(&ikev2_request)? }; - let ip_configuration = if ikev2_request.read_exchange_type()? - == message::ExchangeType::IKE_AUTH - && !ikev2_request.read_flags()?.has(message::Flags::RESPONSE) - { - self.vpn_service.ip_configuration().await? - } else { - None - }; + let sockets = self.sockets.clone(); let mut reserved_spi = self.reserve_session_ids(); - let session = if let Some(session) = self.get(session_id) { + let session = if let Some(session) = self.sessions.get_mut(&session_id) { session } else { return Err("Session not found".into()); }; - if let Some((client_ip, dns_addrs)) = ip_configuration { - session.update_ip(client_ip, dns_addrs); - } + if ikev2_request.read_exchange_type()? == message::ExchangeType::IKE_AUTH + && !ikev2_request.read_flags()?.has(message::Flags::RESPONSE) + { + if let Some((client_ip, dns_addrs)) = ip_configuration { + session.update_ip(client_ip, dns_addrs.to_vec()); + } + }; if ikev2_request.read_flags()?.has(message::Flags::RESPONSE) { session.process_response(datagram.remote_addr, datagram.local_addr, &ikev2_request)?; @@ -668,14 +752,15 @@ impl Sessions { self.reserved_spi = Some(reserved_spi); // Response retransmissions are initiated by client. if transmit_response { - let _ = self - .tx - .send(SessionMessage::TransmitResponse( - session_id, - ikev2_request.read_message_id(), - is_nat, - )) - .await; + if let Err(err) = session + .send_last_response(&sockets, ikev2_request.read_message_id(), is_nat) + .await + { + warn!( + "Failed to transmit response to session {}: {}", + session_id, err + ); + } } } @@ -719,7 +804,7 @@ impl Sessions { .insert(session_id, *security_association); } session::IKEv2PendingAction::DeleteIKESession(delay) => { - let tx = self.tx.clone(); + let tx = self.command_sender.clone(); let cmd = SessionMessage::DeleteSession(session_id); rt.spawn(async move { debug!("Scheduling to delete IKEv2 session {}", session_id); @@ -745,7 +830,7 @@ impl Sessions { .collect::>(); } session::IKEv2PendingAction::DeleteChildSA(session_id, delay) => { - let tx = self.tx.clone(); + let tx = self.command_sender.clone(); let cmd = SessionMessage::DeleteSecurityAssociation(session_id); rt.spawn(async move { debug!( @@ -789,32 +874,6 @@ impl Sessions { }); } - async fn transmit_response( - &mut self, - session_id: session::SessionID, - message_id: u32, - is_nat: bool, - ) { - let session = if let Some(session) = self.sessions.get_mut(&session_id) { - session - } else { - debug!( - "Failed to transmit response: missing session {}", - session_id - ); - return; - }; - if let Err(err) = session - .send_last_response(&self.sockets, message_id, is_nat) - .await - { - warn!( - "Failed to transmit response to session {}: {}", - session_id, err - ); - } - } - async fn retransmit_request(&mut self, session_id: session::SessionID, message_id: u32) { let session = if let Some(session) = self.sessions.get_mut(&session_id) { session @@ -836,7 +895,13 @@ impl Sessions { warn!("Session {} reached retrasmission limit", session_id); } session::NextRetransmission::Delay(delay) => { - Self::schedule_retransmission(self.tx.clone(), session_id, message_id, delay).await + Self::schedule_retransmission( + self.command_sender.clone(), + session_id, + message_id, + delay, + ) + .await } } } @@ -855,7 +920,11 @@ impl Sessions { }); } - async fn process_esp_packet(&mut self, datagram: &mut UdpDatagram) -> Result<(), IKEv2Error> { + async fn process_esp_packet( + &mut self, + datagram: &mut UdpDatagram<'_>, + vpn_service: &mut FortiService, + ) -> Result<(), IKEv2Error> { if datagram.bytes == [0xff] { debug!("Received ESP NAT keepalive from {}", datagram.remote_addr); return Ok(self @@ -875,7 +944,7 @@ impl Sessions { local_spi.copy_from_slice(&datagram.bytes[0..4]); let local_spi = u32::from_be_bytes(local_spi); if let Some(sa) = self.security_associations.get_mut(&local_spi) { - let decrypted_slice = sa.handle_esp(&mut datagram.bytes)?; + let decrypted_slice = sa.handle_esp(datagram.bytes)?; trace!( "Decrypted ESP packet from {}\n{:?}", datagram.remote_addr, @@ -897,7 +966,7 @@ impl Sessions { } let mut decrypted_data = Vec::with_capacity(MAX_ESP_PACKET_SIZE); decrypted_data.extend_from_slice(decrypted_slice); - self.vpn_service.send_packet(decrypted_data).await + vpn_service.send_packet(decrypted_data).await } else { warn!( "Security Association {:x} from {} not found", @@ -907,13 +976,21 @@ impl Sessions { } } - async fn process_vpn_packet(&mut self, mut data: Vec) -> Result<(), IKEv2Error> { - let hdr = match esp::IpHeader::from_packet(&data) { + async fn process_vpn_packet( + &mut self, + data: &mut [u8], + data_len: usize, + ) -> Result<(), IKEv2Error> { + if data_len == 0 { + return Ok(()); + } + let hdr = match esp::IpHeader::from_packet(&data[..data_len]) { Ok(hdr) => hdr, Err(err) => { warn!( "Failed to read header in IP packet from VPN: {}\n{:?}", - err, data + err, + &data[..data_len] ); return Err("Failed to read header in IP packet from VPN".into()); } @@ -932,19 +1009,17 @@ impl Sessions { } }) { - let msg_len = data.len(); - let encoded_length = sa.encoded_length(data.len()); - if encoded_length > data.capacity() { + let encoded_length = sa.encoded_length(data_len); + if encoded_length > data.len() { // This sometimes happens when FortiVPN returns a zero-padded packet. warn!( - "Vector doesn't have capacity for ESP headers, data length is {}, capacity is {}", + "Slice doesn't have capacity for ESP headers, message length is {}, slice has {}", encoded_length, - data.capacity() + data.len() ); return Err("Vector doesn't have capacity for ESP headers".into()); } - data.resize(encoded_length, 0); - let encrypted_data = sa.handle_vpn(data.as_mut_slice(), msg_len)?; + let encrypted_data = sa.handle_vpn(&mut data[..encoded_length], data_len)?; trace!( "Encrypted VPN packet to {}\n{:?}", sa.remote_addr(), @@ -967,8 +1042,10 @@ impl Sessions { struct FortiService { config: fortivpn::Config, tunnel: Option, - connect_handle: Option>>, - join_set: JoinSet>, + connect_receiver: + Option>>, + terminate_receiver: Option>>, + shutdown: bool, } impl FortiService { @@ -976,11 +1053,16 @@ impl FortiService { FortiService { config, tunnel: None, - connect_handle: None, - join_set: JoinSet::new(), + connect_receiver: None, + terminate_receiver: None, + shutdown: false, } } + fn is_connected(&self) -> bool { + self.tunnel.is_some() + } + async fn connect( config: fortivpn::Config, ) -> Result { @@ -988,197 +1070,112 @@ impl FortiService { fortivpn::FortiVPNTunnel::new(&config, sslvpn_cookie).await } - async fn peek_vpn(forti_client: &mut fortivpn::FortiVPNTunnel) -> Option { - if let Err(err) = forti_client.peek_recv().await { - debug!("Failed to check if VPN has data available: {}", err); - } - Some(FortiServiceCommand::ReceivePacket) + async fn read_vpn_packet(&mut self, buffer: &mut [u8]) -> Result { + let tunnel = if let Some(tunnel) = self.tunnel.as_mut() { + tunnel + } else { + return Err("VPN tunnel is closed".into()); + }; + Ok(tunnel.try_read_packet(buffer, None).await?) } - async fn read_vpn_packet( - forti_client: &mut fortivpn::FortiVPNTunnel, - ) -> Result, fortivpn::FortiError> { - let mut buffer = [0u8; MAX_ESP_PACKET_SIZE]; - match forti_client.try_read_packet(&mut buffer, None).await { - Ok(msg_len) => { - if msg_len > 0 { - let mut packet_buffer = Vec::with_capacity(MAX_ESP_PACKET_SIZE); - packet_buffer.extend_from_slice(&buffer[..msg_len]); - Ok(packet_buffer) - } else { - Ok(vec![]) - } - } - Err(err) => Err(err), + async fn process_keepalive(&mut self) -> Result<(), fortivpn::FortiError> { + if let Some(tunnel) = self.tunnel.as_mut() { + tunnel.process_echo().await + } else { + Ok(()) } } - async fn process_next(&self) { - match self.tunnel.as_ref() { - Some(tunnel) => self.process_next_connected(tunnel).await, - None => self.process_next_connection().await, + async fn next_packet(&mut self) -> Result { + if let Some(tunnel) = self.tunnel.as_mut() { + // VPN is connected, wait for next available packet. + let result = tunnel.peek_recv().await; + if let Err(err) = result { + debug!("Failed to check if VPN has data available: {}", err); + Err(err.into()) + } else { + Ok(true) + } + } else if let Some(receive_result) = self.terminate_receiver.as_mut() { + let receive_result = receive_result.await; + self.terminate_receiver = None; + receive_result.map_err(|err| { + warn!("Failed to receive VPN termination result: {}", err); + "Failed to receive VPN termination result" + })??; + Ok(false) + } else if let Some(receive_result) = self.connect_receiver.as_mut() { + let receive_result = receive_result.await; + self.connect_receiver = None; + let connect_result = receive_result.map_err(|err| { + warn!("Failed to receive VPN connection result: {}", err); + "Failed to receive VPN connection result" + })?; + + self.tunnel = Some(connect_result?); + Ok(false) + } else if !self.shutdown { + let rt = runtime::Handle::current(); + let config = self.config.clone(); + let (tx, rx) = oneshot::channel(); + self.connect_receiver = Some(rx); + rt.spawn(async move { tx.send(Self::connect(config).await) }); + Ok(false) + } else { + Err("VPN service is shut down".into()) } } - async fn process_next_connection(&self) { - if let Some(connect_handle) = self.connect_handle.as_ref() { - //connect_handle - } - let result = Self::connect(self.config.clone()).await; + fn terminate_shutdown(&mut self, rt: &runtime::Handle) -> Result<(), IKEv2Error> { + self.shutdown = true; + self.start_disconnection(rt) } - async fn process_next_connected(&self, tunnel: &fortivpn::FortiVPNTunnel) {} - - /* - async fn run( - config: fortivpn::Config, - tx: mpsc::Sender, - mut rx: mpsc::Receiver, - sessions_tx: mpsc::Sender, - ) -> Result<(), IKEv2Error> { - loop { - connect_handle.abort(); - let mut forti_client = match forti_client.unwrap() { - Ok(forti_client) => forti_client, - Err(err) => { - debug!("Error occurred when connecting to FortiClient: {}", err); - continue; - } - }; - let keepalive_timer = { - let rt = runtime::Handle::current(); - let tx = tx.clone(); - rt.spawn(async move { - let mut interval = fortivpn::echo_send_interval(); - loop { - interval.tick().await; - let _ = tx.send(FortiServiceCommand::SendEcho).await; - } - }) - }; - // Handle connection until it drops. - let mut selector = crate::futures::RoundRobinSelector::new(); - while let Some(command) = selector - .select(rx.recv(), Self::peek_vpn(&mut forti_client)) - .await - { - match command { - FortiServiceCommand::SendPacket(data) => { - if let Err(err) = forti_client.send_packet(&data).await { - warn!("Failed to send packet to VPN: {}", err); - break; - }; - if let Err(err) = forti_client.flush().await { - warn!("Failed to flush packet to VPN: {}", err); - break; - }; - } - FortiServiceCommand::ReceivePacket => { - let data = match Self::read_vpn_packet(&mut forti_client).await { - Ok(data) => data, - Err(err) => { - warn!("Failed to receive packet from VPN: {}", err); - break; - } - }; - if !data.is_empty() { - let _ = sessions_tx.send(SessionMessage::VpnPacket(data)).await; - } - } - FortiServiceCommand::SendEcho => { - if let Err(err) = forti_client.process_echo().await { - warn!("Echo request timed out: {}", err); - break; - } - } - FortiServiceCommand::RequestIpConfiguration(tx) => { - let _ = - tx.send(Some((forti_client.ip_addr(), forti_client.dns().to_vec()))); - } - FortiServiceCommand::Shutdown => { - keepalive_timer.abort(); - if let Err(err) = forti_client.terminate().await { - warn!("Failed to terminate VPN client connection: {}", err); - } - return Ok(()); - } + fn start_disconnection(&mut self, rt: &runtime::Handle) -> Result<(), IKEv2Error> { + if self.terminate_receiver.is_some() { + Err( + "Received additional VPN disconnection request, termination already in progress" + .into(), + ) + } else if let Some(mut tunnel) = self.tunnel.take() { + let (tx, rx) = oneshot::channel(); + self.terminate_receiver = Some(rx); + rt.spawn(async move { + let result = tunnel.terminate().await; + if let Err(ref err) = result { + warn!("Error returned when disconnecting VPN client: {}", err); } - } - keepalive_timer.abort(); - if let Err(err) = forti_client.terminate().await { - warn!("Failed to terminate VPN client connection: {}", err); - } - { - let rt = runtime::Handle::current(); - let sessions_tx = sessions_tx.clone(); - rt.spawn(async move { sessions_tx.send(SessionMessage::VpnDisconnected).await }); - } - } - } - */ - - async fn send_packet(&self, data: Vec) -> Result<(), IKEv2Error> { - Ok(()) - /* - if let Some(tx) = self.command_sender.as_ref() { - let _ = tx.send(FortiServiceCommand::SendPacket(data)).await; + tx.send(result) + }); Ok(()) } else { - Err("VPN client service is not running".into()) + Err("Received VPN disconnection request for a closed tunnel".into()) } - */ } - async fn ip_configuration(&self) -> Result)>, IKEv2Error> { - Ok(None) - /* - if let Some(command_sender) = self.command_sender.as_ref() { - let (tx, rx) = oneshot::channel(); - command_sender - .send(FortiServiceCommand::RequestIpConfiguration(tx)) - .await - .map_err(|_| "VPN client command channel closed")?; - Ok(rx.await.map_err(|_| "IP address receiver closed")?) + async fn send_packet(&mut self, data: Vec) -> Result<(), IKEv2Error> { + if let Some(tunnel) = self.tunnel.as_mut() { + tunnel.send_packet(&data).await.map_err(|err| { + warn!("Failed to send packet to VPN: {}", err); + err + })?; + Ok(tunnel.flush().await.map_err(|err| { + warn!("Failed to flush packet to VPN: {}", err); + err + })?) } else { Err("VPN client service is not running".into()) } - */ } - async fn terminate(&mut self) -> Result<(), IKEv2Error> { - Ok(()) - /* - match self.command_sender { - Some(ref command_sender) => { - if command_sender - .send(FortiServiceCommand::Shutdown) - .await - .is_err() - { - return Err("Command channel closed".into()); - } - } - None => return Err("Shutdown already in progress".into()), - } - self.command_sender = None; - while let Some(res) = self.join_set.join_next().await { - if let Err(err) = res { - warn!("Error returned when stopping VPN client: {}", err); - } - } - Ok(()) - */ + fn ip_configuration(&self) -> Option<(IpAddr, &[IpAddr])> { + self.tunnel + .as_ref() + .map(|tunnel| (tunnel.ip_addr(), tunnel.dns())) } } -enum FortiServiceCommand { - RequestIpConfiguration(oneshot::Sender)>>), - SendPacket(Vec), - ReceivePacket, - SendEcho, - Shutdown, -} - struct SplitRouteRegistry { tunnel_domains: Vec, } diff --git a/src/main.rs b/src/main.rs index e9cb2fa..efc0ee5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,6 @@ use tokio::{signal, sync::mpsc}; use tokio_rustls::rustls; mod fortivpn; -mod futures; mod http; mod logger; mod ppp; @@ -406,24 +405,22 @@ fn serve_ikev2(config: Ikev2Config) -> Result<(), i32> { eprintln!("Failed to start runtime: {}", err); 1 })?; - let server = match ikev2::Server::new(config.ikev2) { + let mut server = match ikev2::Server::new(config.ikev2) { Ok(server) => server, Err(err) => { eprintln!("Failed to create server: {}", err); std::process::exit(1) } }; - let mut server = rt.block_on(server.start(config.fortivpn)).map_err(|err| { + rt.block_on(server.start(config.fortivpn)).map_err(|err| { eprintln!("Failed to run server: {}", err); 1 })?; - let cancel_flag = Arc::new(atomic::AtomicBool::new(false)); let (shutdown_sender, shutdown_receiver) = oneshot::channel(); let cancel_handle = rt.spawn(async move { if let Err(err) = signal::ctrl_c().await { eprintln!("Failed to wait for CTRL+C signal: {}", err); } - cancel_flag.store(true, atomic::Ordering::Relaxed); let _ = shutdown_sender.send(()); });