diff --git a/quinn-proto/src/config.rs b/quinn-proto/src/config.rs index 54e12c60a1..f45c236977 100644 --- a/quinn-proto/src/config.rs +++ b/quinn-proto/src/config.rs @@ -20,8 +20,8 @@ use crate::{ congestion, crypto::{self, HandshakeTokenKey, HmacKey}, shared::ConnectionId, - Duration, RandomConnectionIdGenerator, TokenLog, VarInt, VarIntBoundsExceeded, - DEFAULT_SUPPORTED_VERSIONS, INITIAL_MTU, MAX_CID_SIZE, MAX_UDP_PAYLOAD, + Duration, RandomConnectionIdGenerator, TokenLog, TokenMemoryCache, TokenStore, VarInt, + VarIntBoundsExceeded, DEFAULT_SUPPORTED_VERSIONS, INITIAL_MTU, MAX_CID_SIZE, MAX_UDP_PAYLOAD, }; /// Parameters governing the core QUIC state machine @@ -1066,6 +1066,9 @@ pub struct ClientConfig { /// Cryptographic configuration to use pub(crate) crypto: Arc, + /// Validation token store to use + pub(crate) token_store: Option>, + /// Provider that populates the destination connection ID of Initial Packets pub(crate) initial_dst_cid_provider: Arc ConnectionId + Send + Sync>, @@ -1079,6 +1082,7 @@ impl ClientConfig { Self { transport: Default::default(), crypto, + token_store: Some(Arc::new(TokenMemoryCache::<2>::default())), initial_dst_cid_provider: Arc::new(|| { RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid() }), @@ -1108,6 +1112,19 @@ impl ClientConfig { self } + /// Set a custom [`TokenStore`] + /// + /// Defaults to a [`TokenMemoryCache`] limited to 256 servers and 2 tokens per server. This + /// default is chosen to complement `rustls`'s default [`ClientSessionStore`]. + /// + /// [`ClientSessionStore`]: rustls::client::ClientSessionStore + /// + /// Setting to `None` disables the use of tokens from NEW_TOKEN frames as a client. + pub fn token_store(&mut self, store: Option>) -> &mut Self { + self.token_store = store; + self + } + /// Set the QUIC version to use pub fn version(&mut self, version: u32) -> &mut Self { self.version = version; @@ -1139,6 +1156,7 @@ impl fmt::Debug for ClientConfig { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("ClientConfig") .field("transport", &self.transport) + // token_store not debug // crypto not debug .field("version", &self.version) .finish_non_exhaustive() diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index fdcf40ba7e..8b36695c23 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -31,9 +31,9 @@ use crate::{ }, token::{ResetToken, Token, TokenInner}, transport_parameters::TransportParameters, - Dir, Duration, EndpointConfig, Frame, Instant, Side, StreamId, SystemTime, Transmit, - TransportError, TransportErrorCode, VarInt, INITIAL_MTU, MAX_CID_SIZE, MAX_STREAM_COUNT, - MIN_INITIAL_SIZE, TIMER_GRANULARITY, + Dir, Duration, EndpointConfig, Frame, Instant, Side, StreamId, SystemTime, TokenStore, + Transmit, TransportError, TransportErrorCode, VarInt, INITIAL_MTU, MAX_CID_SIZE, + MAX_STREAM_COUNT, MIN_INITIAL_SIZE, TIMER_GRANULARITY, }; mod ack_frequency; @@ -193,7 +193,7 @@ pub struct Connection { error: Option, /// Sent in every outgoing Initial packet. Always empty for servers and after Initial keys are /// discarded. - retry_token: Bytes, + token: Bytes, /// Identifies Data-space packet numbers to skip. Not used in earlier spaces. packet_number_filter: PacketNumberFilter, @@ -226,6 +226,9 @@ pub struct Connection { /// no outgoing application data. app_limited: bool, + token_store: Option>, + server_name: Option, + streams: StreamsState, /// Surplus remote CIDs for future use on new paths rem_cids: CidQueue, @@ -257,6 +260,8 @@ impl Connection { allow_mtud: bool, rng_seed: [u8; 32], path_validated: bool, + token_store: Option>, + server_name: Option, ) -> Self { let side = if server_config.is_some() { Side::Server @@ -273,6 +278,10 @@ impl Connection { client_hello: None, }); let mut rng = StdRng::from_seed(rng_seed); + let token = token_store + .as_ref() + .and_then(|store| store.take(server_name.as_ref().unwrap())) + .unwrap_or_default(); let mut this = Self { endpoint_config, server_config, @@ -323,7 +332,7 @@ impl Connection { timers: TimerTable::default(), authentication_failures: 0, error: None, - retry_token: Bytes::new(), + token, #[cfg(test)] packet_number_filter: match config.deterministic_packet_numbers { false => PacketNumberFilter::new(&mut rng), @@ -345,6 +354,9 @@ impl Connection { receiving_ecn: false, total_authed_packets: 0, + token_store, + server_name, + streams: StreamsState::new( side, config.max_concurrent_uni_streams, @@ -2104,7 +2116,7 @@ impl Connection { trace!("discarding {:?} keys", space_id); if space_id == SpaceId::Initial { // No longer needed - self.retry_token = Bytes::new(); + self.token = Bytes::new(); } let space = &mut self.spaces[space_id]; space.crypto = None; @@ -2423,7 +2435,7 @@ impl Connection { self.streams.retransmit_all_for_0rtt(); let token_len = packet.payload.len() - 16; - self.retry_token = packet.payload.freeze().split_to(token_len); + self.token = packet.payload.freeze().split_to(token_len); self.state = State::Handshake(state::Handshake { expected_token: Bytes::new(), rem_cid_set: false, @@ -2865,7 +2877,9 @@ impl Connection { return Err(TransportError::FRAME_ENCODING_ERROR("empty token")); } trace!("got new token"); - // TODO: Cache, or perhaps forward to user? + if let Some(store) = self.token_store.as_ref() { + store.store(self.server_name.as_ref().unwrap(), token); + } } Frame::Datagram(datagram) => { if self diff --git a/quinn-proto/src/connection/packet_builder.rs b/quinn-proto/src/connection/packet_builder.rs index 868a8c7ca8..71079037eb 100644 --- a/quinn-proto/src/connection/packet_builder.rs +++ b/quinn-proto/src/connection/packet_builder.rs @@ -113,7 +113,7 @@ impl PacketBuilder { SpaceId::Initial => Header::Initial(InitialHeader { src_cid: conn.handshake_cid, dst_cid, - token: conn.retry_token.clone(), + token: conn.token.clone(), number, version, }), diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 3ee1fd4486..ee6b79e653 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -31,7 +31,7 @@ use crate::{ }, token::{TokenDecodeError, TokenInner}, transport_parameters::{PreferredAddress, TransportParameters}, - Duration, Instant, ResetToken, Side, SystemTime, Token, Transmit, TransportConfig, + Duration, Instant, ResetToken, Side, SystemTime, Token, TokenStore, Transmit, TransportConfig, TransportError, INITIAL_MTU, MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, }; @@ -433,6 +433,8 @@ impl Endpoint { None, config.transport, true, + config.token_store, + Some(server_name.into()), ); Ok((ch, conn)) } @@ -687,6 +689,8 @@ impl Endpoint { Some(server_config), transport_config, remote_address_validated, + None, + None, ); self.index.insert_initial(dst_cid, ch); @@ -853,6 +857,8 @@ impl Endpoint { server_config: Option>, transport_config: Arc, path_validated: bool, + token_store: Option>, + server_name: Option, ) -> Connection { let mut rng_seed = [0; 32]; self.rng.fill_bytes(&mut rng_seed); @@ -877,6 +883,8 @@ impl Endpoint { self.allow_mtud, rng_seed, path_validated, + token_store, + server_name, ); let mut cids_issued = 0; diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 9625f64587..c0252aa6ef 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -96,6 +96,9 @@ mod bloom_token_log; #[cfg(feature = "fastbloom")] pub use bloom_token_log::BloomTokenLog; +mod token_store; +pub use token_store::{TokenMemoryCache, TokenStore}; + #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; diff --git a/quinn-proto/src/token_store.rs b/quinn-proto/src/token_store.rs new file mode 100644 index 0000000000..b0a3b5ef81 --- /dev/null +++ b/quinn-proto/src/token_store.rs @@ -0,0 +1,247 @@ +//! Storing tokens sent from servers in NEW_TOKEN frames and using them in subsequent connections + +use bytes::Bytes; +use slab::Slab; +use std::{ + collections::{hash_map, HashMap}, + mem::take, + sync::{Arc, Mutex}, +}; + +/// Responsible for storing address validation tokens received from servers and retrieving them for +/// use in subsequent connections +pub trait TokenStore: Send + Sync { + /// Potentially store a token for later one-time use + /// + /// Called when a NEW_TOKEN frame is received from the server. + fn store(&self, server_name: &str, token: Bytes); + + /// Try to find and take a token that was stored with the given server name + /// + /// The same token must never be returned from `take` twice, as doing so can be used to + /// de-anonymize a client's traffic. + /// + /// Called when trying to connect to a server. It is always ok for this to return `None`. + fn take(&self, server_name: &str) -> Option; +} + +/// `TokenStore` implementation that stores up to `N` tokens per server name for up to a +/// limited number of server names, in-memory +pub struct TokenMemoryCache(Mutex>); + +impl TokenMemoryCache { + /// Construct empty + pub fn new(max_server_names: usize) -> Self { + Self(Mutex::new(State::new(max_server_names))) + } +} + +impl TokenStore for TokenMemoryCache { + fn store(&self, server_name: &str, token: Bytes) { + self.0.lock().unwrap().store(server_name, token) + } + + fn take(&self, server_name: &str) -> Option { + self.0.lock().unwrap().take(server_name) + } +} + +/// Defaults to a maximum of 256 servers +impl Default for TokenMemoryCache { + fn default() -> Self { + Self::new(256) + } +} + +/// Lockable inner state of `TokenMemoryCache` +#[derive(Debug)] +struct State { + max_server_names: usize, + // map from server name to slab index in linked + lookup: HashMap, usize>, + linked: LinkedCache, +} + +impl State { + fn new(max_server_names: usize) -> Self { + assert!(max_server_names > 0, "size limit cannot be 0"); + Self { + max_server_names, + lookup: HashMap::new(), + linked: LinkedCache::default(), + } + } + + fn store(&mut self, server_name: &str, token: Bytes) { + let server_name = Arc::::from(server_name); + let idx = match self.lookup.entry(server_name.clone()) { + hash_map::Entry::Occupied(hmap_entry) => { + // key already exists, add the new token to its token stack + let entry = &mut self.linked.entries[*hmap_entry.get()]; + entry.tokens.push(token); + + // unlink the entry and set it up to be linked as the most recently used + self.linked.unlink(*hmap_entry.get()); + *hmap_entry.get() + } + hash_map::Entry::Vacant(hmap_entry) => { + // key does not yet exist, create a new one, evicting the oldest if necessary + let removed_key = if self.linked.entries.len() >= self.max_server_names { + // unwrap safety: max_server_names is > 0, so there's at least one entry, so + // oldest_newest is some + let oldest = self.linked.oldest_newest.unwrap().0; + self.linked.unlink(oldest); + Some(self.linked.entries.remove(oldest).server_name) + } else { + None + }; + + let cache_entry = CacheEntry::new(server_name, token); + let idx = self.linked.entries.insert(cache_entry); + hmap_entry.insert(idx); + + // for borrowing reasons, we must defer removing the evicted hmap entry + if let Some(removed_key) = removed_key { + let removed = self.lookup.remove(&removed_key); + debug_assert!(removed.is_some()); + } + + idx + } + }; + + // link it as the newest entry + self.linked.link(idx); + } + + fn take(&mut self, server_name: &str) -> Option { + if let hash_map::Entry::Occupied(hmap_entry) = self.lookup.entry(server_name.into()) { + let entry = &mut self.linked.entries[*hmap_entry.get()]; + // pop from entry's token stack + let token = entry.tokens.pop(); + if entry.tokens.len > 1 { + // re-link entry as most recently used + self.linked.unlink(*hmap_entry.get()); + self.linked.link(*hmap_entry.get()); + } else { + // token stack emptied, remove entry + self.linked.unlink(*hmap_entry.get()); + self.linked.entries.remove(*hmap_entry.get()); + hmap_entry.remove(); + } + Some(token) + } else { + None + } + } +} + +/// Cache entry within `LinkedCache` +#[derive(Debug)] +struct CacheEntry { + older: Option, + newer: Option, + server_name: Arc, + tokens: Queue, +} + +impl CacheEntry { + /// Construct with a single token, not linked + fn new(server_name: Arc, token: Bytes) -> Self { + let mut tokens = Queue::new(); + tokens.push(token); + Self { + server_name, + older: None, + newer: None, + tokens, + } + } +} + +/// Slab-based linked LRU cache of `CacheEntry` +#[derive(Debug, Default)] +struct LinkedCache { + entries: Slab>, + oldest_newest: Option<(usize, usize)>, +} + +impl LinkedCache { + /// Unlink an entry's neighbors from it + fn unlink(&mut self, idx: usize) { + // re-link older's newer + if let Some(older) = self.entries[idx].older { + self.entries[older].newer = self.entries[idx].newer; + } else { + // unwrap safety: entries[idx] exists, therefore oldest_newest is some + self.oldest_newest = self.entries[idx] + .newer + .map(|newer| (self.oldest_newest.unwrap().0, newer)); + } + // re-link newer's older + if let Some(newer) = self.entries[idx].newer { + self.entries[newer].older = self.entries[idx].older; + } else { + // unwrap safety: oldest_newest is none iff entries[idx] was the only entry. + // if entries[idx].older is some, entries[idx] was not the only entry + // therefore oldest_newest is some. + self.oldest_newest = self.entries[idx] + .older + .map(|older| (older, self.oldest_newest.unwrap().1)); + } + } + + /// Link an entry as the most recently used entry + /// + /// Assumes any pre-existing neighbors are already unlinked. + fn link(&mut self, idx: usize) { + self.entries[idx].newer = None; + self.entries[idx].older = self.oldest_newest.map(|(_, newest)| newest); + if let Some((_, ref mut newest)) = self.oldest_newest.as_mut() { + *newest = idx; + } else { + self.oldest_newest = Some((idx, idx)); + } + } +} + +/// In-place deque queue of up to `N` `Bytes` +#[derive(Debug)] +struct Queue { + elems: [Bytes; N], + // if len > 0, front is elems[start] + // invariant: start < N + start: usize, + // if len > 0, back is elems[(start + len - 1) % N] + len: usize, +} + +impl Queue { + /// Construct empty + fn new() -> Self { + const EMPTY_BYTES: Bytes = Bytes::new(); + Self { + elems: [EMPTY_BYTES; N], + start: 0, + len: 0, + } + } + + /// Push to back, popping from front first if already at capacity + fn push(&mut self, elem: Bytes) { + self.elems[(self.start + self.len) % N] = elem; + if self.len < N { + self.len += 1; + } else { + self.start += 1; + self.start %= N; + } + } + + /// Pop from front, panicking if empty + fn pop(&mut self) -> Bytes { + const PANIC_MSG: &str = "TokenMemoryCache popped from empty Queue, this is a bug!"; + self.len = self.len.checked_sub(1).expect(PANIC_MSG); + take(&mut self.elems[(self.start + self.len) % N]) + } +} diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 0019b03524..4cb0b2caa2 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -68,8 +68,8 @@ pub use proto::{ ConfigError, ConnectError, ConnectionClose, ConnectionError, ConnectionId, ConnectionIdGenerator, ConnectionStats, Dir, EcnCodepoint, EndpointConfig, FrameStats, FrameType, IdleTimeout, MtuDiscoveryConfig, PathStats, ServerConfig, Side, StreamId, TokenLog, - TokenReuseError, Transmit, TransportConfig, TransportErrorCode, UdpStats, VarInt, - VarIntBoundsExceeded, Written, + TokenMemoryCache, TokenReuseError, TokenStore, Transmit, TransportConfig, TransportErrorCode, + UdpStats, VarInt, VarIntBoundsExceeded, Written, }; #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))] pub use rustls;