From 439df952b71be8f789a94ad20371e6b52003836e Mon Sep 17 00:00:00 2001 From: XAMPPRocky <4464295+XAMPPRocky@users.noreply.github.com> Date: Mon, 16 Oct 2023 14:10:25 +0200 Subject: [PATCH] Refactor sessions to use socket pool (#815) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit refactors how we handle upstream connections to the gameservers. When profiling quilkin I noticed that there was a lot of time (~10–15%) being spent dropping the upstream socket through its Arc implementation that happened whenever a session was dropped. As I was thinking about how to solve this problem I also realised there was a second issue, which is that there is a limitation on how many connections Quilkin can hold at once, roughly ~16,383. Because after that we're likely to start encountering port exhaustion from the operating system, since each session is a unique socket. This brought me to the solution in this commit, which is that while we need to give each connection to the gameserver a unique port, we don't need to give a unique port across gameservers. So I refactored how we create sessions to use what I've called a "SessionPool". This pools the sockets for sessions into a map that is keyed by their destination. With this implementation this means that we now have a limit of ~16,000 connections per gameserver, which is far more than any gameserver could reasonably need. --- benches/throughput.rs | 17 +- src/cli.rs | 6 +- src/cli/proxy.rs | 80 +++- src/filters/chain.rs | 2 - src/filters/compress.rs | 4 - src/filters/debug.rs | 8 +- src/filters/firewall.rs | 16 +- src/filters/match.rs | 1 - src/filters/registry.rs | 2 +- src/filters/write.rs | 15 +- src/proxy.rs | 75 +--- src/proxy/sessions.rs | 775 +++++++++++++++++++++++++------------- src/test_utils.rs | 21 +- src/ttl_map.rs | 47 +++ src/utils/net.rs | 37 +- src/xds.rs | 10 +- tests/capture.rs | 7 +- tests/compress.rs | 5 +- tests/filter_order.rs | 5 +- tests/filters.rs | 1 + tests/firewall.rs | 12 +- tests/local_rate_limit.rs | 11 +- tests/match.rs | 8 +- tests/metrics.rs | 5 +- tests/no_filter.rs | 6 +- tests/token_router.rs | 11 +- 26 files changed, 764 insertions(+), 423 deletions(-) diff --git a/benches/throughput.rs b/benches/throughput.rs index 2cd558d259..066dc7ca2a 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -56,6 +56,9 @@ static THROUGHPUT_SERVER_INIT: Lazy<()> = Lazy::new(|| { static FEEDBACK_LOOP: Lazy<()> = Lazy::new(|| { std::thread::spawn(|| { let socket = UdpSocket::bind(FEEDBACK_LOOP_ADDR).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); loop { let mut packet = [0; MESSAGE_SIZE]; @@ -74,6 +77,9 @@ fn throughput_benchmark(c: &mut Criterion) { // Sleep to give the servers some time to warm-up. std::thread::sleep(std::time::Duration::from_millis(500)); let socket = UdpSocket::bind(BENCH_LOOP_ADDR).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); let mut packet = [0; MESSAGE_SIZE]; let mut group = c.benchmark_group("throughput"); @@ -125,6 +131,9 @@ fn write_feedback(addr: SocketAddr) -> mpsc::Sender> { let (write_tx, write_rx) = mpsc::channel::>(); std::thread::spawn(move || { let socket = UdpSocket::bind(addr).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); let mut packet = [0; MESSAGE_SIZE]; let (_, source) = socket.recv_from(&mut packet).unwrap(); while let Ok(packet) = write_rx.recv() { @@ -142,6 +151,9 @@ fn readwrite_benchmark(c: &mut Criterion) { let (read_tx, read_rx) = mpsc::channel::>(); std::thread::spawn(move || { let socket = UdpSocket::bind(READ_LOOP_ADDR).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); let mut packet = [0; MESSAGE_SIZE]; loop { let (length, _) = socket.recv_from(&mut packet).unwrap(); @@ -164,9 +176,12 @@ fn readwrite_benchmark(c: &mut Criterion) { Lazy::force(&WRITE_SERVER_INIT); // Sleep to give the servers some time to warm-up. - std::thread::sleep(std::time::Duration::from_millis(500)); + std::thread::sleep(std::time::Duration::from_millis(150)); let socket = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)).unwrap(); + socket + .set_read_timeout(Some(std::time::Duration::from_millis(500))) + .unwrap(); // prime the direct write connection socket.send_to(PACKETS[0], direct_write_addr).unwrap(); diff --git a/src/cli.rs b/src/cli.rs index 055e8bf628..4c23fb5dce 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -418,7 +418,7 @@ mod tests { std::fs::write(endpoints_file.path(), { config.clusters.write().insert_default( [Endpoint::with_metadata( - (std::net::Ipv4Addr::LOCALHOST, server_port).into(), + (std::net::Ipv6Addr::LOCALHOST, server_port).into(), crate::endpoint::Metadata { tokens: vec![token.clone()].into_iter().collect(), }, @@ -436,7 +436,7 @@ mod tests { assert_eq!( "hello", - timeout(Duration::from_secs(5), rx.recv()) + timeout(Duration::from_millis(500), rx.recv()) .await .expect("should have received a packet") .unwrap() @@ -449,7 +449,7 @@ mod tests { let msg = b"hello\xFF\xFF\xFF"; socket.send_to(msg, &proxy_address).await.unwrap(); - let result = timeout(Duration::from_secs(3), rx.recv()).await; + let result = timeout(Duration::from_millis(500), rx.recv()).await; assert!(result.is_err(), "should not have received a packet"); tracing::info!(?token, "didn't receive bad packet"); } diff --git a/src/cli/proxy.rs b/src/cli/proxy.rs index 4dfe702e8f..7c9072dd93 100644 --- a/src/cli/proxy.rs +++ b/src/cli/proxy.rs @@ -26,7 +26,7 @@ use std::{ use tonic::transport::Endpoint; use super::Admin; -use crate::{proxy::SessionMap, xds::ResourceType, Config, Result}; +use crate::{proxy::SessionPool, xds::ResourceType, Config, Result}; #[cfg(doc)] use crate::filters::FilterFactory; @@ -81,9 +81,6 @@ impl Proxy { mode: Admin, mut shutdown_rx: tokio::sync::watch::Receiver<()>, ) -> crate::Result<()> { - const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); - const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); - let _mmdb_task = self.mmdb.clone().map(|source| { tokio::spawn(async move { use crate::config::BACKOFF_INITIAL_DELAY_MILLISECONDS; @@ -122,8 +119,9 @@ impl Proxy { let id = config.id.load(); tracing::info!(port = self.port, proxy_id = &*id, "Starting"); - let sessions = SessionMap::new(SESSION_TIMEOUT_SECONDS, SESSION_EXPIRY_POLL_INTERVAL); let runtime_config = mode.unwrap_proxy(); + let shared_socket = Arc::new(DualStackLocalSocket::new(self.port)?); + let sessions = SessionPool::new(config.clone(), shared_socket.clone(), shutdown_rx.clone()); let _xds_stream = if !self.management_server.is_empty() { { @@ -152,7 +150,7 @@ impl Proxy { None }; - self.run_recv_from(&config, sessions.clone())?; + self.run_recv_from(&config, &sessions, shared_socket)?; crate::protocol::spawn(self.qcmp_port).await?; tracing::info!("Quilkin is ready"); @@ -161,10 +159,10 @@ impl Proxy { .await .map_err(|error| eyre::eyre!(error))?; - tracing::info!(sessions=%sessions.len(), "waiting for active sessions to expire"); - while sessions.is_not_empty() { + tracing::info!(sessions=%sessions.sessions().len(), "waiting for active sessions to expire"); + while sessions.sessions().is_not_empty() { tokio::time::sleep(Duration::from_secs(1)).await; - tracing::debug!(sessions=%sessions.len(), "sessions still active"); + tracing::debug!(sessions=%sessions.sessions().len(), "sessions still active"); } tracing::info!("all sessions expired"); @@ -176,18 +174,29 @@ impl Proxy { /// This function also spawns the set of worker tasks responsible for consuming packets /// off the aforementioned queue and processing them through the filter chain and session /// pipeline. - fn run_recv_from(&self, config: &Arc, sessions: SessionMap) -> Result<()> { + fn run_recv_from( + &self, + config: &Arc, + sessions: &Arc, + shared_socket: Arc, + ) -> Result<()> { // The number of worker tasks to spawn. Each task gets a dedicated queue to // consume packets off. let num_workers = num_cpus::get(); // Contains config for each worker task. let mut workers = Vec::with_capacity(num_workers); - for worker_id in 0..num_workers { - let socket = Arc::new(DualStackLocalSocket::new(self.port)?); + workers.push(crate::proxy::DownstreamReceiveWorkerConfig { + worker_id: 0, + socket: shared_socket, + config: config.clone(), + sessions: sessions.clone(), + }); + + for worker_id in 1..num_workers { workers.push(crate::proxy::DownstreamReceiveWorkerConfig { worker_id, - socket: socket.clone(), + socket: Arc::new(DualStackLocalSocket::new(self.port)?), config: config.clone(), sessions: sessions.clone(), }) @@ -241,6 +250,7 @@ mod tests { t.run_server(config, proxy, None); + tracing::trace!(%local_addr, "sending hello"); let msg = "hello"; endpoint1 .socket @@ -249,14 +259,14 @@ mod tests { .unwrap(); assert_eq!( msg, - timeout(Duration::from_secs(1), endpoint1.packet_rx) + timeout(Duration::from_millis(100), endpoint1.packet_rx) .await .expect("should get a packet") .unwrap() ); assert_eq!( msg, - timeout(Duration::from_secs(1), endpoint2.packet_rx) + timeout(Duration::from_millis(100), endpoint2.packet_rx) .await .expect("should get a packet") .unwrap() @@ -268,7 +278,8 @@ mod tests { let mut t = TestHelper::default(); let endpoint = t.open_socket_and_recv_single_packet().await; - let local_addr = available_addr(&AddressType::Random).await; + let mut local_addr = available_addr(&AddressType::Ipv6).await; + crate::test_utils::map_addr_to_localhost(&mut local_addr); let proxy = crate::cli::Proxy { port: local_addr.port(), ..<_>::default() @@ -277,14 +288,16 @@ mod tests { config.clusters.modify(|clusters| { clusters.insert_default( [Endpoint::new( - endpoint.socket.local_ipv4_addr().unwrap().into(), + endpoint.socket.local_ipv6_addr().unwrap().into(), )] .into(), ); }); t.run_server(config, proxy, None); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; let msg = "hello"; + tracing::debug!(%local_addr, "sending packet"); endpoint .socket .send_to(msg.as_bytes(), &local_addr) @@ -366,8 +379,19 @@ mod tests { crate::proxy::DownstreamReceiveWorkerConfig { worker_id: 1, socket: socket.clone(), - config, - sessions: <_>::default(), + config: config.clone(), + sessions: SessionPool::new( + config, + Arc::new( + DualStackLocalSocket::new( + crate::test_utils::available_addr(&AddressType::Random) + .await + .port(), + ) + .unwrap(), + ), + tokio::sync::watch::channel(()).1, + ), } .spawn(); @@ -405,7 +429,23 @@ mod tests { ) }); - proxy.run_recv_from(&config, <_>::default()).unwrap(); + let shared_socket = Arc::new( + DualStackLocalSocket::new( + crate::test_utils::available_addr(&AddressType::Random) + .await + .port(), + ) + .unwrap(), + ); + let sessions = SessionPool::new( + config.clone(), + shared_socket.clone(), + tokio::sync::watch::channel(()).1, + ); + + proxy + .run_recv_from(&config, &sessions, shared_socket) + .unwrap(); let socket = create_socket().await; socket.send_to(msg.as_bytes(), &local_addr).await.unwrap(); diff --git a/src/filters/chain.rs b/src/filters/chain.rs index 919f7843e7..552cdc09a9 100644 --- a/src/filters/chain.rs +++ b/src/filters/chain.rs @@ -369,7 +369,6 @@ mod tests { ); let mut context = WriteContext::new( - endpoints_fixture[0].clone(), endpoints_fixture[0].address.clone(), "127.0.0.1:70".parse().unwrap(), b"hello".to_vec(), @@ -417,7 +416,6 @@ mod tests { ); let mut context = WriteContext::new( - endpoints_fixture[0].clone(), endpoints_fixture[0].address.clone(), "127.0.0.1:70".parse().unwrap(), b"hello".to_vec(), diff --git a/src/filters/compress.rs b/src/filters/compress.rs index fc84e36d0e..a5950c4b9a 100644 --- a/src/filters/compress.rs +++ b/src/filters/compress.rs @@ -196,7 +196,6 @@ mod tests { // write decompress let mut write_context = WriteContext::new( - Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), read_context.contents.clone(), @@ -223,7 +222,6 @@ mod tests { assert!(compression .write(&mut WriteContext::new( - Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), b"hello".to_vec(), @@ -270,7 +268,6 @@ mod tests { assert_eq!(b"hello", &*read_context.contents); let mut write_context = WriteContext::new( - Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), b"hello".to_vec(), @@ -329,7 +326,6 @@ mod tests { let expected = contents_fixture(); // write compress let mut write_context = WriteContext::new( - Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), expected.clone(), diff --git a/src/filters/debug.rs b/src/filters/debug.rs index 7290ecd045..a9dc324f6b 100644 --- a/src/filters/debug.rs +++ b/src/filters/debug.rs @@ -50,8 +50,12 @@ impl Filter for Debug { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] async fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> { - info!(id = ?self.config.id, endpoint = ?ctx.endpoint.address, source = ?&ctx.source, - dest = ?&ctx.dest, contents = ?String::from_utf8_lossy(&ctx.contents), "Write filter event"); + info!( + id = ?self.config.id, + source = ?&ctx.source, + dest = ?&ctx.dest, + contents = ?String::from_utf8_lossy(&ctx.contents), "Write filter event" + ); Ok(()) } } diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs index 996803aa33..7698e49792 100644 --- a/src/filters/firewall.rs +++ b/src/filters/firewall.rs @@ -166,23 +166,13 @@ mod tests { }], }; - let endpoint = Endpoint::new((Ipv4Addr::LOCALHOST, 80).into()); let local_addr: crate::endpoint::EndpointAddress = (Ipv4Addr::LOCALHOST, 8081).into(); - let mut ctx = WriteContext::new( - endpoint.clone(), - ([192, 168, 75, 20], 80).into(), - local_addr.clone(), - vec![], - ); + let mut ctx = + WriteContext::new(([192, 168, 75, 20], 80).into(), local_addr.clone(), vec![]); assert!(firewall.write(&mut ctx).await.is_ok()); - let mut ctx = WriteContext::new( - endpoint, - ([192, 168, 77, 20], 80).into(), - local_addr, - vec![], - ); + let mut ctx = WriteContext::new(([192, 168, 77, 20], 80).into(), local_addr, vec![]); assert!(firewall.write(&mut ctx).await.is_err()); } } diff --git a/src/filters/match.rs b/src/filters/match.rs index 9522b4a9ec..af4810c56b 100644 --- a/src/filters/match.rs +++ b/src/filters/match.rs @@ -188,7 +188,6 @@ mod tests { // no config, so should make no change. filter .write(&mut WriteContext::new( - endpoint.clone(), endpoint.address, "127.0.0.1:70".parse().unwrap(), contents.clone(), diff --git a/src/filters/registry.rs b/src/filters/registry.rs index 42ef796135..f2fb7cf98a 100644 --- a/src/filters/registry.rs +++ b/src/filters/registry.rs @@ -113,7 +113,7 @@ mod tests { .await .is_ok()); assert!(filter - .write(&mut WriteContext::new(endpoint, addr.clone(), addr, vec![],)) + .write(&mut WriteContext::new(addr.clone(), addr, vec![],)) .await .is_ok()); } diff --git a/src/filters/write.rs b/src/filters/write.rs index 744d8f9df3..48d06800a7 100644 --- a/src/filters/write.rs +++ b/src/filters/write.rs @@ -16,10 +16,7 @@ use std::collections::HashMap; -use crate::{ - endpoint::{Endpoint, EndpointAddress}, - metadata::DynamicMetadata, -}; +use crate::{endpoint::EndpointAddress, metadata::DynamicMetadata}; #[cfg(doc)] use crate::filters::Filter; @@ -27,8 +24,6 @@ use crate::filters::Filter; /// The input arguments to [`Filter::write`]. #[non_exhaustive] pub struct WriteContext { - /// The upstream endpoint that we're expecting packets from. - pub endpoint: Endpoint, /// The source of the received packet. pub source: EndpointAddress, /// The destination of the received packet. @@ -41,14 +36,8 @@ pub struct WriteContext { impl WriteContext { /// Creates a new [`WriteContext`] - pub fn new( - endpoint: Endpoint, - source: EndpointAddress, - dest: EndpointAddress, - contents: Vec, - ) -> Self { + pub fn new(source: EndpointAddress, dest: EndpointAddress, contents: Vec) -> Self { Self { - endpoint, source, dest, contents, diff --git a/src/proxy.rs b/src/proxy.rs index f243a63d19..73edd51362 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -16,10 +16,9 @@ use std::{net::SocketAddr, sync::Arc}; -pub use sessions::{Session, SessionKey, SessionMap}; +pub use sessions::{Session, SessionKey, SessionPool}; use crate::{ - endpoint::{Endpoint, EndpointAddress}, filters::{Filter, ReadContext}, utils::net::DualStackLocalSocket, Config, @@ -44,7 +43,7 @@ pub(crate) struct DownstreamReceiveWorkerConfig { /// Socket with reused port from which the worker receives packets. pub socket: Arc, pub config: Arc, - pub sessions: SessionMap, + pub sessions: Arc, } impl DownstreamReceiveWorkerConfig { @@ -71,7 +70,8 @@ impl DownstreamReceiveWorkerConfig { tokio::select! { result = socket.recv_from(&mut buf) => { match result { - Ok((size, source)) => { + Ok((size, mut source)) => { + crate::utils::net::to_canonical(&mut source); let packet = DownstreamPacket { received_at: chrono::Utc::now().timestamp_nanos_opt().unwrap(), asn_info: crate::maxmind_db::MaxmindDb::lookup(source.ip()), @@ -88,7 +88,7 @@ impl DownstreamReceiveWorkerConfig { } last_received_at = Some(packet.received_at); - Self::spawn_process_task(packet, source, worker_id, &socket, &config, &sessions) + Self::spawn_process_task(packet, source, worker_id, &config, &sessions) } Err(error) => { tracing::error!(%error, "error receiving packet"); @@ -106,9 +106,8 @@ impl DownstreamReceiveWorkerConfig { packet: DownstreamPacket, source: std::net::SocketAddr, worker_id: usize, - socket: &Arc, config: &Arc, - sessions: &SessionMap, + sessions: &Arc, ) { tracing::trace!( id = worker_id, @@ -121,16 +120,13 @@ impl DownstreamReceiveWorkerConfig { tokio::spawn({ let config = config.clone(); let sessions = sessions.clone(); - let socket = socket.clone(); async move { let timer = crate::metrics::processing_time(crate::metrics::READ).start_timer(); let asn_info = packet.asn_info.clone(); let asn_info = asn_info.as_ref(); - match Self::process_downstream_received_packet(packet, config, socket, sessions) - .await - { + match Self::process_downstream_received_packet(packet, config, sessions).await { Ok(size) => { crate::metrics::packets_total(crate::metrics::READ, asn_info).inc(); crate::metrics::bytes_total(crate::metrics::READ, asn_info) @@ -157,8 +153,7 @@ impl DownstreamReceiveWorkerConfig { async fn process_downstream_received_packet( packet: DownstreamPacket, config: Arc, - downstream_socket: Arc, - sessions: SessionMap, + sessions: Arc, ) -> Result { let endpoints: Vec<_> = config.clusters.read().endpoints().collect(); if endpoints.is_empty() { @@ -171,56 +166,18 @@ impl DownstreamReceiveWorkerConfig { let mut bytes_written = 0; for endpoint in context.endpoints.iter() { - bytes_written += Self::session_send_packet( - &context.contents, - &context.source, - endpoint, - &downstream_socket, - &config, - &sessions, - packet.asn_info.clone(), - ) - .await?; + let session_key = SessionKey { + source: packet.source, + dest: endpoint.address.to_socket_addr().await?, + }; + + bytes_written += sessions + .send(session_key, packet.asn_info.clone(), &context.contents) + .await?; } Ok(bytes_written) } - - /// Send a packet received from `recv_addr` to an endpoint. - #[tracing::instrument(level="trace", skip_all, fields(source = %recv_addr, dest = %endpoint.address))] - async fn session_send_packet( - packet: &[u8], - recv_addr: &EndpointAddress, - endpoint: &Endpoint, - downstream_socket: &Arc, - config: &Arc, - sessions: &SessionMap, - asn_info: Option, - ) -> Result { - let session_key = SessionKey { - source: recv_addr.clone(), - dest: endpoint.address.clone(), - }; - - let send_future = match sessions.get(&session_key) { - Some(entry) => entry.send(packet), - None => { - let session = Session::new( - config.clone(), - session_key.source.clone(), - downstream_socket.clone(), - endpoint.clone(), - asn_info, - )?; - - let future = session.send(packet); - sessions.insert(session_key, session); - future - } - }; - - send_future.await - } } #[derive(thiserror::Error, Debug)] diff --git a/src/proxy/sessions.rs b/src/proxy/sessions.rs index bfc9b6d869..3194d944a9 100644 --- a/src/proxy/sessions.rs +++ b/src/proxy/sessions.rs @@ -14,168 +14,123 @@ * limitations under the License. */ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + net::SocketAddr, + sync::Arc, + time::Duration, +}; use tokio::{ - net::UdpSocket, - select, - sync::{watch, OnceCell}, + sync::{watch, RwLock}, time::Instant, }; use crate::{ - endpoint::{Endpoint, EndpointAddress}, - filters::{Filter, WriteContext}, + config::Config, + filters::Filter, maxmind_db::IpNetEntry, utils::{net::DualStackLocalSocket, Loggable}, }; +use dashmap::DashMap; + pub(crate) mod metrics; pub type SessionMap = crate::ttl_map::TtlMap; -/// Session encapsulates a UDP stream session -pub struct Session { - config: Arc, - /// created_at is time at which the session was created - created_at: Instant, - /// socket that sends and receives from and to the endpoint address - upstream_socket: Arc>>, - /// dest is where to send data to - dest: Endpoint, - /// address of original sender - source: EndpointAddress, - /// a channel to broadcast on if we are shutting down this Session - shutdown_tx: watch::Sender<()>, - /// The ASN information. - asn_info: Option, +type SessionRef<'pool> = + dashmap::mapref::one::Ref<'pool, SessionKey, crate::ttl_map::Value>; + +/// A data structure that is responsible for holding sessions, and pooling +/// sockets between them. This means that we only provide new unique sockets +/// to new connections to the same gameserver, and we share sockets across +/// multiple gameservers. +/// +/// Traffic from different gameservers is then demuxed using their address to +/// send back to the original client. +#[derive(Debug)] +pub struct SessionPool { + ports_to_sockets: DashMap>, + storage: Arc>, + session_map: SessionMap, + downstream_socket: Arc, + shutdown_rx: watch::Receiver<()>, + config: Arc, } -// A (source, destination) address pair that uniquely identifies a session. -#[derive(Clone, Eq, Hash, PartialEq, Debug, PartialOrd, Ord)] -pub struct SessionKey { - pub source: EndpointAddress, - pub dest: EndpointAddress, -} - -impl From<(EndpointAddress, EndpointAddress)> for SessionKey { - fn from((source, dest): (EndpointAddress, EndpointAddress)) -> Self { - SessionKey { source, dest } - } +/// The wrapper struct responsible for holding all of the socket related mappings. +#[derive(Default, Debug)] +struct SocketStorage { + destination_to_sockets: HashMap>, + destination_to_sources: HashMap<(SocketAddr, u16), SocketAddr>, + sources_to_asn_info: HashMap, + sockets_to_destination: HashMap>, } -/// ReceivedPacketContext contains state needed to process a received packet. -struct ReceivedPacketContext<'a> { - packet: &'a [u8], - config: Arc, - endpoint: &'a Endpoint, - source: EndpointAddress, - dest: EndpointAddress, -} - -impl Session { - /// internal constructor for a Session from SessionArgs - #[tracing::instrument(skip_all)] +impl SessionPool { + /// Constructs a new session pool, it's created with an `Arc` as that's + /// required for the pool to provide a reference to the children to be able + /// to release their sockets back to the parent. pub fn new( - config: Arc, - source: EndpointAddress, + config: Arc, downstream_socket: Arc, - dest: Endpoint, - asn_info: Option, - ) -> Result { - let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); - - let s = Session { - config: config.clone(), - upstream_socket: Arc::new(OnceCell::new()), - source: source.clone(), - dest, - created_at: Instant::now(), - shutdown_tx, - asn_info, - }; + shutdown_rx: watch::Receiver<()>, + ) -> Arc { + const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); + const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); - tracing::debug!(source = %s.source, dest = ?s.dest, "Session created"); - - if let Some(asn) = &s.asn_info { - tracing::debug!( - number = asn.r#as, - organization = asn.as_name, - country_code = asn.as_cc, - prefix = asn.prefix, - prefix_entity = asn.prefix_entity, - prefix_name = asn.prefix_name, - "maxmind information" - ); - } - - self::metrics::total_sessions().inc(); - s.active_session_metric().inc(); - s.run(downstream_socket, shutdown_rx); - Ok(s) - } - - fn upstream_socket( - &self, - ) -> impl std::future::Future, super::PipelineError>> { - let upstream_socket = self.upstream_socket.clone(); - let address = self.dest.address.clone(); - - async move { - let connect_addr = address.to_socket_addr().await?; - let bind_addr: SocketAddr = match connect_addr { - SocketAddr::V4(_) => (std::net::Ipv4Addr::UNSPECIFIED, 0).into(), - SocketAddr::V6(_) => (std::net::Ipv6Addr::UNSPECIFIED, 0).into(), - }; - - upstream_socket - .get_or_try_init(|| async { - let upstream_socket = UdpSocket::bind(bind_addr).await?; - upstream_socket.connect(connect_addr).await?; - Ok(Arc::new(upstream_socket)) - }) - .await - .cloned() - } + Arc::new(Self { + config, + downstream_socket, + shutdown_rx, + ports_to_sockets: <_>::default(), + storage: <_>::default(), + session_map: SessionMap::new(SESSION_TIMEOUT_SECONDS, SESSION_EXPIRY_POLL_INTERVAL), + }) } - /// run starts processing receiving upstream udp packets - /// and sending them back downstream - fn run( - &self, - downstream_socket: Arc, - mut shutdown_rx: watch::Receiver<()>, - ) { - let source = self.source.clone(); - let config = self.config.clone(); - let endpoint = self.dest.clone(); - let upstream_socket = self.upstream_socket(); - let asn_info = self.asn_info.clone(); + /// Allocates a new upstream socket from a new socket from the system. + async fn create_new_session_from_new_socket<'pool>( + self: &'pool Arc, + key: SessionKey, + asn_info: Option, + ) -> Result, super::PipelineError> { + let socket = DualStackLocalSocket::new(0).map(Arc::new)?; + let port = socket.local_ipv4_addr().unwrap().port(); + self.ports_to_sockets.insert(port, socket.clone()); + let upstream_socket = socket.clone(); + let pool = self.clone(); tokio::spawn(async move { let mut buf: Vec = vec![0; 65535]; let mut last_received_at = None; - let upstream_socket = match upstream_socket.await { - Ok(socket) => socket, - Err(error) => { - tracing::error!(%error, "upstream socket failed to initialise"); - return; - } - }; + let mut shutdown_rx = pool.shutdown_rx.clone(); loop { - tracing::debug!(source = %source, dest = ?endpoint, "Awaiting incoming packet"); - let asn_info = asn_info.as_ref(); - - select! { + tokio::select! { received = upstream_socket.recv_from(&mut buf) => { match received { Err(error) => { - crate::metrics::errors_total(crate::metrics::WRITE, &error.to_string(), asn_info).inc(); - tracing::error!(%error, %source, dest = ?endpoint, "Error receiving packet"); + tracing::trace!(%error, "error receiving packet"); + crate::metrics::errors_total(crate::metrics::WRITE, &error.to_string(), None).inc(); }, - Ok((size, recv_addr)) => { + Ok((size, mut recv_addr)) => { let received_at = chrono::Utc::now().timestamp_nanos_opt().unwrap(); + crate::utils::net::to_canonical(&mut recv_addr); + tracing::trace!(%recv_addr, %size, "received packet"); + let (downstream_addr, asn_info): (SocketAddr, Option) = { + let storage = pool.storage.read().await; + let Some(downstream_addr) = storage.destination_to_sources.get(&(recv_addr, port)) else { + tracing::warn!(address=%recv_addr, "received traffic from a server that has no downstream"); + continue; + }; + let asn_info = storage.sources_to_asn_info.get(downstream_addr); + + (*downstream_addr, asn_info.cloned()) + }; + + let asn_info = asn_info.as_ref(); if let Some(last_received_at) = last_received_at { crate::metrics::packet_jitter(crate::metrics::WRITE, asn_info).set(received_at - last_received_at); } @@ -185,15 +140,13 @@ impl Session { crate::metrics::bytes_total(crate::metrics::WRITE, asn_info).inc_by(size as u64); let timer = crate::metrics::processing_time(crate::metrics::WRITE).start_timer(); - let result = Session::process_recv_packet( - &downstream_socket, - ReceivedPacketContext { - config: config.clone(), - packet: &buf[..size], - endpoint: &endpoint, - source: recv_addr.into(), - dest: source.clone(), - }).await; + let result = Self::process_recv_packet( + pool.config.clone(), + &pool.downstream_socket, + recv_addr, + downstream_addr, + &buf[..size], + ).await; timer.stop_and_record(); if let Err(error) = result { error.log(); @@ -209,83 +162,303 @@ impl Session { }; } _ = shutdown_rx.changed() => { - tracing::debug!(%source, dest = ?endpoint, "Closing Session"); + tracing::debug!("Closing upstream socket loop"); return; } }; } }); + + self.create_session_from_existing_socket(key, socket, port, asn_info) + .await } - fn active_session_metric(&self) -> prometheus::IntGauge { - metrics::active_sessions(self.asn_info.as_ref()) + /// Returns a reference to an existing session mapped to `key`, otherwise + /// creates a new session either from a fresh socket, or if there are sockets + /// allocated that are not reserved by an existing destination, using the + /// existing socket. + pub async fn get<'pool>( + self: &'pool Arc, + key @ SessionKey { dest, .. }: SessionKey, + asn_info: Option, + ) -> Result, super::PipelineError> { + // If we already have a session for the key pairing, return that session. + if let Some(entry) = self.session_map.get(&key) { + return Ok(entry); + } + + // If there's a socket_set available, it means there are sockets + // allocated to the address that we want to avoid. + let storage = self.storage.read().await; + let Some(socket_set) = storage.destination_to_sockets.get(&dest) else { + drop(storage); + return if self.ports_to_sockets.is_empty() { + // Initial case where we have no allocated or reserved sockets. + self.create_new_session_from_new_socket(key, asn_info).await + } else { + // Where we have no allocated sockets for a destination, assign + // the first available one. + let entry = self.ports_to_sockets.iter().next().unwrap(); + let port = *entry.key(); + + self.create_session_from_existing_socket(key, entry.value().clone(), port, asn_info) + .await + }; + }; + + if let Some(entry) = self + .ports_to_sockets + .iter() + .find(|entry| !socket_set.contains(entry.key())) + { + drop(storage); + self.storage + .write() + .await + .destination_to_sockets + .get_mut(&dest) + .unwrap() + .insert(*entry.key()); + self.create_session_from_existing_socket( + key, + entry.value().clone(), + *entry.key(), + asn_info, + ) + .await + } else { + drop(storage); + self.create_new_session_from_new_socket(key, asn_info).await + } + } + + /// Using an existing socket, reserves the socket for a new session. + async fn create_session_from_existing_socket<'session>( + self: &'session Arc, + key: SessionKey, + upstream_socket: Arc, + socket_port: u16, + asn_info: Option, + ) -> Result, super::PipelineError> { + let mut storage = self.storage.write().await; + storage + .destination_to_sockets + .entry(key.dest) + .or_default() + .insert(socket_port); + storage + .sockets_to_destination + .entry(socket_port) + .or_default() + .insert(key.dest); + storage + .destination_to_sources + .insert((key.dest, socket_port), key.source); + + if let Some(asn_info) = &asn_info { + storage + .sources_to_asn_info + .insert(key.source, asn_info.clone()); + } + + drop(storage); + let session = Session::new(key, upstream_socket, socket_port, self.clone(), asn_info)?; + Ok(match self.session_map.entry(key) { + crate::ttl_map::Entry::Occupied(mut entry) => { + entry.insert(session); + entry.into_ref() + } + crate::ttl_map::Entry::Vacant(v) => v.insert(session).downgrade(), + }) } /// process_recv_packet processes a packet that is received by this session. async fn process_recv_packet( + config: Arc, downstream_socket: &Arc, - packet_ctx: ReceivedPacketContext<'_>, + source: SocketAddr, + dest: SocketAddr, + packet: &[u8], ) -> Result { - let ReceivedPacketContext { - packet, - config, - endpoint, - source: from, - dest, - } = packet_ctx; - - tracing::trace!(%from, dest = %endpoint.address, contents = %crate::utils::base64_encode(packet), "received packet from upstream"); + tracing::trace!(%source, %dest, contents = %crate::utils::base64_encode(packet), "received packet from upstream"); - let mut context = WriteContext::new( - endpoint.clone(), - from.clone(), - dest.clone(), - packet.to_vec(), - ); + let mut context = + crate::filters::WriteContext::new(source.into(), dest.into(), packet.to_vec()); config.filters.load().write(&mut context).await?; - let addr = dest.to_socket_addr().await.map_err(Error::ToSocketAddr)?; let packet = context.contents.as_ref(); - tracing::trace!(%from, dest = %addr, contents = %crate::utils::base64_encode(packet), "sending packet downstream"); + tracing::trace!(%source, %dest, contents = %crate::utils::base64_encode(packet), "sending packet downstream"); downstream_socket - .send_to(packet, &addr) + .send_to(packet, &dest) .await .map_err(Error::SendTo) } - /// Sends a packet to the Session's dest. - pub fn send<'buf>( - &self, - buf: &'buf [u8], - ) -> impl std::future::Future> + 'buf { - tracing::trace!( - dest_address = %self.dest.address, - contents = %crate::utils::base64_encode(buf), - "sending packet upstream"); - - let socket = self.upstream_socket(); - async move { socket.await?.send(buf).await.map_err(From::from) } + /// Returns a map of active sessions. + pub fn sessions(&self) -> &SessionMap { + &self.session_map + } + + /// Sends packet data to the appropiate session based on its `key`. + pub async fn send( + self: &Arc, + key: SessionKey, + asn_info: Option, + packet: &[u8], + ) -> Result { + self.get(key, asn_info) + .await? + .send(packet) + .await + .map_err(From::from) + } + + /// Returns whether the pool contains any sockets allocated to a destination. + #[cfg(test)] + async fn has_no_allocated_sockets(&self) -> bool { + let storage = self.storage.read().await; + let is_empty = storage.destination_to_sockets.is_empty(); + // These should always be the same. + debug_assert!(!(is_empty ^ storage.sockets_to_destination.is_empty())); + is_empty + } + + /// Forces removal of session to make testing quicker. + #[cfg(test)] + async fn drop_session(&self, key: SessionKey, session: SessionRef<'_>) -> bool { + drop(session); + let is_removed = self.session_map.remove(key); + // Sleep because there's no async drop. + tokio::time::sleep(Duration::from_millis(100)).await; + is_removed + } + + /// Handles the logic of releasing a socket back into the pool. + async fn release_socket( + self: Arc, + SessionKey { + ref source, + ref dest, + }: SessionKey, + port: u16, + ) { + let mut storage = self.storage.write().await; + let socket_set = storage.destination_to_sockets.get_mut(dest).unwrap(); + + assert!(socket_set.remove(&port)); + + if socket_set.is_empty() { + storage.destination_to_sockets.remove(dest).unwrap(); + } + + let dest_set = storage.sockets_to_destination.get_mut(&port).unwrap(); + + assert!(dest_set.remove(dest)); + + if dest_set.is_empty() { + storage.sockets_to_destination.remove(&port).unwrap(); + } + + // Not asserted because the source might not have GeoIP info. + storage.sources_to_asn_info.remove(source); + assert!(storage + .destination_to_sources + .remove(&(*dest, port)) + .is_some()); } } -impl Drop for Session { - fn drop(&mut self) { +/// Session encapsulates a UDP stream session +#[derive(Debug)] +pub struct Session { + /// created_at is time at which the session was created + created_at: Instant, + /// The source and destination pair. + key: SessionKey, + /// The socket port of the session. + socket_port: u16, + /// The socket of the session. + socket: Arc, + /// The GeoIP information of the source. + asn_info: Option, + /// The socket pool of the session. + pool: Arc, +} + +impl Session { + pub fn new( + key: SessionKey, + socket: Arc, + socket_port: u16, + pool: Arc, + asn_info: Option, + ) -> Result { + let s = Self { + key, + socket, + pool, + socket_port, + asn_info, + created_at: Instant::now(), + }; + + if let Some(asn) = &s.asn_info { + tracing::debug!( + number = asn.r#as, + organization = asn.as_name, + country_code = asn.as_cc, + prefix = asn.prefix, + prefix_entity = asn.prefix_entity, + prefix_name = asn.prefix_name, + "maxmind information" + ); + } + + self::metrics::total_sessions().inc(); + s.active_session_metric().inc(); + tracing::debug!(source = %key.source, dest = %key.dest, "Session created"); + Ok(s) + } + + pub async fn send(&self, packet: &[u8]) -> std::io::Result { + tracing::trace!(dest=%self.key.dest, "sending packet upstream"); + self.socket.send_to(packet, self.key.dest).await + } + + fn active_session_metric(&self) -> prometheus::IntGauge { + metrics::active_sessions(self.asn_info.as_ref()) + } + + fn async_drop(&mut self) -> impl std::future::Future { self.active_session_metric().dec(); metrics::duration_secs().observe(self.created_at.elapsed().as_secs() as f64); + tracing::debug!(source = %self.key.source, dest_address = %self.key.dest, "Session closed"); + SessionPool::release_socket(self.pool.clone(), self.key, self.socket_port) + } +} - if let Err(error) = self.shutdown_tx.send(()) { - tracing::warn!(%error, "Error sending session shutdown signal"); - } +impl Drop for Session { + fn drop(&mut self) { + tokio::spawn(self.async_drop()); + } +} + +// A (source, destination) address pair that uniquely identifies a session. +#[derive(Clone, Copy, Eq, Hash, PartialEq, Debug, PartialOrd, Ord)] +pub struct SessionKey { + pub source: SocketAddr, + pub dest: SocketAddr, +} - tracing::debug!(source = %self.source, dest_address = %self.dest.address, "Session closed"); +impl From<(SocketAddr, SocketAddr)> for SessionKey { + fn from((source, dest): (SocketAddr, SocketAddr)) -> Self { + Self { source, dest } } } #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("failed to convert endpoint to socket address: {0}")] - ToSocketAddr(std::io::Error), #[error("failed to send packet downstream: {0}")] SendTo(std::io::Error), #[error("filter {0}")] @@ -295,7 +468,7 @@ pub enum Error { impl Loggable for Error { fn log(&self) { match self { - Self::ToSocketAddr(error) | Self::SendTo(error) => { + Self::SendTo(error) => { tracing::error!(kind=%error.kind(), "{}", self) } Self::Filter(_) => { @@ -307,100 +480,186 @@ impl Loggable for Error { #[cfg(test)] mod tests { - use std::{str::from_utf8, sync::Arc, time::Duration}; + use super::*; + use crate::test_utils::{available_addr, AddressType, TestHelper}; + use std::sync::Arc; + + async fn new_pool(config: impl Into>) -> (Arc, watch::Sender<()>) { + let (tx, rx) = watch::channel(()); + ( + SessionPool::new( + Arc::new(config.into().unwrap_or_default()), + Arc::new( + DualStackLocalSocket::new( + crate::test_utils::available_addr(&AddressType::Random) + .await + .port(), + ) + .unwrap(), + ), + rx, + ), + tx, + ) + } - use tokio::time::timeout; + #[tokio::test] + async fn insert_and_release_single_socket() { + let (pool, _sender) = new_pool(None).await; + let key = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); - use crate::{ - endpoint::{Endpoint, EndpointAddress}, - proxy::sessions::ReceivedPacketContext, - test_utils::{create_socket, new_test_config, AddressType, TestHelper}, - }; + let session = pool.get(key, None).await.unwrap(); - use super::*; + assert!(pool.drop_session(key, session).await); + + assert!(pool.has_no_allocated_sockets().await); + } #[tokio::test] - async fn session_send_and_receive() { - let mut t = TestHelper::default(); - let addr = t.run_echo_server(&AddressType::Random).await; - let endpoint = Endpoint::new(addr.clone()); - let socket = Arc::new(create_socket().await); - let msg = "hello"; + async fn insert_and_release_multiple_sockets() { + let (pool, _sender) = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8081u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); - let sess = - Session::new(<_>::default(), addr.clone(), socket.clone(), endpoint, None).unwrap(); + let session1 = pool.get(key1, None).await.unwrap(); + let session2 = pool.get(key2, None).await.unwrap(); - sess.upstream_socket() - .await - .unwrap() - .send(msg.as_bytes()) - .await - .unwrap(); + assert!(pool.drop_session(key1, session1).await); + assert!(!pool.has_no_allocated_sockets().await); + assert!(pool.drop_session(key2, session2).await); - let mut buf = vec![0; 1024]; - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .unwrap() - .unwrap(); - let packet = &buf[..size]; - assert_eq!(msg, from_utf8(packet).unwrap()); - assert_eq!(addr.port(), recv_addr.port()); + assert!(pool.has_no_allocated_sockets().await); + drop(pool); } #[tokio::test] - async fn process_recv_packet() { - crate::test_utils::load_test_filters(); - - let socket = Arc::new(create_socket().await); - let endpoint = Endpoint::new("127.0.1.1:80".parse().unwrap()); - let dest: EndpointAddress = socket.local_ipv4_addr().unwrap().into(); - - // first test with no filtering - let msg = "hello"; - Session::process_recv_packet( - &socket, - ReceivedPacketContext { - config: <_>::default(), - packet: msg.as_bytes(), - endpoint: &endpoint, - source: endpoint.address.clone(), - dest: dest.clone(), - }, + async fn same_address_uses_different_sockets() { + let (pool, _sender) = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), ) - .await - .unwrap(); + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8081u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); - let mut buf = vec![0; 1024]; - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .expect("Should receive a packet") - .unwrap(); - assert_eq!(msg, from_utf8(&buf[..size]).unwrap()); - assert_eq!(dest.port(), recv_addr.port()); - - // add filter - let config = Arc::new(new_test_config()); - Session::process_recv_packet( - &socket, - ReceivedPacketContext { - config, - packet: msg.as_bytes(), - endpoint: &endpoint, - source: endpoint.address.clone(), - dest: dest.clone(), - }, + let socket1 = pool.get(key1, None).await.unwrap(); + let socket2 = pool.get(key2, None).await.unwrap(); + + assert_ne!(socket1.socket_port, socket2.socket_port); + } + + #[tokio::test] + async fn different_addresses_uses_same_socket() { + let (pool, _sender) = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8081u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8081u16).into(), + ) + .into(); + + let socket1 = pool.get(key1, None).await.unwrap(); + let socket2 = pool.get(key2, None).await.unwrap(); + + assert_eq!(socket1.socket_port, socket2.socket_port); + } + + #[tokio::test] + async fn spawn_safe_same_destination() { + let (pool, _sender) = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + + let socket1 = pool.get(key1, None).await.unwrap(); + + let task = tokio::spawn(async move { + let _ = socket1; + }); + + let _socket2 = pool.get(key2, None).await.unwrap(); + + task.await.unwrap(); + } + + #[tokio::test] + async fn spawn_safe_different_destination() { + let (pool, _sender) = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8081u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8081u16).into(), + ) + .into(); + + let socket1 = pool.get(key1, None).await.unwrap(); + + let task = tokio::spawn(async move { + let _ = socket1; + }); + + let _socket2 = pool.get(key2, None).await.unwrap(); + + task.await.unwrap(); + } + + #[tokio::test] + async fn send_and_recv() { + let mut t = TestHelper::default(); + let dest = t.run_echo_server(&AddressType::Ipv6).await; + let mut dest = dest.to_socket_addr().await.unwrap(); + crate::test_utils::map_addr_to_localhost(&mut dest); + let source = available_addr(&AddressType::Ipv6).await; + let socket = tokio::net::UdpSocket::bind(source).await.unwrap(); + let mut source = socket.local_addr().unwrap(); + crate::test_utils::map_addr_to_localhost(&mut source); + let (pool, _sender) = new_pool(None).await; + + let key: SessionKey = (source, dest).into(); + let msg = b"helloworld"; + + pool.send(key, None, msg).await.unwrap(); + + let mut buf = [0u8; 1024]; + let (size, _) = tokio::time::timeout( + std::time::Duration::from_secs(1), + socket.recv_from(&mut buf), ) .await + .unwrap() .unwrap(); - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .expect("Should receive a packet") - .unwrap(); - assert_eq!( - format!("{}:our:{}:{}", msg, endpoint.address, dest), - from_utf8(&buf[..size]).unwrap() - ); - assert_eq!(dest.port(), recv_addr.port()); + assert_eq!(msg, &buf[..size]); } } diff --git a/src/test_utils.rs b/src/test_utils.rs index a780ce4521..633e5c2360 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -182,7 +182,7 @@ impl TestHelper { &mut self, ) -> (mpsc::Receiver, Arc) { let socket = Arc::new( - DualStackLocalSocket::new_with_address((Ipv4Addr::UNSPECIFIED, 0).into()).unwrap(), + DualStackLocalSocket::new_with_address((Ipv4Addr::LOCALHOST, 0).into()).unwrap(), ); let packet_rx = self.recv_multiple_packets(&socket).await; (packet_rx, socket) @@ -239,7 +239,8 @@ impl TestHelper { { let socket = create_socket().await; // sometimes give ipv6, sometimes ipv4. - let addr = get_address(address_type, &socket); + let mut addr = get_address(address_type, &socket); + crate::test_utils::map_addr_to_localhost(&mut addr); let mut shutdown = self.get_shutdown_subscriber().await; let local_addr = addr; tokio::spawn(async move { @@ -249,6 +250,7 @@ impl TestHelper { recvd = socket.recv_from(&mut buf) => { let (size, sender) = recvd.unwrap(); let packet = &buf[..size]; + tracing::trace!(%sender, %size, "echo server received and returning packet"); tap(sender, packet, local_addr); socket.send_to(packet, sender).await.unwrap(); }, @@ -322,7 +324,6 @@ where let endpoint = "127.0.0.1:90".parse::().unwrap(); let contents = "hello".to_string().into_bytes(); let mut context = WriteContext::new( - endpoint.clone(), endpoint.address, "127.0.0.1:70".parse().unwrap(), contents.clone(), @@ -332,6 +333,19 @@ where assert_eq!(contents, &*context.contents); } +pub async fn map_to_localhost(address: &mut EndpointAddress) { + let mut socket_addr = address.to_socket_addr().await.unwrap(); + map_addr_to_localhost(&mut socket_addr); + *address = socket_addr.into(); +} + +pub fn map_addr_to_localhost(address: &mut SocketAddr) { + match address { + SocketAddr::V4(addr) => addr.set_ip(std::net::Ipv4Addr::LOCALHOST), + SocketAddr::V6(addr) => addr.set_ip(std::net::Ipv6Addr::LOCALHOST), + } +} + /// Opens a new socket bound to an ephemeral port pub async fn create_socket() -> DualStackLocalSocket { DualStackLocalSocket::new(0).unwrap() @@ -359,6 +373,7 @@ pub fn ep(id: u8) -> Endpoint { } } +#[track_caller] pub fn new_test_config() -> crate::Config { crate::Config { filters: crate::config::Slot::new( diff --git a/src/ttl_map.rs b/src/ttl_map.rs index 7e8e2e12c8..2d609916b4 100644 --- a/src/ttl_map.rs +++ b/src/ttl_map.rs @@ -73,6 +73,15 @@ impl Value { } } +impl std::fmt::Debug for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Value") + .field("value", &self.value) + .field("expires_at", &self.expires_at) + .finish() + } +} + impl std::ops::Deref for Value { type Target = V; @@ -89,6 +98,18 @@ struct Map { shutdown_tx: Option>, } +impl std::fmt::Debug + for Map +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Map") + .field("inner", &self.inner) + .field("ttl", &self.ttl) + .field("shutdown_tx", &self.shutdown_tx) + .finish() + } +} + impl Drop for Map { fn drop(&mut self) { if let Some(shutdown_tx) = self.shutdown_tx.take() { @@ -211,6 +232,11 @@ where .map(|value| value.value) } + /// Removes a key-value pair from the map. + pub fn remove(&self, key: K) -> bool { + self.0.inner.remove(&key).is_some() + } + /// Returns an entry for in-place updates of the specified key-value pair. /// Note: This acquires a write lock on the map's shard that corresponds /// to the entry. @@ -231,6 +257,14 @@ where } } +impl std::fmt::Debug + for TtlMap +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TtlMap").field("inner", &self.0).finish() + } +} + impl Clone for TtlMap { fn clone(&self) -> Self { Self(self.0.clone()) @@ -274,6 +308,19 @@ impl<'a, K, V> OccupiedEntry<'a, K, Value> where K: Eq + Hash, { + /// Returns a reference to the entry's value. + /// The value will be reset to expire at the configured TTL after the time of retrieval. + pub fn into_ref(self) -> Ref<'a, K, Value> { + match self.inner { + DashMapEntry::Occupied(entry) => { + let value = entry.into_ref(); + value.value().update_expiration(self.ttl); + value.downgrade() + } + _ => unreachable!("BUG: entry type should be occupied"), + } + } + /// Returns a reference to the entry's value. /// The value will be reset to expire at the configured TTL after the time of retrieval. pub fn get(&self) -> &Value { diff --git a/src/utils/net.rs b/src/utils/net.rs index 0789b2a71e..8d41afe8da 100644 --- a/src/utils/net.rs +++ b/src/utils/net.rs @@ -22,14 +22,12 @@ use std::{ use socket2::{Protocol, Socket, Type}; use tokio::{net::ToSocketAddrs, net::UdpSocket}; -use crate::Result; - /// returns a UdpSocket with address and port reuse, on Ipv6Addr::UNSPECIFIED -fn socket_with_reuse(port: u16) -> Result { +fn socket_with_reuse(port: u16) -> std::io::Result { socket_with_reuse_and_address((Ipv6Addr::UNSPECIFIED, port).into()) } -fn socket_with_reuse_and_address(addr: SocketAddr) -> Result { +fn socket_with_reuse_and_address(addr: SocketAddr) -> std::io::Result { let domain = match addr { SocketAddr::V4(_) => socket2::Domain::IPV4, SocketAddr::V6(_) => socket2::Domain::IPV6, @@ -43,7 +41,7 @@ fn socket_with_reuse_and_address(addr: SocketAddr) -> Result { sock.set_only_v6(false)?; } sock.bind(&addr.into())?; - UdpSocket::from_std(sock.into()).map_err(|error| eyre::eyre!(error)) + UdpSocket::from_std(sock.into()) } #[cfg(not(target_family = "windows"))] @@ -60,19 +58,26 @@ fn enable_reuse(sock: &Socket) -> io::Result<()> { /// An ipv6 socket that can accept and send data from either a local ipv4 address or ipv6 address /// with port reuse enabled and only_v6 set to false. +#[derive(Debug)] pub struct DualStackLocalSocket { socket: UdpSocket, } impl DualStackLocalSocket { - pub fn new(port: u16) -> Result { + pub fn new(port: u16) -> std::io::Result { Ok(Self { socket: socket_with_reuse(port)?, }) } + pub fn bind_local(port: u16) -> std::io::Result { + Ok(Self { + socket: socket_with_reuse_and_address((Ipv6Addr::LOCALHOST, port).into())?, + }) + } + /// Primarily used for testing of ipv4 vs ipv6 addresses. - pub(crate) fn new_with_address(addr: SocketAddr) -> Result { + pub(crate) fn new_with_address(addr: SocketAddr) -> std::io::Result { Ok(Self { socket: socket_with_reuse_and_address(addr)?, }) @@ -200,3 +205,21 @@ mod tests { ); } } + +/// Converts a a socket address to its canonical version. +/// This is just a copy of the method available in std but that is currently +/// nightly only. +pub fn to_canonical(addr: &mut SocketAddr) { + let ip = match addr.ip() { + std::net::IpAddr::V6(ip) => { + if let Some(mapped) = ip.to_ipv4_mapped() { + std::net::IpAddr::V4(mapped) + } else { + std::net::IpAddr::V6(ip) + } + } + addr => addr, + }; + + addr.set_ip(ip); +} diff --git a/src/xds.rs b/src/xds.rs index 93a883df1e..7d86d7c6ab 100644 --- a/src/xds.rs +++ b/src/xds.rs @@ -157,8 +157,9 @@ mod tests { let mut helper = crate::test_utils::TestHelper::default(); let token = "mytoken"; let address = { - let mut addr = Endpoint::new(helper.run_echo_server(&AddressType::Random).await); + let mut addr = Endpoint::new(helper.run_echo_server(&AddressType::Ipv6).await); addr.metadata.known.tokens.insert(token.into()); + crate::test_utils::map_to_localhost(&mut addr.address).await; addr }; let clusters = crate::cluster::ClusterMap::default(); @@ -275,10 +276,7 @@ mod tests { client .socket - .send_to( - &packet, - (std::net::Ipv4Addr::UNSPECIFIED, client_addr.port()), - ) + .send_to(&packet, (std::net::Ipv6Addr::LOCALHOST, client_addr.port())) .await .unwrap(); let response = tokio::time::timeout(std::time::Duration::from_secs(1), client.packet_rx) @@ -315,7 +313,7 @@ mod tests { // Each time, we create a new upstream endpoint and send a cluster update for it. let concat_bytes = vec![("b", "c,"), ("d", "e")]; for (b1, b2) in concat_bytes.into_iter() { - let socket = std::net::UdpSocket::bind((std::net::Ipv4Addr::UNSPECIFIED, 0)).unwrap(); + let socket = std::net::UdpSocket::bind((std::net::Ipv6Addr::LOCALHOST, 0)).unwrap(); let local_addr: crate::endpoint::EndpointAddress = socket.local_addr().unwrap().into(); config.clusters.modify(|clusters| { diff --git a/tests/capture.rs b/tests/capture.rs index 96379e940f..626166df7d 100644 --- a/tests/capture.rs +++ b/tests/capture.rs @@ -31,7 +31,8 @@ use quilkin::{ #[tokio::test] async fn token_router() { let mut t = TestHelper::default(); - let echo = t.run_echo_server(&AddressType::Random).await; + let mut echo = t.run_echo_server(&AddressType::Random).await; + quilkin::test_utils::map_to_localhost(&mut echo).await; let server_port = 12348; let server_proxy = quilkin::cli::Proxy { port: server_port, @@ -87,7 +88,7 @@ async fn token_router() { assert_eq!( "helloabc", - timeout(Duration::from_secs(5), recv_chan.recv()) + timeout(Duration::from_millis(500), recv_chan.recv()) .await .expect("should have received a packet") .unwrap() @@ -97,6 +98,6 @@ async fn token_router() { let msg = b"helloxyz"; socket.send_to(msg, &local_addr).await.unwrap(); - let result = timeout(Duration::from_secs(3), recv_chan.recv()).await; + let result = timeout(Duration::from_millis(500), recv_chan.recv()).await; assert!(result.is_err(), "should not have received a packet"); } diff --git a/tests/compress.rs b/tests/compress.rs index d537e0364e..6c9d05da62 100644 --- a/tests/compress.rs +++ b/tests/compress.rs @@ -29,7 +29,8 @@ async fn client_and_server() { let echo = t.run_echo_server(&AddressType::Random).await; // create server configuration as - let server_addr = available_addr(&AddressType::Random).await; + let mut server_addr = available_addr(&AddressType::Random).await; + quilkin::test_utils::map_addr_to_localhost(&mut server_addr); let yaml = " on_read: DECOMPRESS on_write: COMPRESS @@ -84,7 +85,7 @@ on_write: DECOMPRESS let (mut rx, tx) = t.open_socket_and_recv_multiple_packets().await; tx.send_to(b"hello", &client_addr).await.unwrap(); - let expected = timeout(Duration::from_secs(5), rx.recv()) + let expected = timeout(Duration::from_millis(500), rx.recv()) .await .expect("should have received a packet") .unwrap(); diff --git a/tests/filter_order.rs b/tests/filter_order.rs index 23e99a389d..58401267b1 100644 --- a/tests/filter_order.rs +++ b/tests/filter_order.rs @@ -44,7 +44,7 @@ on_read: COMPRESS on_write: DECOMPRESS "; - let echo = t + let mut echo = t .run_echo_server_with_tap(&AddressType::Random, move |_, bytes, _| { assert!( from_utf8(bytes).is_err(), @@ -53,6 +53,7 @@ on_write: DECOMPRESS }) .await; + quilkin::test_utils::map_to_localhost(&mut echo).await; let server_port = 12346; let server_proxy = quilkin::cli::Proxy { port: server_port, @@ -94,7 +95,7 @@ on_write: DECOMPRESS assert_eq!( "helloxyzabc", - timeout(Duration::from_secs(5), recv_chan.recv()) + timeout(Duration::from_millis(500), recv_chan.recv()) .await .expect("should have received a packet") .unwrap() diff --git a/tests/filters.rs b/tests/filters.rs index d83e359e18..a57f317a92 100644 --- a/tests/filters.rs +++ b/tests/filters.rs @@ -125,6 +125,7 @@ async fn debug_filter() { // create an echo server as an endpoint. let echo = t.run_echo_server(&AddressType::Random).await; + tracing::trace!(%echo, "running echo server"); // create server configuration let server_port = 12247; let server_config = std::sync::Arc::new(quilkin::Config::default()); diff --git a/tests/firewall.rs b/tests/firewall.rs index bdc8a19cb9..10ac11f882 100644 --- a/tests/firewall.rs +++ b/tests/firewall.rs @@ -51,7 +51,7 @@ on_write: assert_eq!( "hello", - timeout(Duration::from_secs(5), rx.recv()) + timeout(Duration::from_millis(500), rx.recv()) .await .expect("should have received a packet") .unwrap() @@ -81,7 +81,7 @@ on_write: assert_eq!( "hello", - timeout(Duration::from_secs(5), rx.recv()) + timeout(Duration::from_millis(500), rx.recv()) .await .expect("should have received a packet") .unwrap() @@ -109,7 +109,7 @@ on_write: "; let mut rx = test(&mut t, port, yaml, &address_type).await; - let result = timeout(Duration::from_secs(3), rx.recv()).await; + let result = timeout(Duration::from_millis(500), rx.recv()).await; assert!(result.is_err(), "should not have received a packet"); } @@ -134,7 +134,7 @@ on_write: "; let mut rx = test(&mut t, port, yaml, &address_type).await; - let result = timeout(Duration::from_secs(3), rx.recv()).await; + let result = timeout(Duration::from_millis(500), rx.recv()).await; assert!(result.is_err(), "should not have received a packet"); } @@ -159,7 +159,7 @@ on_write: "; let mut rx = test(&mut t, port, yaml, &address_type).await; - let result = timeout(Duration::from_secs(3), rx.recv()).await; + let result = timeout(Duration::from_millis(500), rx.recv()).await; assert!(result.is_err(), "should not have received a packet"); } @@ -184,7 +184,7 @@ on_write: "; let mut rx = test(&mut t, port, yaml, &address_type).await; - let result = timeout(Duration::from_secs(3), rx.recv()).await; + let result = timeout(Duration::from_millis(500), rx.recv()).await; assert!(result.is_err(), "should not have received a packet"); } diff --git a/tests/local_rate_limit.rs b/tests/local_rate_limit.rs index fb2cf0f8e8..55068c5ba3 100644 --- a/tests/local_rate_limit.rs +++ b/tests/local_rate_limit.rs @@ -35,7 +35,8 @@ period: 1 "; let echo = t.run_echo_server(&AddressType::Random).await; - let server_addr = available_addr(&AddressType::Random).await; + let mut server_addr = available_addr(&AddressType::Random).await; + quilkin::test_utils::map_addr_to_localhost(&mut server_addr); let server_proxy = quilkin::cli::Proxy { port: server_addr.port(), ..<_>::default() @@ -53,19 +54,21 @@ period: 1 .map(std::sync::Arc::new) .unwrap(), ); + tracing::trace!("spawning server"); t.run_server(server_config, server_proxy, None); let msg = "hello"; let (mut rx, socket) = t.open_socket_and_recv_multiple_packets().await; for _ in 0..3 { + tracing::trace!(%server_addr, %msg, "sending"); socket.send_to(msg.as_bytes(), &server_addr).await.unwrap(); } for _ in 0..2 { assert_eq!( msg, - timeout(Duration::from_secs(5), rx.recv()) + timeout(Duration::from_millis(500), rx.recv()) .await .unwrap() .unwrap() @@ -75,5 +78,7 @@ period: 1 // Allow enough time to have received any response. tokio::time::sleep(Duration::from_millis(100)).await; // Check that we do not get any response. - assert!(timeout(Duration::from_secs(1), rx.recv()).await.is_err()); + assert!(timeout(Duration::from_millis(500), rx.recv()) + .await + .is_err()); } diff --git a/tests/match.rs b/tests/match.rs index 6462b9693c..265920c654 100644 --- a/tests/match.rs +++ b/tests/match.rs @@ -95,7 +95,7 @@ on_read: assert_eq!( "helloxyz", - timeout(Duration::from_secs(5), recv_chan.recv()) + timeout(Duration::from_millis(500), recv_chan.recv()) .await .expect("should have received a packet") .unwrap() @@ -107,7 +107,7 @@ on_read: assert_eq!( "helloabc", - timeout(Duration::from_secs(5), recv_chan.recv()) + timeout(Duration::from_millis(500), recv_chan.recv()) .await .expect("should have received a packet") .unwrap() @@ -119,7 +119,7 @@ on_read: assert_eq!( "hellodef", - timeout(Duration::from_secs(5), recv_chan.recv()) + timeout(Duration::from_millis(500), recv_chan.recv()) .await .expect("should have received a packet") .unwrap() @@ -131,7 +131,7 @@ on_read: assert_eq!( "hellodef", - timeout(Duration::from_secs(5), recv_chan.recv()) + timeout(Duration::from_millis(500), recv_chan.recv()) .await .expect("should have received a packet") .unwrap() diff --git a/tests/metrics.rs b/tests/metrics.rs index a04b3c6a69..708b44f921 100644 --- a/tests/metrics.rs +++ b/tests/metrics.rs @@ -32,7 +32,8 @@ async fn metrics_server() { .port(); // create server configuration - let server_addr = quilkin::test_utils::available_addr(&AddressType::Random).await; + let mut server_addr = quilkin::test_utils::available_addr(&AddressType::Random).await; + quilkin::test_utils::map_addr_to_localhost(&mut server_addr); let server_proxy = quilkin::cli::Proxy { port: server_addr.port(), ..<_>::default() @@ -44,7 +45,7 @@ async fn metrics_server() { t.run_server( server_config, server_proxy, - Some(Some((std::net::Ipv6Addr::UNSPECIFIED, metrics_port).into())), + Some(Some((std::net::Ipv4Addr::UNSPECIFIED, metrics_port).into())), ); // create a local client diff --git a/tests/no_filter.rs b/tests/no_filter.rs index 022eefbf41..a034d7a5bf 100644 --- a/tests/no_filter.rs +++ b/tests/no_filter.rs @@ -53,19 +53,19 @@ async fn echo() { let (mut recv_chan, socket) = t.open_socket_and_recv_multiple_packets().await; socket.send_to(b"hello", &local_addr).await.unwrap(); - let value = timeout(Duration::from_secs(5), recv_chan.recv()) + let value = timeout(Duration::from_millis(500), recv_chan.recv()) .await .unwrap() .unwrap(); assert_eq!("hello", value); - let value = timeout(Duration::from_secs(5), recv_chan.recv()) + let value = timeout(Duration::from_millis(500), recv_chan.recv()) .await .unwrap() .unwrap(); assert_eq!("hello", value); // should only be two returned items - assert!(timeout(Duration::from_secs(2), recv_chan.recv()) + assert!(timeout(Duration::from_millis(500), recv_chan.recv()) .await .is_err()); } diff --git a/tests/token_router.rs b/tests/token_router.rs index 2cb834202a..92d03a2502 100644 --- a/tests/token_router.rs +++ b/tests/token_router.rs @@ -14,7 +14,7 @@ * limitations under the License. */ -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{Ipv6Addr, SocketAddr}; use tokio::time::{timeout, Duration}; @@ -31,7 +31,8 @@ use quilkin::{ #[tokio::test] async fn token_router() { let mut t = TestHelper::default(); - let echo = t.run_echo_server(&AddressType::Random).await; + let mut echo = t.run_echo_server(&AddressType::Ipv6).await; + quilkin::test_utils::map_to_localhost(&mut echo).await; let capture_yaml = " suffix: @@ -82,13 +83,13 @@ quilkin.dev: // valid packet let (mut recv_chan, socket) = t.open_socket_and_recv_multiple_packets().await; - let local_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, server_port)); + let local_addr = SocketAddr::from((Ipv6Addr::LOCALHOST, server_port)); let msg = b"helloabc"; socket.send_to(msg, &local_addr).await.unwrap(); assert_eq!( "hello", - timeout(Duration::from_secs(5), recv_chan.recv()) + timeout(Duration::from_millis(500), recv_chan.recv()) .await .expect("should have received a packet") .unwrap() @@ -98,6 +99,6 @@ quilkin.dev: let msg = b"helloxyz"; socket.send_to(msg, &local_addr).await.unwrap(); - let result = timeout(Duration::from_secs(3), recv_chan.recv()).await; + let result = timeout(Duration::from_millis(500), recv_chan.recv()).await; assert!(result.is_err(), "should not have received a packet"); }