From 05c11b30a7020e65910ed886496ff17b0dbf1ee9 Mon Sep 17 00:00:00 2001 From: Lucas Date: Thu, 1 Aug 2024 14:26:30 -0300 Subject: [PATCH] Remove loop from elapsed slots calculation --- sdk/program/src/epoch_schedule.rs | 148 +++++++++++++++++++++++++++++- sdk/src/rent_collector.rs | 9 +- 2 files changed, 150 insertions(+), 7 deletions(-) diff --git a/sdk/program/src/epoch_schedule.rs b/sdk/program/src/epoch_schedule.rs index 2e1398d86b9050..fd4ca5dea21fd7 100644 --- a/sdk/program/src/epoch_schedule.rs +++ b/sdk/program/src/epoch_schedule.rs @@ -110,6 +110,78 @@ impl EpochSchedule { } } + /// Returns the number of elapsed slots between the start epoch and end epoch. + pub fn calculate_elapsed_slots(&self, start_epoch: Epoch, end_epoch: Epoch) -> u64 { + // This original code for this calculation was the following: + // + // fn calculate(start_epoch: Epoch, end_epoch: Epoch, schedule: &EpochSchedule) -> u64 { + // (start_epoch..=end_epoch) + // .map(|epoch| { + // schedule.get_slots_in_epoch(epoch.saturating_add(1)) + // }).sum() + // } + // + // It can be very slow if the difference between start epoch and end epoch is too big. + // We can derive a mathematical expression to perform the calculation without a loop. + // Let S be the start epoch, E the end epoch, N the self.first_normal_epoch, C + // MINIMUM_SLOTS_PER_EPOCH.trailing_zeros(), and O the number of slots per epoch, + // then we want to know: + // elapsed_slots = 2^(S+1+C) + 2^(S+2+C) + ... + 2^(N-1+C) + O + O + ... + O + // Let's divide the work: + // before_first_normal = 2^(S+1+C) + 2^(S+2+C) + ... + 2^(N-1+C) + // after_first_normal = O + O + ... + O + // so that elapsed_slots = before_first_normal+after_first_normal + // + // before_first_normal is a geometric progression + // (https://en.wikipedia.org/wiki/Geometric_progression), whose sum is a well known value. + // before_first_normal = 2^C * (2^(S+1) + 2^(S+2) + ... + 2^(N-1)) + // before_first_normal = 2^C * (2^(S+1) * (2^(N-1-S-1+1) - 1)/(2-1)) + // before_first_normal = 2^C * (2^N - 2^(S+1)) + // + // after_first_normal is simply a sum of terms, so we can do: + // after_first_normal = O * ((E + 1) - N + 1) + // after_first_normal = O * (E + 2 - N) + // + // [1] Note that if end_epoch is less than self.first_normal_epoch, after_first_normal is zero, + // and end_epoch+1 would assume the value of N-1 in before_first_normal. + // + // [2] Likewise, if start_epoch is greater than self.first_normal_epoch, before_first_normal is + // zero, and start_epoch+1 would assume the value of N in after_first_normal. + + let n = if end_epoch.saturating_add(1) < self.first_normal_epoch { + // As in [1], E+1 should be N-1 here, so N = E + 2 + end_epoch.saturating_add(2) + } else { + // N is first_normal_epoch when end_epoch+1 is greater than first_normal_slot + self.first_normal_epoch + }; + + // This is 2^(N) + let two_power_of_n = 2u64.saturating_pow(n as u32); + // This is 2^(S+1) + let two_power_of_s_1 = 2u64.saturating_pow(start_epoch.saturating_add(1) as u32); + + // This is 2^N - 2^(S+1) + let two_powers_sub = two_power_of_n.saturating_sub(two_power_of_s_1); + // As C is log2(MINIMUM_SLOTS_PER_EPOCH), 2^C equals MINIMUM_SLOTS_PER_EPOCH + // This is 2^(C) * (2^N - 2^(S+1)) + let before_first_normal = two_powers_sub.saturating_mul(MINIMUM_SLOTS_PER_EPOCH); + + let n = if self.first_normal_epoch < start_epoch.saturating_add(1) { + // As in [2] (see my explanation), S+1 should be N here, so N = S + 1 + start_epoch.saturating_add(1) + } else { + // N equals first_normal_epoch when the latter is less that start_epoch +1 + self.first_normal_epoch + }; + + // This is (E + 1) - N + 1 => E + 2 - N + let e_plus_two_minus_n = end_epoch.saturating_add(2).saturating_sub(n); + let after_first_normal = e_plus_two_minus_n.saturating_mul(self.slots_per_epoch); + + before_first_normal.saturating_add(after_first_normal) + } + /// get the epoch for which the given slot should save off /// information about stakers pub fn get_leader_schedule_epoch(&self, slot: Slot) -> Epoch { @@ -187,7 +259,10 @@ impl EpochSchedule { #[cfg(test)] mod tests { - use super::*; + use { + super::*, + rand::distributions::{Distribution, Uniform}, + }; #[test] fn test_epoch_schedule() { @@ -260,4 +335,75 @@ mod tests { let cloned_epoch_schedule = epoch_schedule.clone(); assert_eq!(cloned_epoch_schedule, epoch_schedule); } + + fn check_elapsed_epochs(start: Epoch, end: Epoch, schedule: &EpochSchedule) -> u64 { + (start..=end) + .map(|epoch| schedule.get_slots_in_epoch(epoch.saturating_add(1))) + .sum() + } + + #[test] + fn test_calculate_elapsed_slots_sanity() { + let epoch_schedule = EpochSchedule { + slots_per_epoch: 5, + leader_schedule_slot_offset: 2, + warmup: true, + first_normal_epoch: 10, + first_normal_slot: 5, + }; + + let cases = vec![ + (0, 5), + (1, 8), + (1, 9), + (1, 10), + (10, 15), + (12, 20), + (2, 30), + (1, 1), + (10, 10), + (20, 20), + (0, 0), + (50, 20), + (20, 0), + (20, 10), + (20, 5), + (10, 5), + (8, 3), + ]; + + for item in &cases { + assert_eq!( + check_elapsed_epochs(item.0, item.1, &epoch_schedule), + epoch_schedule.calculate_elapsed_slots(item.0, item.1) + ); + } + } + + #[test] + #[cfg(not(target_os = "solana"))] + fn test_calculate_elapsed_slots_fuzzy() { + let mut rng = rand::thread_rng(); + let slots_per_epoch_dist = Uniform::from(1..=20); + let first_normal_epoch_dist = Uniform::from(1..=60); + + let epoch_schedule = EpochSchedule { + slots_per_epoch: slots_per_epoch_dist.sample(&mut rng), + leader_schedule_slot_offset: 2, + warmup: true, + first_normal_epoch: first_normal_epoch_dist.sample(&mut rng), + first_normal_slot: 5, + }; + + let start_epoch_dist = Uniform::from(0..=125); + for _ in 0..5000 { + let start_epoch = start_epoch_dist.sample(&mut rng); + let end_epoch_dist = Uniform::from(start_epoch..=125); + let end_epoch = end_epoch_dist.sample(&mut rng); + assert_eq!( + check_elapsed_epochs(start_epoch, end_epoch, &epoch_schedule), + epoch_schedule.calculate_elapsed_slots(start_epoch, end_epoch) + ); + } + } } diff --git a/sdk/src/rent_collector.rs b/sdk/src/rent_collector.rs index 96a7e479db76a6..1bb0fec03401c0 100644 --- a/sdk/src/rent_collector.rs +++ b/sdk/src/rent_collector.rs @@ -90,12 +90,9 @@ impl RentCollector { if self.rent.is_exempt(lamports, data_len) { RentDue::Exempt } else { - let slots_elapsed: u64 = (account_rent_epoch..=self.epoch) - .map(|epoch| { - self.epoch_schedule - .get_slots_in_epoch(epoch.saturating_add(1)) - }) - .sum(); + let slots_elapsed = self + .epoch_schedule + .calculate_elapsed_slots(account_rent_epoch, self.epoch); // avoid infinite rent in rust 1.45 let years_elapsed = if self.slots_per_year != 0.0 {