From b8320dc4540239c6aae0a825ed1a17dfb6b6a765 Mon Sep 17 00:00:00 2001 From: Erin Power Date: Tue, 10 Oct 2023 20:55:04 +0200 Subject: [PATCH] Refactor sessions to use socket pool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit refactors how we handle upstream connections to the gameservers. When profiling quilkin I noticed that there was a lot of time (~10–15%) being spent dropping the upstream socket through its Arc implementation that happened whenever a session was dropped. As I was thinking about how to solve this problem I also realised there was a second issue, which is that there is a limitation on how many connections Quilkin can hold at once, roughly ~16,383. Because after that we're likely to start encountering port exhaustion from the operating system, since each session is a unique socket. This brought me to the solution in this commit, which is that while we need to give each connection to the gameserver a unique port, we don't need to give a unique port across gameservers. So I refactored how we create sessions to use what I've called a "SessionPool". This pools the sockets for sessions into a map that is keyed by their destination. With this implementation this means that we now have a limit of ~16,000 connections per gameserver, which is far more than any gameserver could reasonably need. --- src/cli/proxy.rs | 45 ++- src/filters/chain.rs | 2 - src/filters/compress.rs | 4 - src/filters/debug.rs | 8 +- src/filters/firewall.rs | 16 +- src/filters/match.rs | 1 - src/filters/registry.rs | 2 +- src/filters/write.rs | 15 +- src/proxy.rs | 72 +--- src/proxy/sessions.rs | 753 ++++++++++++++++++++++++++-------------- src/test_utils.rs | 2 +- src/ttl_map.rs | 34 ++ src/utils/net.rs | 1 + 13 files changed, 588 insertions(+), 367 deletions(-) diff --git a/src/cli/proxy.rs b/src/cli/proxy.rs index 4dfe702e8f..212656feac 100644 --- a/src/cli/proxy.rs +++ b/src/cli/proxy.rs @@ -26,7 +26,7 @@ use std::{ use tonic::transport::Endpoint; use super::Admin; -use crate::{proxy::SessionMap, xds::ResourceType, Config, Result}; +use crate::{proxy::SessionPool, xds::ResourceType, Config, Result}; #[cfg(doc)] use crate::filters::FilterFactory; @@ -81,9 +81,6 @@ impl Proxy { mode: Admin, mut shutdown_rx: tokio::sync::watch::Receiver<()>, ) -> crate::Result<()> { - const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); - const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); - let _mmdb_task = self.mmdb.clone().map(|source| { tokio::spawn(async move { use crate::config::BACKOFF_INITIAL_DELAY_MILLISECONDS; @@ -122,8 +119,12 @@ impl Proxy { let id = config.id.load(); tracing::info!(port = self.port, proxy_id = &*id, "Starting"); - let sessions = SessionMap::new(SESSION_TIMEOUT_SECONDS, SESSION_EXPIRY_POLL_INTERVAL); let runtime_config = mode.unwrap_proxy(); + let sessions = SessionPool::new( + config.clone(), + DualStackLocalSocket::new(self.port)?, + shutdown_rx.clone(), + ); let _xds_stream = if !self.management_server.is_empty() { { @@ -161,10 +162,10 @@ impl Proxy { .await .map_err(|error| eyre::eyre!(error))?; - tracing::info!(sessions=%sessions.len(), "waiting for active sessions to expire"); - while sessions.is_not_empty() { + tracing::info!(sessions=%sessions.sessions().len(), "waiting for active sessions to expire"); + while sessions.sessions().is_not_empty() { tokio::time::sleep(Duration::from_secs(1)).await; - tracing::debug!(sessions=%sessions.len(), "sessions still active"); + tracing::debug!(sessions=%sessions.sessions().len(), "sessions still active"); } tracing::info!("all sessions expired"); @@ -176,7 +177,7 @@ impl Proxy { /// This function also spawns the set of worker tasks responsible for consuming packets /// off the aforementioned queue and processing them through the filter chain and session /// pipeline. - fn run_recv_from(&self, config: &Arc, sessions: SessionMap) -> Result<()> { + fn run_recv_from(&self, config: &Arc, sessions: Arc) -> Result<()> { // The number of worker tasks to spawn. Each task gets a dedicated queue to // consume packets off. let num_workers = num_cpus::get(); @@ -366,8 +367,17 @@ mod tests { crate::proxy::DownstreamReceiveWorkerConfig { worker_id: 1, socket: socket.clone(), - config, - sessions: <_>::default(), + config: config.clone(), + sessions: SessionPool::new( + config, + DualStackLocalSocket::new( + crate::test_utils::available_addr(&AddressType::Random) + .await + .port(), + ) + .unwrap(), + tokio::sync::watch::channel(()).1, + ), } .spawn(); @@ -405,7 +415,18 @@ mod tests { ) }); - proxy.run_recv_from(&config, <_>::default()).unwrap(); + let sessions = SessionPool::new( + config.clone(), + DualStackLocalSocket::new( + crate::test_utils::available_addr(&AddressType::Random) + .await + .port(), + ) + .unwrap(), + tokio::sync::watch::channel(()).1, + ); + + proxy.run_recv_from(&config, sessions).unwrap(); let socket = create_socket().await; socket.send_to(msg.as_bytes(), &local_addr).await.unwrap(); diff --git a/src/filters/chain.rs b/src/filters/chain.rs index 919f7843e7..552cdc09a9 100644 --- a/src/filters/chain.rs +++ b/src/filters/chain.rs @@ -369,7 +369,6 @@ mod tests { ); let mut context = WriteContext::new( - endpoints_fixture[0].clone(), endpoints_fixture[0].address.clone(), "127.0.0.1:70".parse().unwrap(), b"hello".to_vec(), @@ -417,7 +416,6 @@ mod tests { ); let mut context = WriteContext::new( - endpoints_fixture[0].clone(), endpoints_fixture[0].address.clone(), "127.0.0.1:70".parse().unwrap(), b"hello".to_vec(), diff --git a/src/filters/compress.rs b/src/filters/compress.rs index fc84e36d0e..a5950c4b9a 100644 --- a/src/filters/compress.rs +++ b/src/filters/compress.rs @@ -196,7 +196,6 @@ mod tests { // write decompress let mut write_context = WriteContext::new( - Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), read_context.contents.clone(), @@ -223,7 +222,6 @@ mod tests { assert!(compression .write(&mut WriteContext::new( - Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), b"hello".to_vec(), @@ -270,7 +268,6 @@ mod tests { assert_eq!(b"hello", &*read_context.contents); let mut write_context = WriteContext::new( - Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), b"hello".to_vec(), @@ -329,7 +326,6 @@ mod tests { let expected = contents_fixture(); // write compress let mut write_context = WriteContext::new( - Endpoint::new("127.0.0.1:80".parse().unwrap()), "127.0.0.1:8080".parse().unwrap(), "127.0.0.1:8081".parse().unwrap(), expected.clone(), diff --git a/src/filters/debug.rs b/src/filters/debug.rs index 7290ecd045..a9dc324f6b 100644 --- a/src/filters/debug.rs +++ b/src/filters/debug.rs @@ -50,8 +50,12 @@ impl Filter for Debug { #[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))] async fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> { - info!(id = ?self.config.id, endpoint = ?ctx.endpoint.address, source = ?&ctx.source, - dest = ?&ctx.dest, contents = ?String::from_utf8_lossy(&ctx.contents), "Write filter event"); + info!( + id = ?self.config.id, + source = ?&ctx.source, + dest = ?&ctx.dest, + contents = ?String::from_utf8_lossy(&ctx.contents), "Write filter event" + ); Ok(()) } } diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs index 996803aa33..7698e49792 100644 --- a/src/filters/firewall.rs +++ b/src/filters/firewall.rs @@ -166,23 +166,13 @@ mod tests { }], }; - let endpoint = Endpoint::new((Ipv4Addr::LOCALHOST, 80).into()); let local_addr: crate::endpoint::EndpointAddress = (Ipv4Addr::LOCALHOST, 8081).into(); - let mut ctx = WriteContext::new( - endpoint.clone(), - ([192, 168, 75, 20], 80).into(), - local_addr.clone(), - vec![], - ); + let mut ctx = + WriteContext::new(([192, 168, 75, 20], 80).into(), local_addr.clone(), vec![]); assert!(firewall.write(&mut ctx).await.is_ok()); - let mut ctx = WriteContext::new( - endpoint, - ([192, 168, 77, 20], 80).into(), - local_addr, - vec![], - ); + let mut ctx = WriteContext::new(([192, 168, 77, 20], 80).into(), local_addr, vec![]); assert!(firewall.write(&mut ctx).await.is_err()); } } diff --git a/src/filters/match.rs b/src/filters/match.rs index 9522b4a9ec..af4810c56b 100644 --- a/src/filters/match.rs +++ b/src/filters/match.rs @@ -188,7 +188,6 @@ mod tests { // no config, so should make no change. filter .write(&mut WriteContext::new( - endpoint.clone(), endpoint.address, "127.0.0.1:70".parse().unwrap(), contents.clone(), diff --git a/src/filters/registry.rs b/src/filters/registry.rs index 42ef796135..f2fb7cf98a 100644 --- a/src/filters/registry.rs +++ b/src/filters/registry.rs @@ -113,7 +113,7 @@ mod tests { .await .is_ok()); assert!(filter - .write(&mut WriteContext::new(endpoint, addr.clone(), addr, vec![],)) + .write(&mut WriteContext::new(addr.clone(), addr, vec![],)) .await .is_ok()); } diff --git a/src/filters/write.rs b/src/filters/write.rs index 744d8f9df3..48d06800a7 100644 --- a/src/filters/write.rs +++ b/src/filters/write.rs @@ -16,10 +16,7 @@ use std::collections::HashMap; -use crate::{ - endpoint::{Endpoint, EndpointAddress}, - metadata::DynamicMetadata, -}; +use crate::{endpoint::EndpointAddress, metadata::DynamicMetadata}; #[cfg(doc)] use crate::filters::Filter; @@ -27,8 +24,6 @@ use crate::filters::Filter; /// The input arguments to [`Filter::write`]. #[non_exhaustive] pub struct WriteContext { - /// The upstream endpoint that we're expecting packets from. - pub endpoint: Endpoint, /// The source of the received packet. pub source: EndpointAddress, /// The destination of the received packet. @@ -41,14 +36,8 @@ pub struct WriteContext { impl WriteContext { /// Creates a new [`WriteContext`] - pub fn new( - endpoint: Endpoint, - source: EndpointAddress, - dest: EndpointAddress, - contents: Vec, - ) -> Self { + pub fn new(source: EndpointAddress, dest: EndpointAddress, contents: Vec) -> Self { Self { - endpoint, source, dest, contents, diff --git a/src/proxy.rs b/src/proxy.rs index f243a63d19..25895721bb 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -16,10 +16,9 @@ use std::{net::SocketAddr, sync::Arc}; -pub use sessions::{Session, SessionKey, SessionMap}; +pub use sessions::{Session, SessionKey, SessionPool}; use crate::{ - endpoint::{Endpoint, EndpointAddress}, filters::{Filter, ReadContext}, utils::net::DualStackLocalSocket, Config, @@ -44,7 +43,7 @@ pub(crate) struct DownstreamReceiveWorkerConfig { /// Socket with reused port from which the worker receives packets. pub socket: Arc, pub config: Arc, - pub sessions: SessionMap, + pub sessions: Arc, } impl DownstreamReceiveWorkerConfig { @@ -88,7 +87,7 @@ impl DownstreamReceiveWorkerConfig { } last_received_at = Some(packet.received_at); - Self::spawn_process_task(packet, source, worker_id, &socket, &config, &sessions) + Self::spawn_process_task(packet, source, worker_id, &config, &sessions) } Err(error) => { tracing::error!(%error, "error receiving packet"); @@ -106,9 +105,8 @@ impl DownstreamReceiveWorkerConfig { packet: DownstreamPacket, source: std::net::SocketAddr, worker_id: usize, - socket: &Arc, config: &Arc, - sessions: &SessionMap, + sessions: &Arc, ) { tracing::trace!( id = worker_id, @@ -121,16 +119,13 @@ impl DownstreamReceiveWorkerConfig { tokio::spawn({ let config = config.clone(); let sessions = sessions.clone(); - let socket = socket.clone(); async move { let timer = crate::metrics::processing_time(crate::metrics::READ).start_timer(); let asn_info = packet.asn_info.clone(); let asn_info = asn_info.as_ref(); - match Self::process_downstream_received_packet(packet, config, socket, sessions) - .await - { + match Self::process_downstream_received_packet(packet, config, sessions).await { Ok(size) => { crate::metrics::packets_total(crate::metrics::READ, asn_info).inc(); crate::metrics::bytes_total(crate::metrics::READ, asn_info) @@ -157,8 +152,7 @@ impl DownstreamReceiveWorkerConfig { async fn process_downstream_received_packet( packet: DownstreamPacket, config: Arc, - downstream_socket: Arc, - sessions: SessionMap, + sessions: Arc, ) -> Result { let endpoints: Vec<_> = config.clusters.read().endpoints().collect(); if endpoints.is_empty() { @@ -171,56 +165,18 @@ impl DownstreamReceiveWorkerConfig { let mut bytes_written = 0; for endpoint in context.endpoints.iter() { - bytes_written += Self::session_send_packet( - &context.contents, - &context.source, - endpoint, - &downstream_socket, - &config, - &sessions, - packet.asn_info.clone(), - ) - .await?; + let session_key = SessionKey { + source: packet.source, + dest: endpoint.address.to_socket_addr().await?, + }; + + bytes_written += sessions + .send(session_key, packet.asn_info.clone(), &context.contents) + .await?; } Ok(bytes_written) } - - /// Send a packet received from `recv_addr` to an endpoint. - #[tracing::instrument(level="trace", skip_all, fields(source = %recv_addr, dest = %endpoint.address))] - async fn session_send_packet( - packet: &[u8], - recv_addr: &EndpointAddress, - endpoint: &Endpoint, - downstream_socket: &Arc, - config: &Arc, - sessions: &SessionMap, - asn_info: Option, - ) -> Result { - let session_key = SessionKey { - source: recv_addr.clone(), - dest: endpoint.address.clone(), - }; - - let send_future = match sessions.get(&session_key) { - Some(entry) => entry.send(packet), - None => { - let session = Session::new( - config.clone(), - session_key.source.clone(), - downstream_socket.clone(), - endpoint.clone(), - asn_info, - )?; - - let future = session.send(packet); - sessions.insert(session_key, session); - future - } - }; - - send_future.await - } } #[derive(thiserror::Error, Debug)] diff --git a/src/proxy/sessions.rs b/src/proxy/sessions.rs index bfc9b6d869..a4ed52f867 100644 --- a/src/proxy/sessions.rs +++ b/src/proxy/sessions.rs @@ -14,168 +14,127 @@ * limitations under the License. */ -use std::{net::SocketAddr, sync::Arc}; - -use tokio::{ - net::UdpSocket, - select, - sync::{watch, OnceCell}, - time::Instant, +use std::{ + collections::{HashMap, HashSet}, + net::SocketAddr, + pin::Pin, + sync::Arc, + time::Duration, }; +use tokio::{net::UdpSocket, sync::watch, time::Instant}; + use crate::{ - endpoint::{Endpoint, EndpointAddress}, - filters::{Filter, WriteContext}, + config::Config, + filters::Filter, maxmind_db::IpNetEntry, utils::{net::DualStackLocalSocket, Loggable}, }; +use parking_lot::RwLock; + +use dashmap::DashMap; + +const _RM: u16 = 16383; + pub(crate) mod metrics; pub type SessionMap = crate::ttl_map::TtlMap; -/// Session encapsulates a UDP stream session -pub struct Session { - config: Arc, - /// created_at is time at which the session was created - created_at: Instant, - /// socket that sends and receives from and to the endpoint address - upstream_socket: Arc>>, - /// dest is where to send data to - dest: Endpoint, - /// address of original sender - source: EndpointAddress, - /// a channel to broadcast on if we are shutting down this Session - shutdown_tx: watch::Sender<()>, - /// The ASN information. - asn_info: Option, +type SessionRef<'pool> = + dashmap::mapref::one::Ref<'pool, SessionKey, crate::ttl_map::Value>; + +/// A data structure that is responsible for holding sessions, and pooling +/// sockets between them. This means that we only provide new unique sockets +/// to new connections to the same gameserver, and we share sockets across +/// multiple gameservers. +/// +/// Traffic from different gameservers is then demuxed using their address to +/// send back to the original client. +#[derive(Debug)] +pub struct SessionPool { + ports_to_sockets: DashMap>, + storage: RwLock, + session_map: SessionMap, + downstream_socket: Arc, + shutdown_rx: watch::Receiver<()>, + config: Arc, } -// A (source, destination) address pair that uniquely identifies a session. -#[derive(Clone, Eq, Hash, PartialEq, Debug, PartialOrd, Ord)] -pub struct SessionKey { - pub source: EndpointAddress, - pub dest: EndpointAddress, +/// The wrapper struct responsible for holding all of the socket related mappings. +#[derive(Default, Debug)] +struct SocketStorage { + destination_to_sockets: HashMap>, + destination_to_sources: HashMap<(SocketAddr, u16), SocketAddr>, + sources_to_asn_info: HashMap, + sockets_to_destination: HashMap>, } -impl From<(EndpointAddress, EndpointAddress)> for SessionKey { - fn from((source, dest): (EndpointAddress, EndpointAddress)) -> Self { - SessionKey { source, dest } +impl SessionPool { + /// Constructs a new session pool, it's created with an `Arc` as that's + /// required for the pool to provide a reference to the children to be able + /// to release their sockets back to the parent. + pub fn new( + config: Arc, + downstream_socket: DualStackLocalSocket, + shutdown_rx: watch::Receiver<()>, + ) -> Arc { + const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); + const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); + + Arc::new(Self { + config, + downstream_socket: Arc::new(downstream_socket), + shutdown_rx, + ports_to_sockets: <_>::default(), + storage: <_>::default(), + session_map: SessionMap::new(SESSION_TIMEOUT_SECONDS, SESSION_EXPIRY_POLL_INTERVAL), + }) } -} -/// ReceivedPacketContext contains state needed to process a received packet. -struct ReceivedPacketContext<'a> { - packet: &'a [u8], - config: Arc, - endpoint: &'a Endpoint, - source: EndpointAddress, - dest: EndpointAddress, -} - -impl Session { - /// internal constructor for a Session from SessionArgs - #[tracing::instrument(skip_all)] - pub fn new( - config: Arc, - source: EndpointAddress, - downstream_socket: Arc, - dest: Endpoint, + /// Allocates a new upstream socket from a new socket from the system. + async fn create_new_session_from_new_socket<'pool>( + self: &'pool Arc, + key @ SessionKey { dest, .. }: SessionKey, asn_info: Option, - ) -> Result { - let (shutdown_tx, shutdown_rx) = watch::channel::<()>(()); - - let s = Session { - config: config.clone(), - upstream_socket: Arc::new(OnceCell::new()), - source: source.clone(), - dest, - created_at: Instant::now(), - shutdown_tx, - asn_info, + ) -> Result, super::PipelineError> { + let addr: std::net::IpAddr = match dest { + SocketAddr::V4(_) => std::net::Ipv4Addr::UNSPECIFIED.into(), + SocketAddr::V6(_) => std::net::Ipv6Addr::UNSPECIFIED.into(), }; - tracing::debug!(source = %s.source, dest = ?s.dest, "Session created"); - - if let Some(asn) = &s.asn_info { - tracing::debug!( - number = asn.r#as, - organization = asn.as_name, - country_code = asn.as_cc, - prefix = asn.prefix, - prefix_entity = asn.prefix_entity, - prefix_name = asn.prefix_name, - "maxmind information" - ); - } - - self::metrics::total_sessions().inc(); - s.active_session_metric().inc(); - s.run(downstream_socket, shutdown_rx); - Ok(s) - } - - fn upstream_socket( - &self, - ) -> impl std::future::Future, super::PipelineError>> { - let upstream_socket = self.upstream_socket.clone(); - let address = self.dest.address.clone(); - - async move { - let connect_addr = address.to_socket_addr().await?; - let bind_addr: SocketAddr = match connect_addr { - SocketAddr::V4(_) => (std::net::Ipv4Addr::UNSPECIFIED, 0).into(), - SocketAddr::V6(_) => (std::net::Ipv6Addr::UNSPECIFIED, 0).into(), - }; - - upstream_socket - .get_or_try_init(|| async { - let upstream_socket = UdpSocket::bind(bind_addr).await?; - upstream_socket.connect(connect_addr).await?; - Ok(Arc::new(upstream_socket)) - }) - .await - .cloned() - } - } - - /// run starts processing receiving upstream udp packets - /// and sending them back downstream - fn run( - &self, - downstream_socket: Arc, - mut shutdown_rx: watch::Receiver<()>, - ) { - let source = self.source.clone(); - let config = self.config.clone(); - let endpoint = self.dest.clone(); - let upstream_socket = self.upstream_socket(); - let asn_info = self.asn_info.clone(); + let socket = UdpSocket::bind((addr, 0u16)).await.map(Arc::new)?; + let port = socket.local_addr().unwrap().port(); + self.ports_to_sockets.insert(port, socket.clone()); + let upstream_socket = socket.clone(); + let pool = self.clone(); tokio::spawn(async move { let mut buf: Vec = vec![0; 65535]; let mut last_received_at = None; - let upstream_socket = match upstream_socket.await { - Ok(socket) => socket, - Err(error) => { - tracing::error!(%error, "upstream socket failed to initialise"); - return; - } - }; + let mut shutdown_rx = pool.shutdown_rx.clone(); loop { - tracing::debug!(source = %source, dest = ?endpoint, "Awaiting incoming packet"); - let asn_info = asn_info.as_ref(); - - select! { + tokio::select! { received = upstream_socket.recv_from(&mut buf) => { match received { Err(error) => { - crate::metrics::errors_total(crate::metrics::WRITE, &error.to_string(), asn_info).inc(); - tracing::error!(%error, %source, dest = ?endpoint, "Error receiving packet"); + crate::metrics::errors_total(crate::metrics::WRITE, &error.to_string(), None).inc(); }, Ok((size, recv_addr)) => { let received_at = chrono::Utc::now().timestamp_nanos_opt().unwrap(); + let (downstream_addr, asn_info) = { + let storage = pool.storage.read(); + let Some(downstream_addr) = storage.destination_to_sources.get(&(recv_addr, port)) else { + tracing::warn!(address=%recv_addr, "received traffic from a server that has no downstream"); + continue; + }; + let asn_info = storage.sources_to_asn_info.get(downstream_addr); + + (*downstream_addr, asn_info.cloned()) + }; + + let asn_info = asn_info.as_ref(); if let Some(last_received_at) = last_received_at { crate::metrics::packet_jitter(crate::metrics::WRITE, asn_info).set(received_at - last_received_at); } @@ -185,15 +144,13 @@ impl Session { crate::metrics::bytes_total(crate::metrics::WRITE, asn_info).inc_by(size as u64); let timer = crate::metrics::processing_time(crate::metrics::WRITE).start_timer(); - let result = Session::process_recv_packet( - &downstream_socket, - ReceivedPacketContext { - config: config.clone(), - packet: &buf[..size], - endpoint: &endpoint, - source: recv_addr.into(), - dest: source.clone(), - }).await; + let result = Self::process_recv_packet( + pool.config.clone(), + &pool.downstream_socket, + recv_addr, + downstream_addr, + &buf[..size], + ).await; timer.stop_and_record(); if let Err(error) = result { error.log(); @@ -209,63 +166,281 @@ impl Session { }; } _ = shutdown_rx.changed() => { - tracing::debug!(%source, dest = ?endpoint, "Closing Session"); + tracing::debug!("Closing upstream socket loop"); return; } }; } }); + + self.create_session_from_existing_socket(key, socket, port, asn_info) } - fn active_session_metric(&self) -> prometheus::IntGauge { - metrics::active_sessions(self.asn_info.as_ref()) + /// Returns a reference to an existing session mapped to `key`, otherwise + /// creates a new session either from a fresh socket, or if there are sockets + /// allocated that are not reserved by an existing destination, using the + /// existing socket. + // This uses dynamic dispatch because we're using `parking_lot`, and we + // to prove that we're not holding a lock across an await point. We're + // using `parking_lot` because there's no async drop, so we can't lock + // on drop currently. + pub fn get<'pool>( + self: &'pool Arc, + key @ SessionKey { dest, .. }: SessionKey, + asn_info: Option, + ) -> Pin< + Box< + dyn std::future::Future, super::PipelineError>> + + Send + + 'pool, + >, + > { + // If we already have a session for the key pairing, return that session. + if let Some(entry) = self.session_map.get(&key) { + return Box::pin(std::future::ready(Ok(entry))); + } + + // If there's a socket_set available, it means there are sockets + // allocated to the address that we want to avoid. + let storage = self.storage.read(); + let Some(socket_set) = storage.destination_to_sockets.get(&dest) else { + drop(storage); + return if self.ports_to_sockets.is_empty() { + // Initial case where we have no allocated or reserved sockets. + Box::pin(self.create_new_session_from_new_socket(key, asn_info)) + } else { + // Where we have no allocated sockets for a destination, assign + // the first available one. + let entry = self.ports_to_sockets.iter().next().unwrap(); + let port = *entry.key(); + + Box::pin(std::future::ready( + self.create_session_from_existing_socket( + key, + entry.value().clone(), + port, + asn_info, + ), + )) + }; + }; + + if let Some(entry) = self + .ports_to_sockets + .iter() + .find(|entry| !socket_set.contains(entry.key())) + { + drop(storage); + self.storage + .write() + .destination_to_sockets + .get_mut(&dest) + .unwrap() + .insert(*entry.key()); + Box::pin(std::future::ready( + self.create_session_from_existing_socket( + key, + entry.value().clone(), + *entry.key(), + asn_info, + ), + )) + } else { + drop(storage); + Box::pin(self.create_new_session_from_new_socket(key, asn_info)) + } + } + + /// Using an existing socket, reserves the socket for a new session. + fn create_session_from_existing_socket<'session>( + self: &'session Arc, + key: SessionKey, + upstream_socket: Arc, + socket_port: u16, + asn_info: Option, + ) -> Result, super::PipelineError> { + let mut storage = self.storage.write(); + storage + .destination_to_sockets + .entry(key.dest) + .or_default() + .insert(socket_port); + storage + .sockets_to_destination + .entry(socket_port) + .or_default() + .insert(key.dest); + storage + .destination_to_sources + .insert((key.dest, socket_port), key.source); + + if let Some(asn_info) = &asn_info { + storage + .sources_to_asn_info + .insert(key.source, asn_info.clone()); + } + + self.session_map.insert( + key, + Session::new(key, upstream_socket, socket_port, self.clone(), asn_info)?, + ); + + Ok(self.session_map.get(&key).unwrap()) } /// process_recv_packet processes a packet that is received by this session. async fn process_recv_packet( + config: Arc, downstream_socket: &Arc, - packet_ctx: ReceivedPacketContext<'_>, + source: SocketAddr, + dest: SocketAddr, + packet: &[u8], ) -> Result { - let ReceivedPacketContext { - packet, - config, - endpoint, - source: from, - dest, - } = packet_ctx; - - tracing::trace!(%from, dest = %endpoint.address, contents = %crate::utils::base64_encode(packet), "received packet from upstream"); - - let mut context = WriteContext::new( - endpoint.clone(), - from.clone(), - dest.clone(), - packet.to_vec(), - ); + tracing::trace!(%source, %dest, contents = %crate::utils::base64_encode(packet), "received packet from upstream"); + + let mut context = + crate::filters::WriteContext::new(source.into(), dest.into(), packet.to_vec()); config.filters.load().write(&mut context).await?; - let addr = dest.to_socket_addr().await.map_err(Error::ToSocketAddr)?; let packet = context.contents.as_ref(); - tracing::trace!(%from, dest = %addr, contents = %crate::utils::base64_encode(packet), "sending packet downstream"); + tracing::trace!(%source, %dest, contents = %crate::utils::base64_encode(packet), "received packet from upstream"); downstream_socket - .send_to(packet, &addr) + .send_to(packet, &source) .await .map_err(Error::SendTo) } - /// Sends a packet to the Session's dest. - pub fn send<'buf>( + /// Returns a map of active sessions. + pub fn sessions(&self) -> &SessionMap { + &self.session_map + } + + /// Sends packet data to the appropiate session based on its `key`. + pub async fn send( + self: &Arc, + key: SessionKey, + asn_info: Option, + packet: &[u8], + ) -> Result { + self.get(key, asn_info) + .await? + .send(packet) + .await + .map_err(From::from) + } + + /// Returns whether the pool contains any sockets allocated to a destination. + #[cfg(test)] + fn has_no_allocated_sockets(&self) -> bool { + let storage = self.storage.read(); + let is_empty = storage.destination_to_sockets.is_empty(); + // These should always be the same. + debug_assert!(!(is_empty ^ storage.sockets_to_destination.is_empty())); + is_empty + } + + /// Forces removal of session to make testing quicker. + #[cfg(test)] + fn drop_session(&self, key: SessionKey, session: SessionRef) -> bool { + drop(session); + self.session_map.remove(key) + } + + /// Handles the logic of releasing a socket back into the pool. + fn release_socket( &self, - buf: &'buf [u8], - ) -> impl std::future::Future> + 'buf { - tracing::trace!( - dest_address = %self.dest.address, - contents = %crate::utils::base64_encode(buf), - "sending packet upstream"); - - let socket = self.upstream_socket(); - async move { socket.await?.send(buf).await.map_err(From::from) } + SessionKey { + ref source, + ref dest, + }: SessionKey, + port: u16, + ) { + let mut storage = self.storage.write(); + let socket_set = storage.destination_to_sockets.get_mut(dest).unwrap(); + + assert!(socket_set.remove(&port)); + + if socket_set.is_empty() { + storage.destination_to_sockets.remove(dest).unwrap(); + } + + let dest_set = storage.sockets_to_destination.get_mut(&port).unwrap(); + + assert!(dest_set.remove(dest)); + + if dest_set.is_empty() { + storage.sockets_to_destination.remove(&port).unwrap(); + } + + // Not asserted because the source might not have GeoIP info. + storage.sources_to_asn_info.remove(source); + assert!(storage + .destination_to_sources + .remove(&(*dest, port)) + .is_some()); + } +} + +/// Session encapsulates a UDP stream session +#[derive(Debug)] +pub struct Session { + /// created_at is time at which the session was created + created_at: Instant, + /// The source and destination pair. + key: SessionKey, + /// The socket port of the session. + socket_port: u16, + /// The socket of the session. + socket: Arc, + /// The GeoIP information of the source. + asn_info: Option, + /// The socket pool of the session. + pool: Arc, +} + +impl Session { + pub fn new( + key: SessionKey, + socket: Arc, + socket_port: u16, + pool: Arc, + asn_info: Option, + ) -> Result { + tracing::debug!(source = %key.source, dest = %key.dest, "Session created"); + + let s = Self { + key, + socket, + pool, + socket_port, + asn_info, + created_at: Instant::now(), + }; + + if let Some(asn) = &s.asn_info { + tracing::debug!( + number = asn.r#as, + organization = asn.as_name, + country_code = asn.as_cc, + prefix = asn.prefix, + prefix_entity = asn.prefix_entity, + prefix_name = asn.prefix_name, + "maxmind information" + ); + } + + self::metrics::total_sessions().inc(); + s.active_session_metric().inc(); + Ok(s) + } + + pub async fn send(&self, packet: &[u8]) -> std::io::Result { + self.socket.send_to(packet, self.key.dest).await + } + + fn active_session_metric(&self) -> prometheus::IntGauge { + metrics::active_sessions(self.asn_info.as_ref()) } } @@ -273,19 +448,26 @@ impl Drop for Session { fn drop(&mut self) { self.active_session_metric().dec(); metrics::duration_secs().observe(self.created_at.elapsed().as_secs() as f64); + tracing::debug!(source = %self.key.source, dest_address = %self.key.dest, "Session closed"); + self.pool.release_socket(self.key, self.socket_port); + } +} - if let Err(error) = self.shutdown_tx.send(()) { - tracing::warn!(%error, "Error sending session shutdown signal"); - } +// A (source, destination) address pair that uniquely identifies a session. +#[derive(Clone, Copy, Eq, Hash, PartialEq, Debug, PartialOrd, Ord)] +pub struct SessionKey { + pub source: SocketAddr, + pub dest: SocketAddr, +} - tracing::debug!(source = %self.source, dest_address = %self.dest.address, "Session closed"); +impl From<(SocketAddr, SocketAddr)> for SessionKey { + fn from((source, dest): (SocketAddr, SocketAddr)) -> Self { + Self { source, dest } } } #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("failed to convert endpoint to socket address: {0}")] - ToSocketAddr(std::io::Error), #[error("failed to send packet downstream: {0}")] SendTo(std::io::Error), #[error("filter {0}")] @@ -295,7 +477,7 @@ pub enum Error { impl Loggable for Error { fn log(&self) { match self { - Self::ToSocketAddr(error) | Self::SendTo(error) => { + Self::SendTo(error) => { tracing::error!(kind=%error.kind(), "{}", self) } Self::Filter(_) => { @@ -307,100 +489,151 @@ impl Loggable for Error { #[cfg(test)] mod tests { - use std::{str::from_utf8, sync::Arc, time::Duration}; + use super::*; + use crate::test_utils::AddressType; + use std::sync::Arc; + + async fn new_pool(config: impl Into>) -> Arc { + SessionPool::new( + Arc::new(config.into().unwrap_or_default()), + DualStackLocalSocket::new( + crate::test_utils::available_addr(&AddressType::Random) + .await + .port(), + ) + .unwrap(), + watch::channel(()).1, + ) + } + + #[tokio::test] + async fn insert_and_release_single_socket() { + let pool = new_pool(None).await; + let key = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); - use tokio::time::timeout; + let session = pool.get(key, None).await.unwrap(); - use crate::{ - endpoint::{Endpoint, EndpointAddress}, - proxy::sessions::ReceivedPacketContext, - test_utils::{create_socket, new_test_config, AddressType, TestHelper}, - }; + assert!(pool.drop_session(key, session)); - use super::*; + assert!(pool.has_no_allocated_sockets()); + } #[tokio::test] - async fn session_send_and_receive() { - let mut t = TestHelper::default(); - let addr = t.run_echo_server(&AddressType::Random).await; - let endpoint = Endpoint::new(addr.clone()); - let socket = Arc::new(create_socket().await); - let msg = "hello"; + async fn insert_and_release_multiple_sockets() { + let pool = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8081u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); - let sess = - Session::new(<_>::default(), addr.clone(), socket.clone(), endpoint, None).unwrap(); + let session1 = pool.get(key1, None).await.unwrap(); + let session2 = pool.get(key2, None).await.unwrap(); - sess.upstream_socket() - .await - .unwrap() - .send(msg.as_bytes()) - .await - .unwrap(); + assert!(pool.drop_session(key1, session1)); + assert!(!pool.has_no_allocated_sockets()); + assert!(pool.drop_session(key2, session2)); - let mut buf = vec![0; 1024]; - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .unwrap() - .unwrap(); - let packet = &buf[..size]; - assert_eq!(msg, from_utf8(packet).unwrap()); - assert_eq!(addr.port(), recv_addr.port()); + assert!(pool.has_no_allocated_sockets()); + drop(pool); } #[tokio::test] - async fn process_recv_packet() { - crate::test_utils::load_test_filters(); - - let socket = Arc::new(create_socket().await); - let endpoint = Endpoint::new("127.0.1.1:80".parse().unwrap()); - let dest: EndpointAddress = socket.local_ipv4_addr().unwrap().into(); - - // first test with no filtering - let msg = "hello"; - Session::process_recv_packet( - &socket, - ReceivedPacketContext { - config: <_>::default(), - packet: msg.as_bytes(), - endpoint: &endpoint, - source: endpoint.address.clone(), - dest: dest.clone(), - }, + async fn same_address_uses_different_sockets() { + let pool = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), ) - .await - .unwrap(); + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8081u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); - let mut buf = vec![0; 1024]; - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .expect("Should receive a packet") - .unwrap(); - assert_eq!(msg, from_utf8(&buf[..size]).unwrap()); - assert_eq!(dest.port(), recv_addr.port()); - - // add filter - let config = Arc::new(new_test_config()); - Session::process_recv_packet( - &socket, - ReceivedPacketContext { - config, - packet: msg.as_bytes(), - endpoint: &endpoint, - source: endpoint.address.clone(), - dest: dest.clone(), - }, + let socket1 = pool.get(key1, None).await.unwrap(); + let socket2 = pool.get(key2, None).await.unwrap(); + + assert_ne!(socket1.socket_port, socket2.socket_port); + } + + #[tokio::test] + async fn different_addresses_uses_same_socket() { + let pool = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), ) - .await - .unwrap(); + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8081u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8081u16).into(), + ) + .into(); - let (size, recv_addr) = timeout(Duration::from_secs(5), socket.recv_from(&mut buf)) - .await - .expect("Should receive a packet") - .unwrap(); - assert_eq!( - format!("{}:our:{}:{}", msg, endpoint.address, dest), - from_utf8(&buf[..size]).unwrap() - ); - assert_eq!(dest.port(), recv_addr.port()); + let socket1 = pool.get(key1, None).await.unwrap(); + let socket2 = pool.get(key2, None).await.unwrap(); + + assert_eq!(socket1.socket_port, socket2.socket_port); + } + + #[tokio::test] + async fn spawn_safe_same_destination() { + let pool = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + + let socket1 = pool.get(key1, None).await.unwrap(); + + let task = tokio::spawn(async move { + let _ = socket1; + }); + + let _socket2 = pool.get(key2, None).await.unwrap(); + + task.await.unwrap(); + } + + #[tokio::test] + async fn spawn_safe_different_destination() { + let pool = new_pool(None).await; + let key1 = ( + (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), + ) + .into(); + let key2 = ( + (std::net::Ipv4Addr::LOCALHOST, 8081u16).into(), + (std::net::Ipv4Addr::UNSPECIFIED, 8081u16).into(), + ) + .into(); + + let socket1 = pool.get(key1, None).await.unwrap(); + + let task = tokio::spawn(async move { + let _ = socket1; + }); + + let _socket2 = pool.get(key2, None).await.unwrap(); + + task.await.unwrap(); } } diff --git a/src/test_utils.rs b/src/test_utils.rs index a780ce4521..252881d642 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -322,7 +322,6 @@ where let endpoint = "127.0.0.1:90".parse::().unwrap(); let contents = "hello".to_string().into_bytes(); let mut context = WriteContext::new( - endpoint.clone(), endpoint.address, "127.0.0.1:70".parse().unwrap(), contents.clone(), @@ -359,6 +358,7 @@ pub fn ep(id: u8) -> Endpoint { } } +#[track_caller] pub fn new_test_config() -> crate::Config { crate::Config { filters: crate::config::Slot::new( diff --git a/src/ttl_map.rs b/src/ttl_map.rs index 7e8e2e12c8..e8d703ae3d 100644 --- a/src/ttl_map.rs +++ b/src/ttl_map.rs @@ -73,6 +73,15 @@ impl Value { } } +impl std::fmt::Debug for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Value") + .field("value", &self.value) + .field("expires_at", &self.expires_at) + .finish() + } +} + impl std::ops::Deref for Value { type Target = V; @@ -89,6 +98,18 @@ struct Map { shutdown_tx: Option>, } +impl std::fmt::Debug + for Map +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Map") + .field("inner", &self.inner) + .field("ttl", &self.ttl) + .field("shutdown_tx", &self.shutdown_tx) + .finish() + } +} + impl Drop for Map { fn drop(&mut self) { if let Some(shutdown_tx) = self.shutdown_tx.take() { @@ -211,6 +232,11 @@ where .map(|value| value.value) } + /// Removes a key-value pair from the map. + pub fn remove(&self, key: K) -> bool { + self.0.inner.remove(&key).is_some() + } + /// Returns an entry for in-place updates of the specified key-value pair. /// Note: This acquires a write lock on the map's shard that corresponds /// to the entry. @@ -231,6 +257,14 @@ where } } +impl std::fmt::Debug + for TtlMap +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TtlMap").field("inner", &self.0).finish() + } +} + impl Clone for TtlMap { fn clone(&self) -> Self { Self(self.0.clone()) diff --git a/src/utils/net.rs b/src/utils/net.rs index 0789b2a71e..5609359cb2 100644 --- a/src/utils/net.rs +++ b/src/utils/net.rs @@ -60,6 +60,7 @@ fn enable_reuse(sock: &Socket) -> io::Result<()> { /// An ipv6 socket that can accept and send data from either a local ipv4 address or ipv6 address /// with port reuse enabled and only_v6 set to false. +#[derive(Debug)] pub struct DualStackLocalSocket { socket: UdpSocket, }