diff --git a/runtime/src/epoch_stakes.rs b/runtime/src/epoch_stakes.rs index 015daabe7f86c3..4841b2713c34e7 100644 --- a/runtime/src/epoch_stakes.rs +++ b/runtime/src/epoch_stakes.rs @@ -57,6 +57,12 @@ impl EpochStakes { &self.node_id_to_vote_accounts } + pub fn node_id_to_stake(&self, node_id: &Pubkey) -> Option { + self.node_id_to_vote_accounts + .get(node_id) + .map(|x| x.total_stake) + } + pub fn epoch_authorized_voters(&self) -> &Arc { &self.epoch_authorized_voters } @@ -218,9 +224,10 @@ pub(crate) mod tests { use { super::*, crate::{stake_account::StakeAccount, stakes::StakesCache}, + im::HashMap as ImHashMap, solana_sdk::{account::AccountSharedData, rent::Rent}, solana_stake_program::stake_state::{self, Delegation}, - solana_vote::vote_account::VoteAccount, + solana_vote::vote_account::{VoteAccount, VoteAccounts}, solana_vote_program::vote_state::{self, create_account_with_authorized}, std::iter, }; @@ -231,12 +238,12 @@ pub(crate) mod tests { authorized_voter: Pubkey, } - #[test] - fn test_parse_epoch_vote_accounts() { - let stake_per_account = 100; - let num_vote_accounts_per_node = 2; + fn new_vote_accounts( + num_nodes: usize, + num_vote_accounts_per_node: usize, + ) -> HashMap> { // Create some vote accounts for each pubkey - let vote_accounts_map: HashMap> = (0..10) + (0..num_nodes) .map(|_| { let node_id = solana_sdk::pubkey::new_rand(); ( @@ -259,7 +266,32 @@ pub(crate) mod tests { .collect(), ) }) - .collect(); + .collect() + } + + fn new_epoch_vote_accounts( + vote_accounts_map: &HashMap>, + node_id_to_stake_fn: impl Fn(&Pubkey) -> u64, + ) -> VoteAccountsHashMap { + // Create and process the vote accounts + vote_accounts_map + .iter() + .flat_map(|(node_id, vote_accounts)| { + vote_accounts.iter().map(|v| { + let vote_account = VoteAccount::try_from(v.account.clone()).unwrap(); + (v.vote_account, (node_id_to_stake_fn(node_id), vote_account)) + }) + }) + .collect() + } + + #[test] + fn test_parse_epoch_vote_accounts() { + let stake_per_account = 100; + let num_vote_accounts_per_node = 2; + let num_nodes = 10; + + let vote_accounts_map = new_vote_accounts(num_nodes, num_vote_accounts_per_node); let expected_authorized_voters: HashMap<_, _> = vote_accounts_map .iter() @@ -286,16 +318,8 @@ pub(crate) mod tests { }) .collect(); - // Create and process the vote accounts - let epoch_vote_accounts: HashMap<_, _> = vote_accounts_map - .iter() - .flat_map(|(_, vote_accounts)| { - vote_accounts.iter().map(|v| { - let vote_account = VoteAccount::try_from(v.account.clone()).unwrap(); - (v.vote_account, (stake_per_account, vote_account)) - }) - }) - .collect(); + let epoch_vote_accounts = + new_epoch_vote_accounts(&vote_accounts_map, |_| stake_per_account); let (total_stake, mut node_id_to_vote_accounts, epoch_authorized_voters) = EpochStakes::parse_epoch_vote_accounts(&epoch_vote_accounts, 0); @@ -319,7 +343,7 @@ pub(crate) mod tests { ); assert_eq!( total_stake, - vote_accounts_map.len() as u64 * num_vote_accounts_per_node as u64 * 100 + num_nodes as u64 * num_vote_accounts_per_node as u64 * 100 ); } @@ -485,4 +509,36 @@ pub(crate) mod tests { assert!(versioned.contains_key(&epoch2)); assert!(versioned.contains_key(&epoch3)); } + + #[test] + fn test_node_id_to_stake() { + let num_nodes = 10; + let num_vote_accounts_per_node = 2; + + let vote_accounts_map = new_vote_accounts(num_nodes, num_vote_accounts_per_node); + let node_id_to_stake_map = vote_accounts_map + .keys() + .enumerate() + .map(|(index, node_id)| (*node_id, ((index + 1) * 100) as u64)) + .collect::>(); + let epoch_vote_accounts = new_epoch_vote_accounts(&vote_accounts_map, |node_id| { + *node_id_to_stake_map.get(node_id).unwrap() + }); + let epoch_stakes = EpochStakes::new( + Arc::new(StakesEnum::Accounts(Stakes::new_for_tests( + 0, + VoteAccounts::from(Arc::new(epoch_vote_accounts)), + ImHashMap::default(), + ))), + 0, + ); + + assert_eq!(epoch_stakes.total_stake(), 11000); + for (node_id, stake) in node_id_to_stake_map.iter() { + assert_eq!( + epoch_stakes.node_id_to_stake(node_id), + Some(*stake * num_vote_accounts_per_node as u64) + ); + } + } } diff --git a/runtime/src/stakes.rs b/runtime/src/stakes.rs index 4f2dedf1facb07..0e4d7b6109ef41 100644 --- a/runtime/src/stakes.rs +++ b/runtime/src/stakes.rs @@ -316,6 +316,21 @@ impl Stakes { }) } + #[cfg(test)] + pub fn new_for_tests( + epoch: Epoch, + vote_accounts: VoteAccounts, + stake_delegations: ImHashMap, + ) -> Self { + Self { + vote_accounts, + stake_delegations, + unused: 0, + epoch, + stake_history: StakeHistory::default(), + } + } + pub(crate) fn history(&self) -> &StakeHistory { &self.stake_history } diff --git a/wen-restart/src/heaviest_fork_aggregate.rs b/wen-restart/src/heaviest_fork_aggregate.rs index 0b43b800d18573..dac13bd8274568 100644 --- a/wen-restart/src/heaviest_fork_aggregate.rs +++ b/wen-restart/src/heaviest_fork_aggregate.rs @@ -44,7 +44,7 @@ impl HeaviestForkAggregate { let mut block_stake_map = HashMap::new(); block_stake_map.insert( (my_heaviest_fork_slot, my_heaviest_fork_hash), - Self::validator_stake(epoch_stakes, my_pubkey), + epoch_stakes.node_id_to_stake(my_pubkey).unwrap_or(0), ); Self { supermajority_threshold: wait_for_supermajority_threshold_percent as f64 / 100.0, @@ -58,15 +58,6 @@ impl HeaviestForkAggregate { } } - // TODO(wen): this will a function in separate EpochStakesMap class later. - fn validator_stake(epoch_stakes: &EpochStakes, pubkey: &Pubkey) -> u64 { - epoch_stakes - .node_id_to_vote_accounts() - .get(pubkey) - .map(|x| x.total_stake) - .unwrap_or_default() - } - pub(crate) fn aggregate_from_record( &mut self, key_string: &str, @@ -110,7 +101,7 @@ impl HeaviestForkAggregate { ) -> Option { let total_stake = self.epoch_stakes.total_stake(); let from = &received_heaviest_fork.from; - let sender_stake = Self::validator_stake(&self.epoch_stakes, from); + let sender_stake = self.epoch_stakes.node_id_to_stake(from).unwrap_or(0); if sender_stake == 0 { warn!( "Gossip should not accept zero-stake RestartLastVotedFork from {:?}", @@ -183,7 +174,7 @@ impl HeaviestForkAggregate { // TODO(wen): use better epoch stake and add a test later. pub(crate) fn total_active_stake(&self) -> u64 { self.active_peers.iter().fold(0, |sum: u64, pubkey| { - sum.saturating_add(Self::validator_stake(&self.epoch_stakes, pubkey)) + sum.saturating_add(self.epoch_stakes.node_id_to_stake(pubkey).unwrap_or(0)) }) } @@ -191,7 +182,7 @@ impl HeaviestForkAggregate { self.active_peers_seen_supermajority .iter() .fold(0, |sum: u64, pubkey| { - sum.saturating_add(Self::validator_stake(&self.epoch_stakes, pubkey)) + sum.saturating_add(self.epoch_stakes.node_id_to_stake(pubkey).unwrap_or(0)) }) }