diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0d8d4ff2f2..3fb1710654 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,7 +73,7 @@ jobs: run: curl -LsSf https://get.nexte.st/latest/linux | tar zxf - -C ${CARGO_HOME:-~/.cargo}/bin - name: Build run: cargo build -p qt -p quilkin -p quilkin-xds --tests - - run: cargo nextest run -p qt -p quilkin -p quilkin-xds quilkin + - run: cargo nextest run --no-tests=pass -p qt -p quilkin -p quilkin-xds quilkin build: name: Build diff --git a/Cargo.lock b/Cargo.lock index 138b75cc67..3be76e4ff4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,18 +142,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "async-channel" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" -dependencies = [ - "concurrent-queue", - "event-listener-strategy", - "futures-core", - "pin-project-lite", -] - [[package]] name = "async-stream" version = "0.3.6" @@ -2434,7 +2422,6 @@ dependencies = [ name = "qt" version = "0.1.0" dependencies = [ - "async-channel", "once_cell", "quilkin", "rand", @@ -2457,7 +2444,6 @@ name = "quilkin" version = "0.10.0-dev" dependencies = [ "arc-swap", - "async-channel", "async-stream", "async-trait", "base64 0.22.1", diff --git a/Cargo.toml b/Cargo.toml index f871e4820b..f7c2bd8cdd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,12 +87,12 @@ quilkin-proto.workspace = true # Crates.io arc-swap.workspace = true -async-channel.workspace = true async-stream.workspace = true base64.workspace = true base64-serde = "0.8.0" bytes = { version = "1.8.0", features = ["serde"] } cached.workspace = true +cfg-if = "1.0" crossbeam-utils = { version = "0.8", optional = true } clap = { version = "4.5.21", features = ["cargo", "derive", "env"] } dashmap = { version = "6.1", features = ["serde"] } @@ -153,7 +153,6 @@ hickory-resolver = { version = "0.24", features = [ async-trait = "0.1.83" strum = "0.26" strum_macros = "0.26" -cfg-if = "1.0.0" libflate = "2.1.0" form_urlencoded = "1.2.1" enum_dispatch = "0.3.13" @@ -194,7 +193,6 @@ edition = "2021" [workspace.dependencies] arc-swap = { version = "1.7.1", features = ["serde"] } -async-channel = "2.3.1" async-stream = "0.3.6" base64 = "0.22.1" cached = { version = "0.54", default-features = false } diff --git a/crates/test/Cargo.toml b/crates/test/Cargo.toml index 7a90f45ec4..85d2200140 100644 --- a/crates/test/Cargo.toml +++ b/crates/test/Cargo.toml @@ -24,7 +24,6 @@ publish = false workspace = true [dependencies] -async-channel.workspace = true once_cell.workspace = true quilkin.workspace = true rand.workspace = true diff --git a/crates/test/tests/proxy.rs b/crates/test/tests/proxy.rs index 227cf6648a..2272af221b 100644 --- a/crates/test/tests/proxy.rs +++ b/crates/test/tests/proxy.rs @@ -1,5 +1,5 @@ use qt::*; -use quilkin::test::TestConfig; +use quilkin::{components::proxy, test::TestConfig}; use tracing::Instrument as _; trace_test!(server, { @@ -87,8 +87,7 @@ trace_test!(uring_receiver, { let (mut packet_rx, endpoint) = sb.server("server"); - let (error_sender, mut error_receiver) = - tokio::sync::mpsc::channel::(20); + let (error_sender, mut error_receiver) = tokio::sync::mpsc::channel::(20); tokio::task::spawn( async move { @@ -105,37 +104,32 @@ trace_test!(uring_receiver, { config .clusters .modify(|clusters| clusters.insert_default([endpoint.into()].into())); - let (tx, rx) = async_channel::unbounded(); - let (_shutdown_tx, shutdown_rx) = - quilkin::make_shutdown_channel(quilkin::ShutdownKind::Testing); let socket = sb.client(); let (ws, addr) = sb.socket(); + let pending_sends = proxy::PendingSends::new(1).unwrap(); + // we'll test a single DownstreamReceiveWorkerConfig - let ready = quilkin::components::proxy::packet_router::DownstreamReceiveWorkerConfig { + proxy::packet_router::DownstreamReceiveWorkerConfig { worker_id: 1, port: addr.port(), - upstream_receiver: rx.clone(), config: config.clone(), error_sender, buffer_pool: quilkin::test::BUFFER_POOL.clone(), - sessions: quilkin::components::proxy::SessionPool::new( + sessions: proxy::SessionPool::new( config, - tx, + vec![pending_sends.0.clone()], BUFFER_POOL.clone(), - shutdown_rx.clone(), ), } - .spawn(shutdown_rx) + .spawn(pending_sends) .await .expect("failed to spawn task"); // Drop the socket, otherwise it can drop(ws); - ready.recv().unwrap(); - let msg = "hello-downstream"; tracing::debug!("sending packet"); socket.send_to(msg.as_bytes(), addr).await.unwrap(); @@ -158,36 +152,33 @@ trace_test!( .clusters .modify(|clusters| clusters.insert_default([endpoint.into()].into())); - let (tx, rx) = async_channel::unbounded(); - let (_shutdown_tx, shutdown_rx) = - quilkin::make_shutdown_channel(quilkin::ShutdownKind::Testing); + let pending_sends: Vec<_> = [ + proxy::PendingSends::new(1).unwrap(), + proxy::PendingSends::new(1).unwrap(), + proxy::PendingSends::new(1).unwrap(), + ] + .into_iter() + .collect(); - let sessions = quilkin::components::proxy::SessionPool::new( + let sessions = proxy::SessionPool::new( config.clone(), - tx, + pending_sends.iter().map(|ps| ps.0.clone()).collect(), BUFFER_POOL.clone(), - shutdown_rx.clone(), ); const WORKER_COUNT: usize = 3; let (socket, addr) = sb.socket(); - let workers = quilkin::components::proxy::packet_router::spawn_receivers( + proxy::packet_router::spawn_receivers( config, socket, - WORKER_COUNT, + pending_sends, &sessions, - rx, BUFFER_POOL.clone(), - shutdown_rx, ) .await .unwrap(); - for wn in workers { - wn.recv().unwrap(); - } - let socket = std::sync::Arc::new(sb.client()); let msg = "recv-from"; diff --git a/src/collections/ttl.rs b/src/collections/ttl.rs index 526b147e91..f2da7a1d0a 100644 --- a/src/collections/ttl.rs +++ b/src/collections/ttl.rs @@ -55,11 +55,13 @@ impl Value { /// Get the expiration time for this value. The returned value is the /// number of seconds relative to some reference point (e.g UNIX_EPOCH), based /// on the clock being used. + #[inline] fn expiration_secs(&self) -> u64 { self.expires_at.load(Ordering::Relaxed) } /// Update the value's expiration time to (now + TTL). + #[inline] fn update_expiration(&self, ttl: Duration) { match self.clock.compute_expiration_secs(ttl) { Ok(new_expiration_time) => { @@ -160,6 +162,7 @@ where /// Returns the current time as the number of seconds relative to some initial /// reference point (e.g UNIX_EPOCH), based on the clock implementation being used. /// In tests, this will be driven by [`tokio::time`] + #[inline] pub(crate) fn now_relative_secs(&self) -> u64 { self.0.clock.now_relative_secs().unwrap_or_default() } @@ -237,6 +240,12 @@ where self.0.inner.remove(&key).is_some() } + /// Removes all entries from the map + #[inline] + pub fn clear(&self) { + self.0.inner.clear(); + } + /// Returns an entry for in-place updates of the specified key-value pair. /// Note: This acquires a write lock on the map's shard that corresponds /// to the entry. diff --git a/src/components/proxy.rs b/src/components/proxy.rs index bbf6a8ffe1..f06af458a6 100644 --- a/src/components/proxy.rs +++ b/src/components/proxy.rs @@ -18,8 +18,79 @@ mod error; pub mod packet_router; mod sessions; -#[cfg(target_os = "linux")] -pub(crate) mod io_uring_shared; +cfg_if::cfg_if! { + if #[cfg(target_os = "linux")] { + pub(crate) mod io_uring_shared; + pub(crate) type PacketSendReceiver = io_uring_shared::EventFd; + pub(crate) type PacketSendSender = io_uring_shared::EventFdWriter; + } else { + pub(crate) type PacketSendReceiver = tokio::sync::watch::Receiver; + pub(crate) type PacketSendSender = tokio::sync::watch::Sender; + } +} + +/// A simple packet queue that signals when a packet is pushed +/// +/// For io_uring this notifies an eventfd that will be processed on the next +/// completion loop +#[derive(Clone)] +pub struct PendingSends { + packets: Arc>>, + notify: PacketSendSender, +} + +impl PendingSends { + pub fn new(capacity: usize) -> std::io::Result<(Self, PacketSendReceiver)> { + #[cfg(target_os = "linux")] + let (notify, rx) = { + let rx = io_uring_shared::EventFd::new()?; + (rx.writer(), rx) + }; + #[cfg(not(target_os = "linux"))] + let (notify, rx) = tokio::sync::watch::channel(true); + + Ok(( + Self { + packets: Arc::new(parking_lot::Mutex::new(Vec::with_capacity(capacity))), + notify, + }, + rx, + )) + } + + #[inline] + pub(crate) fn capacity(&self) -> usize { + self.packets.lock().capacity() + } + + /// Pushes a packet onto the queue to be sent, signalling a sender that + /// it's available + #[inline] + pub(crate) fn push(&self, packet: SendPacket) { + self.packets.lock().push(packet); + #[cfg(target_os = "linux")] + self.notify.write(1); + #[cfg(not(target_os = "linux"))] + let _ = self.notify.send(true); + } + + /// Called to shutdown the consumer side of the sends (ie the io loop that is + /// actually dequing and sending packets) + #[inline] + pub(crate) fn shutdown_receiver(&self) { + #[cfg(target_os = "linux")] + self.notify.write(0xdeadbeef); + #[cfg(not(target_os = "linux"))] + let _ = self.notify.send(false); + } + + /// Swaps the current queue with an empty one so we only lock for a pointer swap + #[inline] + pub fn swap(&self, mut swap: Vec) -> Vec { + swap.clear(); + std::mem::replace(&mut self.packets.lock(), swap) + } +} use super::RunArgs; pub use error::{ErrorMap, PipelineError}; @@ -33,8 +104,11 @@ use std::{ }; pub struct SendPacket { - pub destination: SocketAddr, + /// The destination address of the packet + pub destination: socket2::SockAddr, + /// The packet data being sent pub data: crate::pool::FrozenPoolBuffer, + /// The asn info for the sender, used for metrics pub asn_info: Option, } @@ -208,18 +282,6 @@ impl Proxy { )); } - let id = config.id.load(); - let num_workers = self.num_workers.get(); - - let (upstream_sender, upstream_receiver) = async_channel::bounded(250); - let buffer_pool = Arc::new(crate::pool::BufferPool::new(num_workers, 64 * 1024)); - let sessions = SessionPool::new( - config.clone(), - upstream_sender, - buffer_pool.clone(), - shutdown_rx.clone(), - ); - #[allow(clippy::type_complexity)] const SUBS: &[(&str, &[(&str, Vec)])] = &[ ( @@ -247,6 +309,8 @@ impl Proxy { *lock = Some(check.clone()); } + let id = config.id.load(); + std::thread::Builder::new() .name("proxy-subscription".into()) .spawn({ @@ -291,14 +355,25 @@ impl Proxy { .expect("failed to spawn proxy-subscription thread"); } - let worker_notifications = packet_router::spawn_receivers( + let num_workers = self.num_workers.get(); + let buffer_pool = Arc::new(crate::pool::BufferPool::new(num_workers, 2 * 1024)); + + let mut worker_sends = Vec::with_capacity(num_workers); + let mut session_sends = Vec::with_capacity(num_workers); + for _ in 0..num_workers { + let psends = PendingSends::new(15)?; + session_sends.push(psends.0.clone()); + worker_sends.push(psends); + } + + let sessions = SessionPool::new(config.clone(), session_sends, buffer_pool.clone()); + + packet_router::spawn_receivers( config.clone(), self.socket, - num_workers, + worker_sends, &sessions, - upstream_receiver, buffer_pool, - shutdown_rx.clone(), ) .await?; @@ -310,10 +385,6 @@ impl Proxy { crate::net::phoenix::Phoenix::new(crate::codec::qcmp::QcmpMeasurement::new()?), )?; - for notification in worker_notifications { - let _ = notification.recv(); - } - tracing::info!("Quilkin is ready"); if let Some(initialized) = initialized { let _ = initialized.send(()); @@ -324,17 +395,7 @@ impl Proxy { .await .map_err(|error| eyre::eyre!(error))?; - if *shutdown_rx.borrow() == crate::ShutdownKind::Normal { - tracing::info!(sessions=%sessions.sessions().len(), "waiting for active sessions to expire"); - - let interval = std::time::Duration::from_millis(100); - - while sessions.sessions().is_not_empty() { - tokio::time::sleep(interval).await; - tracing::debug!(sessions=%sessions.sessions().len(), "sessions still active"); - } - tracing::info!("all sessions expired"); - } + sessions.shutdown(*shutdown_rx.borrow() == crate::ShutdownKind::Normal); Ok(()) } diff --git a/src/components/proxy/io_uring_shared.rs b/src/components/proxy/io_uring_shared.rs index 3420eaa6ba..40c1c8f68f 100644 --- a/src/components/proxy/io_uring_shared.rs +++ b/src/components/proxy/io_uring_shared.rs @@ -21,10 +21,9 @@ //! enough that it doesn't make sense to share the same code use crate::{ - components::proxy::{self, PipelineError}, + components::proxy::{self, PendingSends, PipelineError, SendPacket}, metrics, - net::maxmind_db::MetricsIpNetEntry, - pool::{FrozenPoolBuffer, PoolBuffer}, + pool::PoolBuffer, time::UtcTimestamp, }; use io_uring::{squeue::Entry, types::Fd}; @@ -38,30 +37,11 @@ use std::{ /// /// We use eventfd to signal to io uring loops from async tasks, it is essentially /// the equivalent of a signalling 64 bit cross-process atomic -pub(crate) struct EventFd { +pub struct EventFd { fd: std::os::fd::OwnedFd, val: u64, } -#[derive(Clone)] -pub(crate) struct EventFdWriter { - fd: i32, -} - -impl EventFdWriter { - #[inline] - pub(crate) fn write(&self, val: u64) { - // SAFETY: we have a valid descriptor, and most of the errors that apply - // to the general write call that eventfd_write wraps are not applicable - // - // Note that while the docs state eventfd_write is glibc, it is implemented - // on musl as well, but really is just a write with 8 bytes - unsafe { - libc::eventfd_write(self.fd, val); - } - } -} - impl EventFd { #[inline] pub(crate) fn new() -> std::io::Result { @@ -102,48 +82,30 @@ impl EventFd { } } -struct RecvPacket { - /// The buffer filled with data during recv_from - buffer: PoolBuffer, - /// The IP of the sender - source: std::net::SocketAddr, -} - -struct SendPacket { - /// The destination address of the packet - destination: SockAddr, - /// The packet data being sent - buffer: FrozenPoolBuffer, - /// The asn info for the sender, used for metrics - asn_info: Option, -} - -/// A simple double buffer for queing packets that need to be sent, each enqueue -/// notifies an eventfd that sends are available #[derive(Clone)] -struct PendingSends { - packets: Arc>>, - notify: EventFdWriter, +pub(crate) struct EventFdWriter { + fd: i32, } -impl PendingSends { - pub fn new(notify: EventFdWriter) -> Self { - Self { - packets: Default::default(), - notify, - } - } - +impl EventFdWriter { #[inline] - pub fn push(&self, packet: SendPacket) { - self.packets.lock().push(packet); - self.notify.write(1); + pub(crate) fn write(&self, val: u64) { + // SAFETY: we have a valid descriptor, and most of the errors that apply + // to the general write call that eventfd_write wraps are not applicable + // + // Note that while the docs state eventfd_write is glibc, it is implemented + // on musl as well, but really is just a write with 8 bytes + unsafe { + libc::eventfd_write(self.fd, val); + } } +} - #[inline] - pub fn swap(&self, swap: Vec) -> Vec { - std::mem::replace(&mut self.packets.lock(), swap) - } +struct RecvPacket { + /// The buffer filled with data during recv_from + buffer: PoolBuffer, + /// The IP of the sender + source: std::net::SocketAddr, } enum LoopPacketInner { @@ -192,8 +154,8 @@ impl LoopPacket { // For sends, the length of the buffer is the actual number of initialized bytes, // and note that iov_base is a *mut even though for sends the buffer is not actually // mutated - self.io_vec.iov_base = send.buffer.as_ptr() as *mut u8 as *mut _; - self.io_vec.iov_len = send.buffer.len(); + self.io_vec.iov_base = send.data.as_ptr() as *mut u8 as *mut _; + self.io_vec.iov_len = send.data.len(); // SAFETY: both pointers are valid at this point, with the same size unsafe { @@ -262,62 +224,8 @@ pub enum PacketProcessorCtx { }, } -pub enum PacketReceiver { - Router(crate::components::proxy::sessions::DownstreamReceiver), - SessionPool(tokio::sync::mpsc::Receiver), -} - -/// Spawns worker tasks -/// -/// One task processes received packets, notifying the io-uring loop when a -/// packet finishes processing, the other receives packets to send and notifies -/// the io-uring loop when there are 1 or more packets available to be sent -fn spawn_workers( - rt: &tokio::runtime::Runtime, - receiver: PacketReceiver, - pending_sends: PendingSends, - mut shutdown_rx: crate::ShutdownRx, - shutdown_event: EventFdWriter, -) { - // Spawn a task that just monitors the shutdown receiver to notify the io-uring loop to exit - rt.spawn(async move { - // The result is uninteresting, either a shutdown has been signalled, or all senders have been dropped - // which equates to the same thing - let _ = shutdown_rx.changed().await; - shutdown_event.write(1); - }); - - match receiver { - PacketReceiver::Router(upstream_receiver) => { - rt.spawn(async move { - while let Ok(packet) = upstream_receiver.recv().await { - let packet = SendPacket { - destination: packet.destination.into(), - buffer: packet.data, - asn_info: packet.asn_info, - }; - pending_sends.push(packet); - } - }); - } - PacketReceiver::SessionPool(mut downstream_receiver) => { - rt.spawn(async move { - while let Some(packet) = downstream_receiver.recv().await { - let packet = SendPacket { - destination: packet.destination.into(), - buffer: packet.data, - asn_info: packet.asn_info, - }; - pending_sends.push(packet); - } - }); - } - } -} - fn process_packet( ctx: &mut PacketProcessorCtx, - packet_processed_event: &EventFdWriter, packet: RecvPacket, last_received_at: &mut Option, ) { @@ -349,8 +257,6 @@ fn process_packet( error_acc, destinations, ); - - packet_processed_event.write(1); } PacketProcessorCtx::SessionPool { pool, port, .. } => { let mut last_received_at = None; @@ -361,8 +267,6 @@ fn process_packet( *port, &mut last_received_at, ); - - packet_processed_event.write(1); } } } @@ -377,10 +281,8 @@ enum Token { Recv { key: usize }, /// Packet sent Send { key: usize }, - /// One or more packets are ready to be sent + /// One or more packets are ready to be sent OR shutdown of the loop is requested PendingsSends, - /// Loop shutdown requested - Shutdown, } struct LoopCtx<'uring> { @@ -508,7 +410,6 @@ impl<'uring> LoopCtx<'uring> { } pub struct IoUringLoop { - runtime: tokio::runtime::Runtime, socket: crate::net::DualStackLocalSocket, concurrent_sends: usize, } @@ -518,14 +419,7 @@ impl IoUringLoop { concurrent_sends: u16, socket: crate::net::DualStackLocalSocket, ) -> Result { - let runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .max_blocking_threads(1) - .worker_threads(3) - .build()?; - Ok(Self { - runtime, concurrent_sends: concurrent_sends as _, socket, }) @@ -535,42 +429,29 @@ impl IoUringLoop { self, thread_name: String, mut ctx: PacketProcessorCtx, - receiver: PacketReceiver, + pending_sends: (PendingSends, EventFd), buffer_pool: Arc, - shutdown: crate::ShutdownRx, - ) -> Result, PipelineError> { + ) -> Result<(), PipelineError> { let dispatcher = tracing::dispatcher::get_default(|d| d.clone()); - let (tx, rx) = std::sync::mpsc::channel(); - let rt = self.runtime; let socket = self.socket; let concurrent_sends = self.concurrent_sends; let mut ring = io_uring::IoUring::new((concurrent_sends + 3) as _)?; - // Used to notify the uring loop when 1 or more packets have been queued - // up to be sent to a remote address - let mut pending_sends_event = EventFd::new()?; - // Used to notify the uring when a received packet has finished - // processing and we can perform another recv, as we (currently) only - // ever process a single packet at a time - let process_event = EventFd::new()?; - // Used to notify the uring loop to shutdown - let mut shutdown_event = EventFd::new()?; + let mut pending_sends_event = pending_sends.1; + let pending_sends = pending_sends.0; std::thread::Builder::new() .name(thread_name) .spawn(move || { let _guard = tracing::dispatcher::set_default(&dispatcher); - let tokens = slab::Slab::with_capacity(concurrent_sends + 1 + 1 + 1); + let tokens = slab::Slab::with_capacity(concurrent_sends + 1 + 1); let loop_packets = slab::Slab::with_capacity(concurrent_sends + 1); - // Create an eventfd to notify the uring thread (this one) of - // pending sends - let pending_sends = PendingSends::new(pending_sends_event.writer()); // Just double buffer the pending writes for simplicity - let mut double_pending_sends = Vec::new(); + let mut double_pending_sends = Vec::with_capacity(pending_sends.capacity()); // When sending packets, this is the direction used when updating metrics let send_dir = if matches!(ctx, PacketProcessorCtx::Router { .. }) { @@ -579,16 +460,6 @@ impl IoUringLoop { metrics::READ }; - // Spawn the worker tasks that process in an async context unlike - // our io-uring loop below - spawn_workers( - &rt, - receiver, - pending_sends.clone(), - shutdown, - shutdown_event.writer(), - ); - let (submitter, sq, mut cq) = ring.split(); let mut loop_ctx = LoopCtx { @@ -602,16 +473,12 @@ impl IoUringLoop { loop_ctx.enqueue_recv(buffer_pool.clone().alloc()); loop_ctx .push_with_token(pending_sends_event.io_uring_entry(), Token::PendingsSends); - loop_ctx.push_with_token(shutdown_event.io_uring_entry(), Token::Shutdown); // Sync always needs to be called when entries have been pushed // onto the submission queue for the loop to actually function (ie, similar to await on futures) loop_ctx.sync(); - // Notify that we have set everything up - let _ = tx.send(()); let mut last_received_at = None; - let process_event_writer = process_event.writer(); // The core io uring loop 'io: loop { @@ -653,26 +520,26 @@ impl IoUringLoop { } let packet = packet.finalize_recv(ret as usize); - process_packet( - &mut ctx, - &process_event_writer, - packet, - &mut last_received_at, - ); + process_packet(&mut ctx, packet, &mut last_received_at); loop_ctx.enqueue_recv(buffer_pool.clone().alloc()); } Token::PendingsSends => { - double_pending_sends = pending_sends.swap(double_pending_sends); - loop_ctx.push_with_token( - pending_sends_event.io_uring_entry(), - Token::PendingsSends, - ); - - for pending in - double_pending_sends.drain(0..double_pending_sends.len()) - { - loop_ctx.enqueue_send(pending); + if pending_sends_event.val < 0xdeadbeef { + double_pending_sends = pending_sends.swap(double_pending_sends); + loop_ctx.push_with_token( + pending_sends_event.io_uring_entry(), + Token::PendingsSends, + ); + + for pending in + double_pending_sends.drain(0..double_pending_sends.len()) + { + loop_ctx.enqueue_send(pending); + } + } else { + tracing::info!("io-uring loop shutdown requested"); + break 'io; } } Token::Send { key } => { @@ -685,7 +552,7 @@ impl IoUringLoop { metrics::errors_total(send_dir, &source, &asn_info).inc(); metrics::packets_dropped_total(send_dir, &source, &asn_info) .inc(); - } else if ret as usize != packet.buffer.len() { + } else if ret as usize != packet.data.len() { metrics::packets_total(send_dir, &asn_info).inc(); metrics::errors_total( send_dir, @@ -698,10 +565,6 @@ impl IoUringLoop { metrics::bytes_total(send_dir, &asn_info).inc_by(ret as u64); } } - Token::Shutdown => { - tracing::info!("io-uring loop shutdown requested"); - break 'io; - } } } @@ -709,7 +572,7 @@ impl IoUringLoop { } })?; - Ok(rx) + Ok(()) } } diff --git a/src/components/proxy/packet_router.rs b/src/components/proxy/packet_router.rs index 0915cc0321..5d54c99e23 100644 --- a/src/components/proxy/packet_router.rs +++ b/src/components/proxy/packet_router.rs @@ -14,10 +14,7 @@ * limitations under the License. */ -use super::{ - sessions::{DownstreamReceiver, SessionKey}, - PipelineError, SessionPool, -}; +use super::{sessions::SessionKey, PipelineError, SessionPool}; use crate::{ filters::{Filter as _, ReadContext}, metrics, @@ -44,8 +41,6 @@ pub(crate) struct DownstreamPacket { pub struct DownstreamReceiveWorkerConfig { /// ID of the worker. pub worker_id: usize, - /// Socket with reused port from which the worker receives packets. - pub upstream_receiver: DownstreamReceiver, pub port: u16, pub config: Arc, pub sessions: Arc, @@ -137,21 +132,17 @@ impl DownstreamReceiveWorkerConfig { pub async fn spawn_receivers( config: Arc, socket: socket2::Socket, - num_workers: usize, + worker_sends: Vec<(super::PendingSends, super::PacketSendReceiver)>, sessions: &Arc, - upstream_receiver: DownstreamReceiver, buffer_pool: Arc, - shutdown: crate::ShutdownRx, -) -> crate::Result>> { +) -> crate::Result<()> { let (error_sender, mut error_receiver) = mpsc::channel(128); let port = crate::net::socket_port(&socket); - let mut worker_notifications = Vec::with_capacity(num_workers); - for worker_id in 0..num_workers { + for (worker_id, ws) in worker_sends.into_iter().enumerate() { let worker = DownstreamReceiveWorkerConfig { worker_id, - upstream_receiver: upstream_receiver.clone(), port, config: config.clone(), sessions: sessions.clone(), @@ -159,7 +150,7 @@ pub async fn spawn_receivers( buffer_pool: buffer_pool.clone(), }; - worker_notifications.push(worker.spawn(shutdown.clone()).await?); + worker.spawn(ws).await?; } drop(error_sender); @@ -197,5 +188,5 @@ pub async fn spawn_receivers( } }); - Ok(worker_notifications) + Ok(()) } diff --git a/src/components/proxy/packet_router/io_uring.rs b/src/components/proxy/packet_router/io_uring.rs index 2b535f41fd..a3bc554506 100644 --- a/src/components/proxy/packet_router/io_uring.rs +++ b/src/components/proxy/packet_router/io_uring.rs @@ -14,18 +14,18 @@ * limitations under the License. */ +use crate::components::proxy; use eyre::Context as _; impl super::DownstreamReceiveWorkerConfig { pub async fn spawn( self, - shutdown: crate::ShutdownRx, - ) -> eyre::Result> { + pending_sends: (proxy::PendingSends, proxy::PacketSendReceiver), + ) -> eyre::Result<()> { use crate::components::proxy::io_uring_shared; let Self { worker_id, - upstream_receiver, port, config, sessions, @@ -47,9 +47,8 @@ impl super::DownstreamReceiveWorkerConfig { worker_id, destinations: Vec::with_capacity(1), }, - io_uring_shared::PacketReceiver::Router(upstream_receiver), + pending_sends, buffer_pool, - shutdown, ) .context("failed to spawn io-uring loop") } diff --git a/src/components/proxy/packet_router/reference.rs b/src/components/proxy/packet_router/reference.rs index 30b5360964..694d5eae66 100644 --- a/src/components/proxy/packet_router/reference.rs +++ b/src/components/proxy/packet_router/reference.rs @@ -16,14 +16,15 @@ //! The reference implementation is used for non-Linux targets +use crate::components::proxy; + impl super::DownstreamReceiveWorkerConfig { pub async fn spawn( self, - _shutdown: crate::ShutdownRx, - ) -> eyre::Result> { + pending_sends: (proxy::PendingSends, proxy::PacketSendReceiver), + ) -> eyre::Result<()> { let Self { worker_id, - upstream_receiver, port, config, sessions, @@ -31,10 +32,9 @@ impl super::DownstreamReceiveWorkerConfig { buffer_pool, } = self; - let (tx, rx) = std::sync::mpsc::channel(); - let thread_span = uring_span!(tracing::debug_span!("receiver", id = worker_id).or_current()); + let (tx, mut rx) = tokio::sync::oneshot::channel(); let worker = uring_spawn!(thread_span, async move { let mut last_received_at = None; @@ -46,56 +46,49 @@ impl super::DownstreamReceiveWorkerConfig { let send_socket = socket.clone(); let inner_task = async move { - let _ = tx.send(()); + let (pending_sends, mut sends_rx) = pending_sends; + let mut sends_double_buffer = Vec::with_capacity(pending_sends.capacity()); - loop { - tokio::select! { - result = upstream_receiver.recv() => { - match result { - Err(error) => { - tracing::trace!(%error, "error receiving packet"); - crate::metrics::errors_total( - crate::metrics::WRITE, - &error.to_string(), - &crate::metrics::EMPTY, - ) - .inc(); - } - Ok(crate::components::proxy::SendPacket { - destination, - asn_info, - data, - }) => { - let (result, _) = send_socket.send_to(data, destination).await; - let asn_info = asn_info.as_ref().into(); - match result { - Ok(size) => { - crate::metrics::packets_total(crate::metrics::WRITE, &asn_info) - .inc(); - crate::metrics::bytes_total(crate::metrics::WRITE, &asn_info) - .inc_by(size as u64); - } - Err(error) => { - let source = error.to_string(); - crate::metrics::errors_total( - crate::metrics::WRITE, - &source, - &asn_info, - ) - .inc(); - crate::metrics::packets_dropped_total( - crate::metrics::WRITE, - &source, - &asn_info, - ) - .inc(); - } - } - } + while sends_rx.changed().await.is_ok() { + if !*sends_rx.borrow() { + tracing::trace!("io loop shutdown requested"); + break; + } + + sends_double_buffer = pending_sends.swap(sends_double_buffer); + + for packet in sends_double_buffer.drain(..sends_double_buffer.len()) { + let (result, _) = send_socket + .send_to(packet.data, packet.destination.as_socket().unwrap()) + .await; + let asn_info = packet.asn_info.as_ref().into(); + match result { + Ok(size) => { + crate::metrics::packets_total(crate::metrics::WRITE, &asn_info) + .inc(); + crate::metrics::bytes_total(crate::metrics::WRITE, &asn_info) + .inc_by(size as u64); + } + Err(error) => { + let source = error.to_string(); + crate::metrics::errors_total( + crate::metrics::WRITE, + &source, + &asn_info, + ) + .inc(); + crate::metrics::packets_dropped_total( + crate::metrics::WRITE, + &source, + &asn_info, + ) + .inc(); } } } } + + let _ = tx.send(()); }; cfg_if::cfg_if! { @@ -115,34 +108,42 @@ impl super::DownstreamReceiveWorkerConfig { // packet, which is the maximum value of 16 a bit integer. let buffer = buffer_pool.clone().alloc(); - let (result, contents) = socket.recv_from(buffer).await; - let received_at = crate::time::UtcTimestamp::now(); + tokio::select! { + received = socket.recv_from(buffer) => { + let received_at = crate::time::UtcTimestamp::now(); + let (result, buffer) = received; - match result { - Ok((_size, mut source)) => { - source.set_ip(source.ip().to_canonical()); - let packet = super::DownstreamPacket { contents, source }; + match result { + Ok((_size, mut source)) => { + source.set_ip(source.ip().to_canonical()); + let packet = super::DownstreamPacket { contents: buffer, source }; - if let Some(last_received_at) = last_received_at { - crate::metrics::packet_jitter( - crate::metrics::READ, - &crate::metrics::EMPTY, - ) - .set((received_at - last_received_at).nanos()); + if let Some(last_received_at) = last_received_at { + crate::metrics::packet_jitter( + crate::metrics::READ, + &crate::metrics::EMPTY, + ) + .set((received_at - last_received_at).nanos()); + } + last_received_at = Some(received_at); + + Self::process_task( + packet, + worker_id, + &config, + &sessions, + &mut error_acc, + &mut destinations, + ); + } + Err(error) => { + tracing::error!(%error, "error receiving packet"); + return; + } } - last_received_at = Some(received_at); - - Self::process_task( - packet, - worker_id, - &config, - &sessions, - &mut error_acc, - &mut destinations, - ); } - Err(error) => { - tracing::error!(%error, "error receiving packet"); + _ = &mut rx => { + tracing::debug!("Closing downstream socket loop, shutdown requested"); return; } } @@ -151,6 +152,6 @@ impl super::DownstreamReceiveWorkerConfig { use eyre::WrapErr as _; worker.recv().context("failed to spawn receiver task")?; - Ok(rx) + Ok(()) } } diff --git a/src/components/proxy/sessions.rs b/src/components/proxy/sessions.rs index 0fcd6a2563..0ee3a57ddd 100644 --- a/src/components/proxy/sessions.rs +++ b/src/components/proxy/sessions.rs @@ -18,38 +18,38 @@ use std::{ collections::{HashMap, HashSet}, fmt, net::SocketAddr, - sync::Arc, + sync::{atomic, Arc}, time::Duration, }; -use tokio::{sync::mpsc, time::Instant}; +use tokio::time::Instant; use crate::{ - components::proxy::{PipelineError, SendPacket}, + components::proxy::SendPacket, config::Config, filters::Filter, metrics, net::maxmind_db::{IpNetEntry, MetricsIpNetEntry}, pool::{BufferPool, FrozenPoolBuffer, PoolBuffer}, time::UtcTimestamp, - Loggable, ShutdownRx, + Loggable, }; use parking_lot::RwLock; +use super::PendingSends; + pub(crate) mod inner_metrics; pub type SessionMap = crate::collections::ttl::TtlMap; -#[cfg(target_os = "linux")] -mod io_uring; -#[cfg(not(target_os = "linux"))] -mod reference; - -type UpstreamSender = mpsc::Sender; - -type DownstreamSender = async_channel::Sender; -pub type DownstreamReceiver = async_channel::Receiver; +cfg_if::cfg_if! { + if #[cfg(target_os = "linux")] { + mod io_uring; + } else { + mod reference; + } +} #[derive(PartialEq, Eq, Hash)] pub enum SessionError { @@ -90,13 +90,13 @@ impl fmt::Debug for SessionError { /// Traffic from different gameservers is then demuxed using their address to /// send back to the original client. pub struct SessionPool { - ports_to_sockets: RwLock>, + ports_to_sockets: RwLock>, storage: Arc>, session_map: SessionMap, - downstream_sender: DownstreamSender, buffer_pool: Arc, - shutdown_rx: ShutdownRx, config: Arc, + downstream_sends: Vec, + downstream_index: atomic::AtomicUsize, } /// The wrapper struct responsible for holding all of the socket related mappings. @@ -114,21 +114,20 @@ impl SessionPool { /// to release their sockets back to the parent. pub fn new( config: Arc, - downstream_sender: DownstreamSender, + downstream_sends: Vec, buffer_pool: Arc, - shutdown_rx: ShutdownRx, ) -> Arc { const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); Arc::new(Self { config, - downstream_sender, - shutdown_rx, ports_to_sockets: <_>::default(), storage: <_>::default(), session_map: SessionMap::new(SESSION_TIMEOUT_SECONDS, SESSION_EXPIRY_POLL_INTERVAL), buffer_pool, + downstream_sends, + downstream_index: atomic::AtomicUsize::new(0), }) } @@ -136,7 +135,7 @@ impl SessionPool { fn create_new_session_from_new_socket<'pool>( self: &'pool Arc, key: SessionKey, - ) -> Result<(Option, UpstreamSender), super::PipelineError> { + ) -> Result<(Option, PendingSends), super::PipelineError> { tracing::trace!(source=%key.source, dest=%key.dest, "creating new socket for session"); let raw_socket = crate::net::raw_socket_with_reuse(0)?; let port = raw_socket @@ -144,19 +143,15 @@ impl SessionPool { .as_socket() .ok_or(SessionError::SocketAddressUnavailable)? .port(); - let (downstream_sender, downstream_receiver) = mpsc::channel::(15); - let initialised = self - .clone() - .spawn_session(raw_socket, port, downstream_receiver)?; - initialised - .recv() - .map_err(|_err| PipelineError::ChannelClosed)?; + let (pending_sends, srecv) = super::PendingSends::new(15)?; + self.clone() + .spawn_session(raw_socket, port, (pending_sends.clone(), srecv))?; self.ports_to_sockets .write() - .insert(port, downstream_sender.clone()); - self.create_session_from_existing_socket(key, downstream_sender, port) + .insert(port, pending_sends.clone()); + self.create_session_from_existing_socket(key, pending_sends, port) } pub(crate) fn process_received_upstream_packet( @@ -192,7 +187,6 @@ impl SessionPool { let _timer = metrics::processing_time(metrics::WRITE).start_timer(); Self::process_recv_packet( self.config.clone(), - &self.downstream_sender, recv_addr, downstream_addr, asn_info, @@ -200,13 +194,25 @@ impl SessionPool { ) }; - if let Err((asn_info, error)) = result { - error.log(); - let label = format!("proxy::Session::process_recv_packet: {error}"); - let asn_metric_info = asn_info.as_ref().into(); + match result { + Ok(packet) => { + let index = self + .downstream_index + .fetch_add(1, atomic::Ordering::Relaxed) + % self.downstream_sends.len(); + // SAFETY: we've ensured it's within bounds via the % + unsafe { + self.downstream_sends.get_unchecked(index).push(packet); + } + } + Err((asn_info, error)) => { + error.log(); + let label = format!("proxy::Session::process_recv_packet: {error}"); + let asn_metric_info = asn_info.as_ref().into(); - metrics::packets_dropped_total(metrics::WRITE, &label, &asn_metric_info).inc(); - metrics::errors_total(metrics::WRITE, &label, &asn_metric_info).inc(); + metrics::packets_dropped_total(metrics::WRITE, &label, &asn_metric_info).inc(); + metrics::errors_total(metrics::WRITE, &label, &asn_metric_info).inc(); + } } } @@ -217,14 +223,14 @@ impl SessionPool { pub fn get<'pool>( self: &'pool Arc, key @ SessionKey { dest, .. }: SessionKey, - ) -> Result<(Option, UpstreamSender), super::PipelineError> { + ) -> Result<(Option, PendingSends), super::PipelineError> { tracing::trace!(source=%key.source, dest=%key.dest, "SessionPool::get"); // If we already have a session for the key pairing, return that session. if let Some(entry) = self.session_map.get(&key) { tracing::trace!("returning existing session"); return Ok(( entry.asn_info.as_ref().map(MetricsIpNetEntry::from), - entry.upstream_sender.clone(), + entry.pending_sends.clone(), )); } @@ -278,9 +284,9 @@ impl SessionPool { fn create_session_from_existing_socket<'session>( self: &'session Arc, key: SessionKey, - upstream_sender: UpstreamSender, + pending_sends: PendingSends, socket_port: u16, - ) -> Result<(Option, UpstreamSender), super::PipelineError> { + ) -> Result<(Option, PendingSends), super::PipelineError> { tracing::trace!(source=%key.source, dest=%key.dest, "reusing socket for session"); let asn_info = { let mut storage = self.storage.write(); @@ -313,7 +319,7 @@ impl SessionPool { let session = Session::new( key, - upstream_sender.clone(), + pending_sends.clone(), socket_port, self.clone(), asn_info, @@ -321,18 +327,17 @@ impl SessionPool { tracing::trace!("inserting session into map"); self.session_map.insert(key, session); tracing::trace!("session inserted"); - Ok((asn_metrics_info, upstream_sender)) + Ok((asn_metrics_info, pending_sends)) } /// process_recv_packet processes a packet that is received by this session. fn process_recv_packet( config: Arc, - downstream_sender: &DownstreamSender, source: SocketAddr, dest: SocketAddr, asn_info: Option, packet: PoolBuffer, - ) -> Result<(), (Option, Error)> { + ) -> Result, Error)> { tracing::trace!(%source, %dest, length = packet.len(), "received packet from upstream"); let mut context = crate::filters::WriteContext::new(source.into(), dest.into(), packet); @@ -341,21 +346,11 @@ impl SessionPool { return Err((asn_info, err.into())); } - let packet = context.contents.freeze(); - tracing::trace!(%source, %dest, length = packet.len(), "sending packet downstream"); - downstream_sender - .try_send(SendPacket { - data: packet, - destination: dest, - asn_info, - }) - .map_err(|error| match error { - async_channel::TrySendError::Closed(packet) => { - (packet.asn_info, Error::ChannelClosed) - } - async_channel::TrySendError::Full(packet) => (packet.asn_info, Error::ChannelFull), - })?; - Ok(()) + Ok(SendPacket { + data: context.contents.freeze(), + destination: dest.into(), + asn_info, + }) } /// Returns a map of active sessions. @@ -364,25 +359,30 @@ impl SessionPool { } /// Sends packet data to the appropiate session based on its `key`. + #[inline] pub fn send( self: &Arc, key: SessionKey, packet: FrozenPoolBuffer, ) -> Result<(), super::PipelineError> { - use tokio::sync::mpsc::error::TrySendError; + self.send_inner(key, packet)?; + Ok(()) + } + #[inline] + fn send_inner( + self: &Arc, + key: SessionKey, + packet: FrozenPoolBuffer, + ) -> Result { let (asn_info, sender) = self.get(key)?; - sender - .try_send(crate::components::proxy::SendPacket { - data: packet, - asn_info, - destination: key.dest, - }) - .map_err(|error| match error { - TrySendError::Closed(_) => super::PipelineError::ChannelClosed, - TrySendError::Full(_) => super::PipelineError::ChannelFull, - }) + sender.push(SendPacket { + destination: key.dest.into(), + data: packet, + asn_info, + }); + Ok(sender) } /// Returns whether the pool contains any sockets allocated to a destination. @@ -405,7 +405,7 @@ impl SessionPool { } /// Handles the logic of releasing a socket back into the pool. - async fn release_socket( + fn release_socket( self: Arc, SessionKey { ref source, @@ -440,11 +440,28 @@ impl SessionPool { storage.destination_to_sources.remove(&(*dest, port)); tracing::trace!("socket released"); } + + /// Closes all active sessions, and all downstream listeners + pub(crate) fn shutdown(self: Arc, wait: bool) { + // Disable downstream listeners first so sessions aren't spawned while + // we are trying to reap the active sessions + for downstream_listener in &self.downstream_sends { + downstream_listener.shutdown_receiver(); + } + + if wait && !self.session_map.is_empty() { + tracing::info!(sessions=%self.session_map.len(), "waiting for active sessions to expire"); + self.session_map.clear(); + } + } } impl Drop for SessionPool { fn drop(&mut self) { - drop(std::mem::take(&mut self.session_map)); + let map = std::mem::take(&mut self.session_map); + std::thread::spawn(move || { + drop(map); + }); } } @@ -456,8 +473,8 @@ pub struct Session { key: SessionKey, /// The socket port of the session. socket_port: u16, - /// The socket of the session. - upstream_sender: UpstreamSender, + /// The queue of packets being sent to the upstream (server) + pending_sends: PendingSends, /// The GeoIP information of the source. asn_info: Option, /// The socket pool of the session. @@ -467,14 +484,14 @@ pub struct Session { impl Session { pub fn new( key: SessionKey, - upstream_sender: UpstreamSender, + pending_sends: PendingSends, socket_port: u16, pool: Arc, asn_info: Option, ) -> Self { let s = Self { key, - upstream_sender, + pending_sends, pool, socket_port, asn_info, @@ -503,17 +520,18 @@ impl Session { inner_metrics::active_sessions(self.asn_info.as_ref()) } - fn async_drop(&mut self) -> impl std::future::Future { + fn release(&mut self) { self.active_session_metric().dec(); inner_metrics::duration_secs().observe(self.created_at.elapsed().as_secs() as f64); tracing::debug!(source = %self.key.source, dest_address = %self.key.dest, "Session closed"); - SessionPool::release_socket(self.pool.clone(), self.key, self.socket_port) + self.pending_sends.shutdown_receiver(); + SessionPool::release_socket(self.pool.clone(), self.key, self.socket_port); } } impl Drop for Session { fn drop(&mut self) { - tokio::spawn(self.async_drop()); + self.release() } } @@ -532,10 +550,6 @@ impl From<(SocketAddr, SocketAddr)> for SessionKey { #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("downstream channel closed")] - ChannelClosed, - #[error("downstream channel full")] - ChannelFull, #[error("filter {0}")] Filter(#[from] crate::filters::FilterError), } @@ -550,30 +564,24 @@ impl Loggable for Error { #[cfg(test)] mod tests { use super::*; - use crate::{ - test::{alloc_buffer, available_addr, AddressType, TestHelper}, - ShutdownTx, - }; + use crate::test::{alloc_buffer, available_addr, AddressType, TestHelper}; use std::sync::Arc; - async fn new_pool() -> (Arc, ShutdownTx, DownstreamReceiver) { - let (tx, rx) = crate::make_shutdown_channel(crate::ShutdownKind::Testing); - let (sender, receiver) = async_channel::unbounded(); + async fn new_pool() -> (Arc, PendingSends) { + let (pending_sends, _srecv) = PendingSends::new(1).unwrap(); ( SessionPool::new( Arc::new(Config::default_agent()), - sender, + vec![pending_sends.clone()], Arc::new(BufferPool::default()), - rx, ), - tx, - receiver, + pending_sends, ) } #[tokio::test] async fn insert_and_release_single_socket() { - let (pool, _sender, _receiver) = new_pool().await; + let (pool, _receiver) = new_pool().await; let key = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -589,7 +597,7 @@ mod tests { #[tokio::test] async fn insert_and_release_multiple_sockets() { - let (pool, _sender, _receiver) = new_pool().await; + let (pool, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -614,7 +622,7 @@ mod tests { #[tokio::test] async fn same_address_uses_different_sockets() { - let (pool, _sender, _receiver) = new_pool().await; + let (pool, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -639,7 +647,7 @@ mod tests { #[tokio::test] async fn different_addresses_uses_same_socket() { - let (pool, _sender, _receiver) = new_pool().await; + let (pool, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -662,7 +670,7 @@ mod tests { #[tokio::test] async fn spawn_safe_same_destination() { - let (pool, _sender, _receiver) = new_pool().await; + let (pool, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -687,7 +695,7 @@ mod tests { #[tokio::test] async fn spawn_safe_different_destination() { - let (pool, _sender, _receiver) = new_pool().await; + let (pool, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -721,18 +729,14 @@ mod tests { let socket = tokio::net::UdpSocket::bind(source).await.unwrap(); let mut source = socket.local_addr().unwrap(); crate::test::map_addr_to_localhost(&mut source); - let (pool, _sender, receiver) = new_pool().await; + let (pool, _pending_sends) = new_pool().await; let key: SessionKey = (source, dest).into(); let msg = b"helloworld"; - pool.send(key, alloc_buffer(msg).freeze()).unwrap(); - - let packet = tokio::time::timeout(std::time::Duration::from_secs(1), receiver.recv()) - .await - .unwrap() - .unwrap(); + let pending = pool.send_inner(key, alloc_buffer(msg).freeze()).unwrap(); + let pending = pending.swap(Vec::new()); - assert_eq!(msg, &*packet.data); + assert_eq!(msg, &*pending[0].data); } } diff --git a/src/components/proxy/sessions/io_uring.rs b/src/components/proxy/sessions/io_uring.rs index ce709f8e41..d345689f67 100644 --- a/src/components/proxy/sessions/io_uring.rs +++ b/src/components/proxy/sessions/io_uring.rs @@ -14,6 +14,7 @@ * limitations under the License. */ +use crate::components::proxy; use std::sync::Arc; static SESSION_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); @@ -23,9 +24,9 @@ impl super::SessionPool { self: Arc, raw_socket: socket2::Socket, port: u16, - downstream_receiver: tokio::sync::mpsc::Receiver, - ) -> Result, crate::components::proxy::PipelineError> { - use crate::components::proxy::io_uring_shared; + pending_sends: (proxy::PendingSends, proxy::io_uring_shared::EventFd), + ) -> Result<(), proxy::PipelineError> { + use proxy::io_uring_shared; let pool = self; let id = SESSION_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed); @@ -36,14 +37,12 @@ impl super::SessionPool { crate::net::DualStackLocalSocket::from_raw(raw_socket), )?; let buffer_pool = pool.buffer_pool.clone(); - let shutdown = pool.shutdown_rx.clone(); io_loop.spawn( format!("session-{id}"), io_uring_shared::PacketProcessorCtx::SessionPool { pool, port }, - io_uring_shared::PacketReceiver::SessionPool(downstream_receiver), + pending_sends, buffer_pool, - shutdown, ) } } diff --git a/src/components/proxy/sessions/reference.rs b/src/components/proxy/sessions/reference.rs index 067fee4d23..bad4d2ebbc 100644 --- a/src/components/proxy/sessions/reference.rs +++ b/src/components/proxy/sessions/reference.rs @@ -14,20 +14,21 @@ * limitations under the License. */ +use crate::components::proxy; + impl super::SessionPool { pub(super) fn spawn_session( self: std::sync::Arc, raw_socket: socket2::Socket, port: u16, - mut downstream_receiver: tokio::sync::mpsc::Receiver, - ) -> Result, crate::components::proxy::PipelineError> { + pending_sends: (proxy::PendingSends, proxy::PacketSendReceiver), + ) -> Result<(), proxy::PipelineError> { let pool = self; - let rx = uring_spawn!( + uring_spawn!( uring_span!(tracing::debug_span!("session pool")), async move { let mut last_received_at = None; - let mut shutdown_rx = pool.shutdown_rx.clone(); let socket = std::sync::Arc::new(crate::net::DualStackLocalSocket::from_raw(raw_socket)); @@ -35,54 +36,48 @@ impl super::SessionPool { let (tx, mut rx) = tokio::sync::oneshot::channel(); uring_inner_spawn!(async move { - loop { - match downstream_receiver.recv().await { - None => { - crate::metrics::errors_total( - crate::metrics::WRITE, - "downstream channel closed", - &crate::metrics::EMPTY, - ) - .inc(); - break; - } - Some(crate::components::proxy::SendPacket { - destination, - data, - asn_info, - }) => { - tracing::trace!(%destination, length = data.len(), "sending packet upstream"); - let (result, _) = socket2.send_to(data, destination).await; - let asn_info = asn_info.as_ref().into(); - match result { - Ok(size) => { - crate::metrics::packets_total( - crate::metrics::READ, - &asn_info, - ) + let (pending_sends, mut sends_rx) = pending_sends; + let mut sends_double_buffer = Vec::with_capacity(pending_sends.capacity()); + + while sends_rx.changed().await.is_ok() { + if !*sends_rx.borrow() { + tracing::trace!("io loop shutdown requested"); + break; + } + + sends_double_buffer = pending_sends.swap(sends_double_buffer); + + for packet in sends_double_buffer.drain(..sends_double_buffer.len()) { + let destination = packet.destination.as_socket().unwrap(); + tracing::trace!( + %destination, + length = packet.data.len(), + "sending packet upstream" + ); + let (result, _) = socket2.send_to(packet.data, destination).await; + let asn_info = packet.asn_info.as_ref().into(); + match result { + Ok(size) => { + crate::metrics::packets_total(crate::metrics::READ, &asn_info) .inc(); - crate::metrics::bytes_total( - crate::metrics::READ, - &asn_info, - ) + crate::metrics::bytes_total(crate::metrics::READ, &asn_info) .inc_by(size as u64); - } - Err(error) => { - tracing::trace!(%error, "sending packet upstream failed"); - let source = error.to_string(); - crate::metrics::errors_total( - crate::metrics::READ, - &source, - &asn_info, - ) - .inc(); - crate::metrics::packets_dropped_total( - crate::metrics::READ, - &source, - &asn_info, - ) - .inc(); - } + } + Err(error) => { + tracing::trace!(%error, "sending packet upstream failed"); + let source = error.to_string(); + crate::metrics::errors_total( + crate::metrics::READ, + &source, + &asn_info, + ) + .inc(); + crate::metrics::packets_dropped_total( + crate::metrics::READ, + &source, + &asn_info, + ) + .inc(); } } } @@ -104,10 +99,6 @@ impl super::SessionPool { Ok((_size, recv_addr)) => pool.process_received_upstream_packet(buf, recv_addr, port, &mut last_received_at), } } - _ = shutdown_rx.changed() => { - tracing::debug!("Closing upstream socket loop"); - return; - } _ = &mut rx => { tracing::debug!("Closing upstream socket loop, downstream closed"); return; @@ -117,6 +108,6 @@ impl super::SessionPool { } ); - Ok(rx) + Ok(()) } }