Skip to content

Commit

Permalink
Remove loop from elapsed slots calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasSte committed Aug 1, 2024
1 parent fa9205e commit 05c11b3
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 7 deletions.
148 changes: 147 additions & 1 deletion sdk/program/src/epoch_schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,78 @@ impl EpochSchedule {
}
}

/// Returns the number of elapsed slots between the start epoch and end epoch.
pub fn calculate_elapsed_slots(&self, start_epoch: Epoch, end_epoch: Epoch) -> u64 {
// This original code for this calculation was the following:
//
// fn calculate(start_epoch: Epoch, end_epoch: Epoch, schedule: &EpochSchedule) -> u64 {
// (start_epoch..=end_epoch)
// .map(|epoch| {
// schedule.get_slots_in_epoch(epoch.saturating_add(1))
// }).sum()
// }
//
// It can be very slow if the difference between start epoch and end epoch is too big.
// We can derive a mathematical expression to perform the calculation without a loop.
// Let S be the start epoch, E the end epoch, N the self.first_normal_epoch, C
// MINIMUM_SLOTS_PER_EPOCH.trailing_zeros(), and O the number of slots per epoch,
// then we want to know:
// elapsed_slots = 2^(S+1+C) + 2^(S+2+C) + ... + 2^(N-1+C) + O + O + ... + O
// Let's divide the work:
// before_first_normal = 2^(S+1+C) + 2^(S+2+C) + ... + 2^(N-1+C)
// after_first_normal = O + O + ... + O
// so that elapsed_slots = before_first_normal+after_first_normal
//
// before_first_normal is a geometric progression
// (https://en.wikipedia.org/wiki/Geometric_progression), whose sum is a well known value.
// before_first_normal = 2^C * (2^(S+1) + 2^(S+2) + ... + 2^(N-1))
// before_first_normal = 2^C * (2^(S+1) * (2^(N-1-S-1+1) - 1)/(2-1))
// before_first_normal = 2^C * (2^N - 2^(S+1))
//
// after_first_normal is simply a sum of terms, so we can do:
// after_first_normal = O * ((E + 1) - N + 1)
// after_first_normal = O * (E + 2 - N)
//
// [1] Note that if end_epoch is less than self.first_normal_epoch, after_first_normal is zero,
// and end_epoch+1 would assume the value of N-1 in before_first_normal.
//
// [2] Likewise, if start_epoch is greater than self.first_normal_epoch, before_first_normal is
// zero, and start_epoch+1 would assume the value of N in after_first_normal.

let n = if end_epoch.saturating_add(1) < self.first_normal_epoch {
// As in [1], E+1 should be N-1 here, so N = E + 2
end_epoch.saturating_add(2)
} else {
// N is first_normal_epoch when end_epoch+1 is greater than first_normal_slot
self.first_normal_epoch
};

// This is 2^(N)
let two_power_of_n = 2u64.saturating_pow(n as u32);
// This is 2^(S+1)
let two_power_of_s_1 = 2u64.saturating_pow(start_epoch.saturating_add(1) as u32);

// This is 2^N - 2^(S+1)
let two_powers_sub = two_power_of_n.saturating_sub(two_power_of_s_1);
// As C is log2(MINIMUM_SLOTS_PER_EPOCH), 2^C equals MINIMUM_SLOTS_PER_EPOCH
// This is 2^(C) * (2^N - 2^(S+1))
let before_first_normal = two_powers_sub.saturating_mul(MINIMUM_SLOTS_PER_EPOCH);

let n = if self.first_normal_epoch < start_epoch.saturating_add(1) {
// As in [2] (see my explanation), S+1 should be N here, so N = S + 1
start_epoch.saturating_add(1)
} else {
// N equals first_normal_epoch when the latter is less that start_epoch +1
self.first_normal_epoch
};

// This is (E + 1) - N + 1 => E + 2 - N
let e_plus_two_minus_n = end_epoch.saturating_add(2).saturating_sub(n);
let after_first_normal = e_plus_two_minus_n.saturating_mul(self.slots_per_epoch);

before_first_normal.saturating_add(after_first_normal)
}

/// get the epoch for which the given slot should save off
/// information about stakers
pub fn get_leader_schedule_epoch(&self, slot: Slot) -> Epoch {
Expand Down Expand Up @@ -187,7 +259,10 @@ impl EpochSchedule {

#[cfg(test)]
mod tests {
use super::*;
use {
super::*,
rand::distributions::{Distribution, Uniform},
};

#[test]
fn test_epoch_schedule() {
Expand Down Expand Up @@ -260,4 +335,75 @@ mod tests {
let cloned_epoch_schedule = epoch_schedule.clone();
assert_eq!(cloned_epoch_schedule, epoch_schedule);
}

fn check_elapsed_epochs(start: Epoch, end: Epoch, schedule: &EpochSchedule) -> u64 {
(start..=end)
.map(|epoch| schedule.get_slots_in_epoch(epoch.saturating_add(1)))
.sum()
}

#[test]
fn test_calculate_elapsed_slots_sanity() {
let epoch_schedule = EpochSchedule {
slots_per_epoch: 5,
leader_schedule_slot_offset: 2,
warmup: true,
first_normal_epoch: 10,
first_normal_slot: 5,
};

let cases = vec![
(0, 5),
(1, 8),
(1, 9),
(1, 10),
(10, 15),
(12, 20),
(2, 30),
(1, 1),
(10, 10),
(20, 20),
(0, 0),
(50, 20),
(20, 0),
(20, 10),
(20, 5),
(10, 5),
(8, 3),
];

for item in &cases {
assert_eq!(
check_elapsed_epochs(item.0, item.1, &epoch_schedule),
epoch_schedule.calculate_elapsed_slots(item.0, item.1)
);
}
}

#[test]
#[cfg(not(target_os = "solana"))]
fn test_calculate_elapsed_slots_fuzzy() {
let mut rng = rand::thread_rng();
let slots_per_epoch_dist = Uniform::from(1..=20);
let first_normal_epoch_dist = Uniform::from(1..=60);

let epoch_schedule = EpochSchedule {
slots_per_epoch: slots_per_epoch_dist.sample(&mut rng),
leader_schedule_slot_offset: 2,
warmup: true,
first_normal_epoch: first_normal_epoch_dist.sample(&mut rng),
first_normal_slot: 5,
};

let start_epoch_dist = Uniform::from(0..=125);
for _ in 0..5000 {
let start_epoch = start_epoch_dist.sample(&mut rng);
let end_epoch_dist = Uniform::from(start_epoch..=125);
let end_epoch = end_epoch_dist.sample(&mut rng);
assert_eq!(
check_elapsed_epochs(start_epoch, end_epoch, &epoch_schedule),
epoch_schedule.calculate_elapsed_slots(start_epoch, end_epoch)
);
}
}
}
9 changes: 3 additions & 6 deletions sdk/src/rent_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,9 @@ impl RentCollector {
if self.rent.is_exempt(lamports, data_len) {
RentDue::Exempt
} else {
let slots_elapsed: u64 = (account_rent_epoch..=self.epoch)
.map(|epoch| {
self.epoch_schedule
.get_slots_in_epoch(epoch.saturating_add(1))
})
.sum();
let slots_elapsed = self
.epoch_schedule
.calculate_elapsed_slots(account_rent_epoch, self.epoch);

// avoid infinite rent in rust 1.45
let years_elapsed = if self.slots_per_year != 0.0 {
Expand Down

0 comments on commit 05c11b3

Please sign in to comment.