diff --git a/core/src/tpu.rs b/core/src/tpu.rs index 9872c1e4e16875..7f0be6319eac0d 100644 --- a/core/src/tpu.rs +++ b/core/src/tpu.rs @@ -44,7 +44,7 @@ use { pub const DEFAULT_TPU_COALESCE_MS: u64 = 5; // allow multiple connections for NAT and any open/close overlap -pub const MAX_QUIC_CONNECTIONS_PER_IP: usize = 8; +pub const MAX_QUIC_CONNECTIONS_PER_PEER: usize = 8; pub struct TpuSockets { pub transactions: Vec, @@ -161,7 +161,7 @@ impl Tpu { cluster_info.my_contact_info().tpu.ip(), packet_sender, exit.clone(), - MAX_QUIC_CONNECTIONS_PER_IP, + MAX_QUIC_CONNECTIONS_PER_PEER, staked_nodes.clone(), MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS, @@ -177,7 +177,7 @@ impl Tpu { cluster_info.my_contact_info().tpu_forwards.ip(), forwarded_packet_sender, exit.clone(), - MAX_QUIC_CONNECTIONS_PER_IP, + MAX_QUIC_CONNECTIONS_PER_PEER, staked_nodes, MAX_STAKED_CONNECTIONS.saturating_add(MAX_UNSTAKED_CONNECTIONS), 0, // Prevent unstaked nodes from forwarding transactions diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index 42bc7f9ad7f27c..302c9ba3973162 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -16,6 +16,7 @@ use { solana_perf::packet::PacketBatch, solana_sdk::{ packet::{Packet, PACKET_DATA_SIZE}, + pubkey::Pubkey, quic::{QUIC_CONNECTION_HANDSHAKE_TIMEOUT_MS, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS}, signature::Keypair, timing, @@ -44,7 +45,7 @@ pub fn spawn_server( gossip_host: IpAddr, packet_sender: Sender, exit: Arc, - max_connections_per_ip: usize, + max_connections_per_peer: usize, staked_nodes: Arc>, max_staked_connections: usize, max_unstaked_connections: usize, @@ -61,7 +62,7 @@ pub fn spawn_server( incoming, packet_sender, exit, - max_connections_per_ip, + max_connections_per_peer, staked_nodes, max_staked_connections, max_unstaked_connections, @@ -74,7 +75,7 @@ pub async fn run_server( mut incoming: Incoming, packet_sender: Sender, exit: Arc, - max_connections_per_ip: usize, + max_connections_per_peer: usize, staked_nodes: Arc>, max_staked_connections: usize, max_unstaked_connections: usize, @@ -107,7 +108,7 @@ pub async fn run_server( unstaked_connection_table.clone(), staked_connection_table.clone(), packet_sender.clone(), - max_connections_per_ip, + max_connections_per_peer, staked_nodes.clone(), max_staked_connections, max_unstaked_connections, @@ -133,7 +134,10 @@ fn prune_unstaked_connection_table( } } -fn get_connection_stake(connection: &Connection, staked_nodes: Arc>) -> u64 { +fn get_connection_stake( + connection: &Connection, + staked_nodes: Arc>, +) -> Option<(Pubkey, u64)> { connection .peer_identity() .and_then(|der_cert_any| der_cert_any.downcast::>().ok()) @@ -142,10 +146,12 @@ fn get_connection_stake(connection: &Connection, staked_nodes: Arc>, staked_connection_table: Arc>, packet_sender: Sender, - max_connections_per_ip: usize, + max_connections_per_peer: usize, staked_nodes: Arc>, max_staked_connections: usize, max_unstaked_connections: usize, @@ -175,10 +181,13 @@ async fn setup_connection( } = new_connection; let remote_addr = connection.remote_address(); + let mut remote_pubkey = None; let table_and_stake = { - let stake = get_connection_stake(&connection, staked_nodes.clone()); + let (some_pubkey, stake) = get_connection_stake(&connection, staked_nodes.clone()) + .map_or((None, 0), |(pubkey, stake)| (Some(pubkey), stake)); if stake > 0 { + remote_pubkey = some_pubkey; let mut connection_table_l = staked_connection_table.lock().unwrap(); if connection_table_l.total_size >= max_staked_connections { let num_pruned = connection_table_l.prune_random(stake); @@ -238,13 +247,13 @@ async fn setup_connection( if let Ok(max_uni_streams) = max_uni_streams { connection.set_max_concurrent_uni_streams(max_uni_streams); - if let Some((last_update, stream_exit)) = connection_table_l.try_add_connection( - &remote_addr, + ConnectionTableKey::new(remote_addr.ip(), remote_pubkey), + remote_addr.port(), Some(connection), stake, timing::timestamp(), - max_connections_per_ip, + max_connections_per_peer, ) { drop(connection_table_l); let stats = stats.clone(); @@ -266,6 +275,7 @@ async fn setup_connection( uni_streams, packet_sender, remote_addr, + remote_pubkey, last_update, connection_table, stream_exit, @@ -300,6 +310,7 @@ async fn handle_connection( mut uni_streams: IncomingUniStreams, packet_sender: Sender, remote_addr: SocketAddr, + remote_pubkey: Option, last_update: Arc, connection_table: Arc>, stream_exit: Arc, @@ -369,11 +380,11 @@ async fn handle_connection( } } } - if connection_table - .lock() - .unwrap() - .remove_connection(&remote_addr) - { + + if connection_table.lock().unwrap().remove_connection( + ConnectionTableKey::new(remote_addr.ip(), remote_pubkey), + remote_addr.port(), + ) { stats.connection_removed.fetch_add(1, Ordering::Relaxed); } else { stats @@ -510,9 +521,23 @@ enum ConnectionPeerType { Staked, } +#[derive(Copy, Clone, Eq, Hash, PartialEq)] +enum ConnectionTableKey { + IP(IpAddr), + Pubkey(Pubkey), +} + +impl ConnectionTableKey { + fn new(ip: IpAddr, maybe_pubkey: Option) -> Self { + maybe_pubkey.map_or(ConnectionTableKey::IP(ip), |pubkey| { + ConnectionTableKey::Pubkey(pubkey) + }) + } +} + // Map of IP to list of connection entries struct ConnectionTable { - table: IndexMap>, + table: IndexMap>, total_size: usize, peer_type: ConnectionPeerType, } @@ -532,23 +557,23 @@ impl ConnectionTable { let mut num_pruned = 0; while self.total_size > max_size { let mut oldest = std::u64::MAX; - let mut oldest_ip = None; - for (ip, connections) in self.table.iter() { + let mut oldest_index = None; + for (index, (_key, connections)) in self.table.iter().enumerate() { for entry in connections { let last_update = entry.last_update(); if last_update < oldest { oldest = last_update; - oldest_ip = Some(*ip); + oldest_index = Some(index); } } } - if let Some(oldest_ip) = oldest_ip { - if let Some(removed) = self.table.remove(&oldest_ip) { + if let Some(oldest_index) = oldest_index { + if let Some((_, removed)) = self.table.swap_remove_index(oldest_index) { self.total_size -= removed.len(); num_pruned += removed.len(); } } else { - // No valid entries in the table with an IP address. Continuing the loop will cause + // No valid entries in the table. Continuing the loop will cause // infinite looping. break; } @@ -594,17 +619,18 @@ impl ConnectionTable { fn try_add_connection( &mut self, - addr: &SocketAddr, + key: ConnectionTableKey, + port: u16, connection: Option, stake: u64, last_update: u64, - max_connections_per_ip: usize, + max_connections_per_peer: usize, ) -> Option<(Arc, Arc)> { - let connection_entry = self.table.entry(addr.ip()).or_insert_with(Vec::new); + let connection_entry = self.table.entry(key).or_insert_with(Vec::new); let has_connection_capacity = connection_entry .len() .checked_add(1) - .map(|c| c <= max_connections_per_ip) + .map(|c| c <= max_connections_per_peer) .unwrap_or(false); if has_connection_capacity { let exit = Arc::new(AtomicBool::new(false)); @@ -613,7 +639,7 @@ impl ConnectionTable { exit.clone(), stake, last_update.clone(), - addr.port(), + port, connection, )); self.total_size += 1; @@ -623,11 +649,11 @@ impl ConnectionTable { } } - fn remove_connection(&mut self, addr: &SocketAddr) -> bool { - if let Entry::Occupied(mut e) = self.table.entry(addr.ip()) { + fn remove_connection(&mut self, key: ConnectionTableKey, port: u16) -> bool { + if let Entry::Occupied(mut e) = self.table.entry(key) { let e_ref = e.get_mut(); let old_size = e_ref.len(); - e_ref.retain(|connection| connection.port != addr.port()); + e_ref.retain(|connection| connection.port != port); let new_size = e_ref.len(); if e_ref.is_empty() { e.remove_entry(); @@ -1092,24 +1118,38 @@ pub mod test { } #[test] - fn test_prune_table() { + fn test_prune_table_with_ip() { use std::net::Ipv4Addr; solana_logger::setup(); let mut table = ConnectionTable::new(ConnectionPeerType::Staked); let mut num_entries = 5; - let max_connections_per_ip = 10; + let max_connections_per_peer = 10; let sockets: Vec<_> = (0..num_entries) .into_iter() .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0)) .collect(); for (i, socket) in sockets.iter().enumerate() { table - .try_add_connection(socket, None, 0, i as u64, max_connections_per_ip) + .try_add_connection( + ConnectionTableKey::IP(socket.ip()), + socket.port(), + None, + 0, + i as u64, + max_connections_per_peer, + ) .unwrap(); } num_entries += 1; table - .try_add_connection(&sockets[0], None, 0, 5, max_connections_per_ip) + .try_add_connection( + ConnectionTableKey::IP(sockets[0].ip()), + sockets[0].port(), + None, + 0, + 5, + max_connections_per_peer, + ) .unwrap(); let new_size = 3; @@ -1123,18 +1163,115 @@ pub mod test { assert_eq!(table.table.len(), new_size); assert_eq!(table.total_size, new_size); for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) { - table.remove_connection(socket); + table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port()); } assert_eq!(table.total_size, 0); } + #[test] + fn test_prune_table_with_unique_pubkeys() { + solana_logger::setup(); + let mut table = ConnectionTable::new(ConnectionPeerType::Staked); + + // We should be able to add more entries than max_connections_per_peer, since each entry is + // from a different peer pubkey. + let num_entries = 15; + let max_connections_per_peer = 10; + + let pubkeys: Vec<_> = (0..num_entries) + .into_iter() + .map(|_| Pubkey::new_unique()) + .collect(); + for (i, pubkey) in pubkeys.iter().enumerate() { + table + .try_add_connection( + ConnectionTableKey::Pubkey(*pubkey), + 0, + None, + 0, + i as u64, + max_connections_per_peer, + ) + .unwrap(); + } + + let new_size = 3; + let pruned = table.prune_oldest(new_size); + assert_eq!(pruned, num_entries as usize - new_size); + assert_eq!(table.table.len(), new_size); + assert_eq!(table.total_size, new_size); + for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) { + table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0); + } + assert_eq!(table.total_size, 0); + } + + #[test] + fn test_prune_table_with_non_unique_pubkeys() { + solana_logger::setup(); + let mut table = ConnectionTable::new(ConnectionPeerType::Staked); + + let max_connections_per_peer = 10; + let pubkey = Pubkey::new_unique(); + (0..max_connections_per_peer).for_each(|i| { + table + .try_add_connection( + ConnectionTableKey::Pubkey(pubkey), + 0, + None, + 0, + i as u64, + max_connections_per_peer, + ) + .unwrap(); + }); + + // We should NOT be able to add more entries than max_connections_per_peer, since we are + // using the same peer pubkey. + assert!(table + .try_add_connection( + ConnectionTableKey::Pubkey(pubkey), + 0, + None, + 0, + 10, + max_connections_per_peer, + ) + .is_none()); + + // We should be able to add an entry from another peer pubkey + let num_entries = max_connections_per_peer + 1; + let pubkey2 = Pubkey::new_unique(); + assert!(table + .try_add_connection( + ConnectionTableKey::Pubkey(pubkey2), + 0, + None, + 0, + 10, + max_connections_per_peer, + ) + .is_some()); + + assert_eq!(table.total_size, num_entries as usize); + + let new_max_size = 3; + let pruned = table.prune_oldest(new_max_size); + assert!(pruned >= num_entries as usize - new_max_size); + assert!(table.table.len() <= new_max_size); + assert!(table.total_size <= new_max_size); + + table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0); + assert_eq!(table.total_size, 0); + } + #[test] fn test_prune_table_random() { use std::net::Ipv4Addr; solana_logger::setup(); let mut table = ConnectionTable::new(ConnectionPeerType::Staked); let num_entries = 5; - let max_connections_per_ip = 10; + let max_connections_per_peer = 10; let sockets: Vec<_> = (0..num_entries) .into_iter() .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0)) @@ -1142,11 +1279,12 @@ pub mod test { for (i, socket) in sockets.iter().enumerate() { table .try_add_connection( - socket, + ConnectionTableKey::IP(socket.ip()), + socket.port(), None, (i + 1) as u64, i as u64, - max_connections_per_ip, + max_connections_per_peer, ) .unwrap(); } @@ -1168,18 +1306,32 @@ pub mod test { solana_logger::setup(); let mut table = ConnectionTable::new(ConnectionPeerType::Staked); let num_ips = 5; - let max_connections_per_ip = 10; + let max_connections_per_peer = 10; let mut sockets: Vec<_> = (0..num_ips) .into_iter() .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0)) .collect(); for (i, socket) in sockets.iter().enumerate() { table - .try_add_connection(socket, None, 0, (i * 2) as u64, max_connections_per_ip) + .try_add_connection( + ConnectionTableKey::IP(socket.ip()), + socket.port(), + None, + 0, + (i * 2) as u64, + max_connections_per_peer, + ) .unwrap(); table - .try_add_connection(socket, None, 0, (i * 2 + 1) as u64, max_connections_per_ip) + .try_add_connection( + ConnectionTableKey::IP(socket.ip()), + socket.port(), + None, + 0, + (i * 2 + 1) as u64, + max_connections_per_peer, + ) .unwrap(); } @@ -1187,11 +1339,12 @@ pub mod test { SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips, 0, 0, 0)), 0); table .try_add_connection( - &single_connection_addr, + ConnectionTableKey::IP(single_connection_addr.ip()), + single_connection_addr.port(), None, 0, (num_ips * 2) as u64, - max_connections_per_ip, + max_connections_per_peer, ) .unwrap(); @@ -1202,7 +1355,7 @@ pub mod test { sockets.push(zero_connection_addr); for socket in sockets.iter() { - table.remove_connection(socket); + table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port()); } assert_eq!(table.total_size, 0); } diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 2aef26c8e1b26f..5c6a64cd89a6f4 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -277,7 +277,7 @@ pub fn spawn_server( gossip_host: IpAddr, packet_sender: Sender, exit: Arc, - max_connections_per_ip: usize, + max_connections_per_peer: usize, staked_nodes: Arc>, max_staked_connections: usize, max_unstaked_connections: usize, @@ -292,7 +292,7 @@ pub fn spawn_server( gossip_host, packet_sender, exit, - max_connections_per_ip, + max_connections_per_peer, staked_nodes, max_staked_connections, max_unstaked_connections,