diff --git a/benches/throughput.rs b/benches/throughput.rs index 2cd558d259..066dc7ca2a 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -56,6 +56,9 @@ static THROUGHPUT_SERVER_INIT: Lazy<()> = Lazy::new(|| { static FEEDBACK_LOOP: Lazy<()> = Lazy::new(|| { std::thread::spawn(|| { let socket = UdpSocket::bind(FEEDBACK_LOOP_ADDR).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); loop { let mut packet = [0; MESSAGE_SIZE]; @@ -74,6 +77,9 @@ fn throughput_benchmark(c: &mut Criterion) { // Sleep to give the servers some time to warm-up. std::thread::sleep(std::time::Duration::from_millis(500)); let socket = UdpSocket::bind(BENCH_LOOP_ADDR).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); let mut packet = [0; MESSAGE_SIZE]; let mut group = c.benchmark_group("throughput"); @@ -125,6 +131,9 @@ fn write_feedback(addr: SocketAddr) -> mpsc::Sender> { let (write_tx, write_rx) = mpsc::channel::>(); std::thread::spawn(move || { let socket = UdpSocket::bind(addr).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); let mut packet = [0; MESSAGE_SIZE]; let (_, source) = socket.recv_from(&mut packet).unwrap(); while let Ok(packet) = write_rx.recv() { @@ -142,6 +151,9 @@ fn readwrite_benchmark(c: &mut Criterion) { let (read_tx, read_rx) = mpsc::channel::>(); std::thread::spawn(move || { let socket = UdpSocket::bind(READ_LOOP_ADDR).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); let mut packet = [0; MESSAGE_SIZE]; loop { let (length, _) = socket.recv_from(&mut packet).unwrap(); @@ -164,9 +176,12 @@ fn readwrite_benchmark(c: &mut Criterion) { Lazy::force(&WRITE_SERVER_INIT); // Sleep to give the servers some time to warm-up. - std::thread::sleep(std::time::Duration::from_millis(500)); + std::thread::sleep(std::time::Duration::from_millis(150)); let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); // prime the direct write connection socket.send_to(PACKETS[0], direct_write_addr).unwrap(); diff --git a/src/cli.rs b/src/cli.rs index ead9543ac2..4c23fb5dce 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -436,7 +436,7 @@ mod tests { assert_eq!( "hello", - timeout(Duration::from_secs(5), rx.recv()) + timeout(Duration::from_millis(500), rx.recv()) .await .expect("should have received a packet") .unwrap() @@ -449,7 +449,7 @@ mod tests { let msg = b"hello\xFF\xFF\xFF"; socket.send_to(msg, &proxy_address).await.unwrap(); - let result = timeout(Duration::from_secs(3), rx.recv()).await; + let result = timeout(Duration::from_millis(500), rx.recv()).await; assert!(result.is_err(), "should not have received a packet"); tracing::info!(?token, "didn't receive bad packet"); } diff --git a/src/cli/proxy.rs b/src/cli/proxy.rs index bfefe31a8f..2e2e5cd084 100644 --- a/src/cli/proxy.rs +++ b/src/cli/proxy.rs @@ -120,11 +120,8 @@ impl Proxy { tracing::info!(port = self.port, proxy_id = &*id, "Starting"); let runtime_config = mode.unwrap_proxy(); - let sessions = SessionPool::new( - config.clone(), - DualStackLocalSocket::new(self.port)?, - shutdown_rx.clone(), - ); + let shared_socket = Arc::new(DualStackLocalSocket::new(self.port)?); + let sessions = SessionPool::new(config.clone(), shared_socket.clone(), shutdown_rx.clone()); let _xds_stream = if !self.management_server.is_empty() { { @@ -153,7 +150,7 @@ impl Proxy { None }; - self.run_recv_from(&config, sessions.clone())?; + self.run_recv_from(&config, &sessions, shared_socket)?; crate::protocol::spawn(self.qcmp_port).await?; tracing::info!("Quilkin is ready"); @@ -177,18 +174,29 @@ impl Proxy { /// This function also spawns the set of worker tasks responsible for consuming packets /// off the aforementioned queue and processing them through the filter chain and session /// pipeline. - fn run_recv_from(&self, config: &Arc, sessions: Arc) -> Result<()> { + fn run_recv_from( + &self, + config: &Arc, + sessions: &Arc, + shared_socket: Arc, + ) -> Result<()> { // The number of worker tasks to spawn. Each task gets a dedicated queue to // consume packets off. let num_workers = num_cpus::get(); // Contains config for each worker task. let mut workers = Vec::with_capacity(num_workers); - for worker_id in 0..num_workers { - let socket = Arc::new(DualStackLocalSocket::new(self.port)?); + workers.push(crate::proxy::DownstreamReceiveWorkerConfig { + worker_id: 0, + socket: shared_socket, + config: config.clone(), + sessions: sessions.clone(), + }); + + for worker_id in 1..num_workers { workers.push(crate::proxy::DownstreamReceiveWorkerConfig { worker_id, - socket: socket.clone(), + socket: Arc::new(DualStackLocalSocket::new(self.port)?), config: config.clone(), sessions: sessions.clone(), }) @@ -242,6 +250,7 @@ mod tests { t.run_server(config, proxy, None); + tracing::trace!(%local_addr, "sending hello"); let msg = "hello"; endpoint1 .socket @@ -250,14 +259,14 @@ mod tests { .unwrap(); assert_eq!( msg, - timeout(Duration::from_secs(1), endpoint1.packet_rx) + timeout(Duration::from_secs(100), endpoint1.packet_rx) .await .expect("should get a packet") .unwrap() ); assert_eq!( msg, - timeout(Duration::from_secs(1), endpoint2.packet_rx) + timeout(Duration::from_secs(100), endpoint2.packet_rx) .await .expect("should get a packet") .unwrap() @@ -373,12 +382,14 @@ mod tests { config: config.clone(), sessions: SessionPool::new( config, - DualStackLocalSocket::new( - crate::test_utils::available_addr(&AddressType::Random) - .await - .port(), - ) - .unwrap(), + Arc::new( + DualStackLocalSocket::new( + crate::test_utils::available_addr(&AddressType::Random) + .await + .port(), + ) + .unwrap(), + ), tokio::sync::watch::channel(()).1, ), } @@ -418,18 +429,23 @@ mod tests { ) }); - let sessions = SessionPool::new( - config.clone(), + let shared_socket = Arc::new( DualStackLocalSocket::new( crate::test_utils::available_addr(&AddressType::Random) .await .port(), ) .unwrap(), + ); + let sessions = SessionPool::new( + config.clone(), + shared_socket.clone(), tokio::sync::watch::channel(()).1, ); - proxy.run_recv_from(&config, sessions).unwrap(); + proxy + .run_recv_from(&config, &sessions, shared_socket) + .unwrap(); let socket = create_socket().await; socket.send_to(msg.as_bytes(), &local_addr).await.unwrap(); diff --git a/src/proxy/sessions.rs b/src/proxy/sessions.rs index a581d65a48..3194d944a9 100644 --- a/src/proxy/sessions.rs +++ b/src/proxy/sessions.rs @@ -17,12 +17,14 @@ use std::{ collections::{HashMap, HashSet}, net::SocketAddr, - pin::Pin, sync::Arc, time::Duration, }; -use tokio::{sync::watch, time::Instant}; +use tokio::{ + sync::{watch, RwLock}, + time::Instant, +}; use crate::{ config::Config, @@ -31,8 +33,6 @@ use crate::{ utils::{net::DualStackLocalSocket, Loggable}, }; -use parking_lot::RwLock; - use dashmap::DashMap; pub(crate) mod metrics; @@ -52,7 +52,7 @@ type SessionRef<'pool> = #[derive(Debug)] pub struct SessionPool { ports_to_sockets: DashMap>, - storage: RwLock, + storage: Arc>, session_map: SessionMap, downstream_socket: Arc, shutdown_rx: watch::Receiver<()>, @@ -74,7 +74,7 @@ impl SessionPool { /// to release their sockets back to the parent. pub fn new( config: Arc, - downstream_socket: DualStackLocalSocket, + downstream_socket: Arc, shutdown_rx: watch::Receiver<()>, ) -> Arc { const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); @@ -82,7 +82,7 @@ impl SessionPool { Arc::new(Self { config, - downstream_socket: Arc::new(downstream_socket), + downstream_socket, shutdown_rx, ports_to_sockets: <_>::default(), storage: <_>::default(), @@ -120,7 +120,7 @@ impl SessionPool { crate::utils::net::to_canonical(&mut recv_addr); tracing::trace!(%recv_addr, %size, "received packet"); let (downstream_addr, asn_info): (SocketAddr, Option) = { - let storage = pool.storage.read(); + let storage = pool.storage.read().await; let Some(downstream_addr) = storage.destination_to_sources.get(&(recv_addr, port)) else { tracing::warn!(address=%recv_addr, "received traffic from a server that has no downstream"); continue; @@ -170,54 +170,39 @@ impl SessionPool { }); self.create_session_from_existing_socket(key, socket, port, asn_info) + .await } /// Returns a reference to an existing session mapped to `key`, otherwise /// creates a new session either from a fresh socket, or if there are sockets /// allocated that are not reserved by an existing destination, using the /// existing socket. - // This uses dynamic dispatch because we're using `parking_lot`, and we - // to prove that we're not holding a lock across an await point. We're - // using `parking_lot` because there's no async drop, so we can't lock - // on drop currently. - pub fn get<'pool>( + pub async fn get<'pool>( self: &'pool Arc, key @ SessionKey { dest, .. }: SessionKey, asn_info: Option, - ) -> Pin< - Box< - dyn std::future::Future, super::PipelineError>> - + Send - + 'pool, - >, - > { + ) -> Result, super::PipelineError> { // If we already have a session for the key pairing, return that session. if let Some(entry) = self.session_map.get(&key) { - return Box::pin(std::future::ready(Ok(entry))); + return Ok(entry); } // If there's a socket_set available, it means there are sockets // allocated to the address that we want to avoid. - let storage = self.storage.read(); + let storage = self.storage.read().await; let Some(socket_set) = storage.destination_to_sockets.get(&dest) else { drop(storage); return if self.ports_to_sockets.is_empty() { // Initial case where we have no allocated or reserved sockets. - Box::pin(self.create_new_session_from_new_socket(key, asn_info)) + self.create_new_session_from_new_socket(key, asn_info).await } else { // Where we have no allocated sockets for a destination, assign // the first available one. let entry = self.ports_to_sockets.iter().next().unwrap(); let port = *entry.key(); - Box::pin(std::future::ready( - self.create_session_from_existing_socket( - key, - entry.value().clone(), - port, - asn_info, - ), - )) + self.create_session_from_existing_socket(key, entry.value().clone(), port, asn_info) + .await }; }; @@ -229,33 +214,33 @@ impl SessionPool { drop(storage); self.storage .write() + .await .destination_to_sockets .get_mut(&dest) .unwrap() .insert(*entry.key()); - Box::pin(std::future::ready( - self.create_session_from_existing_socket( - key, - entry.value().clone(), - *entry.key(), - asn_info, - ), - )) + self.create_session_from_existing_socket( + key, + entry.value().clone(), + *entry.key(), + asn_info, + ) + .await } else { drop(storage); - Box::pin(self.create_new_session_from_new_socket(key, asn_info)) + self.create_new_session_from_new_socket(key, asn_info).await } } /// Using an existing socket, reserves the socket for a new session. - fn create_session_from_existing_socket<'session>( + async fn create_session_from_existing_socket<'session>( self: &'session Arc, key: SessionKey, upstream_socket: Arc, socket_port: u16, asn_info: Option, ) -> Result, super::PipelineError> { - let mut storage = self.storage.write(); + let mut storage = self.storage.write().await; storage .destination_to_sockets .entry(key.dest) @@ -276,12 +261,15 @@ impl SessionPool { .insert(key.source, asn_info.clone()); } - self.session_map.insert( - key, - Session::new(key, upstream_socket, socket_port, self.clone(), asn_info)?, - ); - - Ok(self.session_map.get(&key).unwrap()) + drop(storage); + let session = Session::new(key, upstream_socket, socket_port, self.clone(), asn_info)?; + Ok(match self.session_map.entry(key) { + crate::ttl_map::Entry::Occupied(mut entry) => { + entry.insert(session); + entry.into_ref() + } + crate::ttl_map::Entry::Vacant(v) => v.insert(session).downgrade(), + }) } /// process_recv_packet processes a packet that is received by this session. @@ -328,8 +316,8 @@ impl SessionPool { /// Returns whether the pool contains any sockets allocated to a destination. #[cfg(test)] - fn has_no_allocated_sockets(&self) -> bool { - let storage = self.storage.read(); + async fn has_no_allocated_sockets(&self) -> bool { + let storage = self.storage.read().await; let is_empty = storage.destination_to_sockets.is_empty(); // These should always be the same. debug_assert!(!(is_empty ^ storage.sockets_to_destination.is_empty())); @@ -338,21 +326,24 @@ impl SessionPool { /// Forces removal of session to make testing quicker. #[cfg(test)] - fn drop_session(&self, key: SessionKey, session: SessionRef) -> bool { + async fn drop_session(&self, key: SessionKey, session: SessionRef<'_>) -> bool { drop(session); - self.session_map.remove(key) + let is_removed = self.session_map.remove(key); + // Sleep because there's no async drop. + tokio::time::sleep(Duration::from_millis(100)).await; + is_removed } /// Handles the logic of releasing a socket back into the pool. - fn release_socket( - &self, + async fn release_socket( + self: Arc, SessionKey { ref source, ref dest, }: SessionKey, port: u16, ) { - let mut storage = self.storage.write(); + let mut storage = self.storage.write().await; let socket_set = storage.destination_to_sockets.get_mut(dest).unwrap(); assert!(socket_set.remove(&port)); @@ -403,8 +394,6 @@ impl Session { pool: Arc, asn_info: Option, ) -> Result { - tracing::debug!(source = %key.source, dest = %key.dest, "Session created"); - let s = Self { key, socket, @@ -428,24 +417,30 @@ impl Session { self::metrics::total_sessions().inc(); s.active_session_metric().inc(); + tracing::debug!(source = %key.source, dest = %key.dest, "Session created"); Ok(s) } pub async fn send(&self, packet: &[u8]) -> std::io::Result { + tracing::trace!(dest=%self.key.dest, "sending packet upstream"); self.socket.send_to(packet, self.key.dest).await } fn active_session_metric(&self) -> prometheus::IntGauge { metrics::active_sessions(self.asn_info.as_ref()) } -} -impl Drop for Session { - fn drop(&mut self) { + fn async_drop(&mut self) -> impl std::future::Future { self.active_session_metric().dec(); metrics::duration_secs().observe(self.created_at.elapsed().as_secs() as f64); tracing::debug!(source = %self.key.source, dest_address = %self.key.dest, "Session closed"); - self.pool.release_socket(self.key, self.socket_port); + SessionPool::release_socket(self.pool.clone(), self.key, self.socket_port) + } +} + +impl Drop for Session { + fn drop(&mut self) { + tokio::spawn(self.async_drop()); } } @@ -494,12 +489,14 @@ mod tests { ( SessionPool::new( Arc::new(config.into().unwrap_or_default()), - DualStackLocalSocket::new( - crate::test_utils::available_addr(&AddressType::Random) - .await - .port(), - ) - .unwrap(), + Arc::new( + DualStackLocalSocket::new( + crate::test_utils::available_addr(&AddressType::Random) + .await + .port(), + ) + .unwrap(), + ), rx, ), tx, @@ -517,9 +514,9 @@ mod tests { let session = pool.get(key, None).await.unwrap(); - assert!(pool.drop_session(key, session)); + assert!(pool.drop_session(key, session).await); - assert!(pool.has_no_allocated_sockets()); + assert!(pool.has_no_allocated_sockets().await); } #[tokio::test] @@ -539,11 +536,11 @@ mod tests { let session1 = pool.get(key1, None).await.unwrap(); let session2 = pool.get(key2, None).await.unwrap(); - assert!(pool.drop_session(key1, session1)); - assert!(!pool.has_no_allocated_sockets()); - assert!(pool.drop_session(key2, session2)); + assert!(pool.drop_session(key1, session1).await); + assert!(!pool.has_no_allocated_sockets().await); + assert!(pool.drop_session(key2, session2).await); - assert!(pool.has_no_allocated_sockets()); + assert!(pool.has_no_allocated_sockets().await); drop(pool); } diff --git a/src/test_utils.rs b/src/test_utils.rs index e09d847ddc..633e5c2360 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -250,6 +250,7 @@ impl TestHelper { recvd = socket.recv_from(&mut buf) => { let (size, sender) = recvd.unwrap(); let packet = &buf[..size]; + tracing::trace!(%sender, %size, "echo server received and returning packet"); tap(sender, packet, local_addr); socket.send_to(packet, sender).await.unwrap(); }, diff --git a/src/ttl_map.rs b/src/ttl_map.rs index e8d703ae3d..2d609916b4 100644 --- a/src/ttl_map.rs +++ b/src/ttl_map.rs @@ -308,6 +308,19 @@ impl<'a, K, V> OccupiedEntry<'a, K, Value> where K: Eq + Hash, { + /// Returns a reference to the entry's value. + /// The value will be reset to expire at the configured TTL after the time of retrieval. + pub fn into_ref(self) -> Ref<'a, K, Value> { + match self.inner { + DashMapEntry::Occupied(entry) => { + let value = entry.into_ref(); + value.value().update_expiration(self.ttl); + value.downgrade() + } + _ => unreachable!("BUG: entry type should be occupied"), + } + } + /// Returns a reference to the entry's value. /// The value will be reset to expire at the configured TTL after the time of retrieval. pub fn get(&self) -> &Value { diff --git a/src/utils/net.rs b/src/utils/net.rs index 860b993a8e..8d41afe8da 100644 --- a/src/utils/net.rs +++ b/src/utils/net.rs @@ -223,4 +223,3 @@ pub fn to_canonical(addr: &mut SocketAddr) { addr.set_ip(ip); } - diff --git a/tests/filters.rs b/tests/filters.rs index d83e359e18..a57f317a92 100644 --- a/tests/filters.rs +++ b/tests/filters.rs @@ -125,6 +125,7 @@ async fn debug_filter() { // create an echo server as an endpoint. let echo = t.run_echo_server(&AddressType::Random).await; + tracing::trace!(%echo, "running echo server"); // create server configuration let server_port = 12247; let server_config = std::sync::Arc::new(quilkin::Config::default()); diff --git a/tests/local_rate_limit.rs b/tests/local_rate_limit.rs index 7b93007643..55068c5ba3 100644 --- a/tests/local_rate_limit.rs +++ b/tests/local_rate_limit.rs @@ -35,7 +35,8 @@ period: 1 "; let echo = t.run_echo_server(&AddressType::Random).await; - let server_addr = available_addr(&AddressType::Random).await; + let mut server_addr = available_addr(&AddressType::Random).await; + quilkin::test_utils::map_addr_to_localhost(&mut server_addr); let server_proxy = quilkin::cli::Proxy { port: server_addr.port(), ..<_>::default() @@ -53,20 +54,21 @@ period: 1 .map(std::sync::Arc::new) .unwrap(), ); + tracing::trace!("spawning server"); t.run_server(server_config, server_proxy, None); - tokio::time::sleep(Duration::from_millis(50)).await; let msg = "hello"; let (mut rx, socket) = t.open_socket_and_recv_multiple_packets().await; for _ in 0..3 { + tracing::trace!(%server_addr, %msg, "sending"); socket.send_to(msg.as_bytes(), &server_addr).await.unwrap(); } for _ in 0..2 { assert_eq!( msg, - timeout(Duration::from_secs(5), rx.recv()) + timeout(Duration::from_millis(500), rx.recv()) .await .unwrap() .unwrap() @@ -76,5 +78,7 @@ period: 1 // Allow enough time to have received any response. tokio::time::sleep(Duration::from_millis(100)).await; // Check that we do not get any response. - assert!(timeout(Duration::from_secs(1), rx.recv()).await.is_err()); + assert!(timeout(Duration::from_millis(500), rx.recv()) + .await + .is_err()); } diff --git a/tests/token_router.rs b/tests/token_router.rs index 8d52a3f7ea..92d03a2502 100644 --- a/tests/token_router.rs +++ b/tests/token_router.rs @@ -30,7 +30,6 @@ use quilkin::{ /// since they work in concert together. #[tokio::test] async fn token_router() { - quilkin::test_utils::enable_log("quilkin=trace"); let mut t = TestHelper::default(); let mut echo = t.run_echo_server(&AddressType::Ipv6).await; quilkin::test_utils::map_to_localhost(&mut echo).await;