diff --git a/Cargo.lock b/Cargo.lock index c3a33f5f2e4467..2c38818d289533 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6875,6 +6875,7 @@ dependencies = [ "solana-secp256k1-recover", "solana-short-vec", "static_assertions", + "test-case", "thiserror", "wasm-bindgen", ] diff --git a/sdk/program/Cargo.toml b/sdk/program/Cargo.toml index 64aa7dff559a0c..3828390bb7c2c9 100644 --- a/sdk/program/Cargo.toml +++ b/sdk/program/Cargo.toml @@ -81,6 +81,7 @@ itertools = { workspace = true } serde_json = { workspace = true } serial_test = { workspace = true } static_assertions = { workspace = true } +test-case = { workspace = true } [build-dependencies] rustc_version = { workspace = true } diff --git a/sdk/program/src/sysvar/slot_hashes.rs b/sdk/program/src/sysvar/slot_hashes.rs index 97a465165314e3..3e027fdd55fda9 100644 --- a/sdk/program/src/sysvar/slot_hashes.rs +++ b/sdk/program/src/sysvar/slot_hashes.rs @@ -52,7 +52,6 @@ use { clock::Slot, hash::Hash, program_error::ProgramError, - slot_hashes::MAX_ENTRIES, sysvar::{get_sysvar, Sysvar, SysvarId}, }, bytemuck_derive::{Pod, Zeroable}, @@ -72,21 +71,62 @@ impl Sysvar for SlotHashes { } } +/// A bytemuck-compatible representation of a `SlotHash`. #[derive(Copy, Clone, Default, Pod, Zeroable)] #[repr(C)] -struct PodSlotHash { - slot: Slot, - hash: Hash, +pub struct PodSlotHash { + pub slot: Slot, + pub hash: Hash, } +const U64_SIZE: usize = std::mem::size_of::(); + /// API for querying the `SlotHashes` sysvar. -pub struct SlotHashesSysvar; +pub struct SlotHashesSysvar { + data: Vec, + slot_hashes_start: usize, + slot_hashes_end: usize, +} impl SlotHashesSysvar { + /// Fetch the slot hashes sysvar data and use it to return an instance of + /// `SlotHashesSysvar`. + pub fn fetch() -> Result { + // First fetch all the sysvar data. + let sysvar_len = SlotHashes::size_of(); + let mut data = vec![0; sysvar_len]; + get_sysvar( + &mut data, + &SlotHashes::id(), + /* offset */ 0, + /* length */ sysvar_len as u64, + )?; + + // Read the sysvar's vector length (u64). + let slot_hash_count = data + .get(..U64_SIZE) + .and_then(|bytes| bytes.try_into().ok()) + .map(u64::from_le_bytes) + .ok_or(ProgramError::InvalidAccountData)?; + + // From the vector length, determine the expected length of the data. + let length = (slot_hash_count as usize) + .checked_mul(std::mem::size_of::()) + .ok_or(ProgramError::ArithmeticOverflow)?; + let slot_hashes_start = U64_SIZE; + let slot_hashes_end = slot_hashes_start.saturating_add(length); + + Ok(Self { + data, + slot_hashes_start, + slot_hashes_end, + }) + } + /// Get a value from the sysvar entries by its key. /// Returns `None` if the key is not found. - pub fn get(slot: &Slot) -> Result, ProgramError> { - get_pod_slot_hashes().map(|pod_hashes| { + pub fn get(&self, slot: &Slot) -> Result, ProgramError> { + self.as_slice().map(|pod_hashes| { pod_hashes .binary_search_by(|PodSlotHash { slot: this, .. }| slot.cmp(this)) .map(|idx| pod_hashes[idx].hash) @@ -96,31 +136,21 @@ impl SlotHashesSysvar { /// Get the position of an entry in the sysvar by its key. /// Returns `None` if the key is not found. - pub fn position(slot: &Slot) -> Result, ProgramError> { - get_pod_slot_hashes().map(|pod_hashes| { + pub fn position(&self, slot: &Slot) -> Result, ProgramError> { + self.as_slice().map(|pod_hashes| { pod_hashes .binary_search_by(|PodSlotHash { slot: this, .. }| slot.cmp(this)) .ok() }) } -} - -fn get_pod_slot_hashes() -> Result, ProgramError> { - let mut pod_hashes = vec![PodSlotHash::default(); MAX_ENTRIES]; - { - let data = bytemuck::try_cast_slice_mut::(&mut pod_hashes) - .map_err(|_| ProgramError::InvalidAccountData)?; - - // Ensure the created buffer is aligned to 8. - if data.as_ptr().align_offset(8) != 0 { - return Err(ProgramError::InvalidAccountData); - } - let offset = 8; // Vector length as `u64`. - let length = (SlotHashes::size_of() as u64).saturating_sub(offset); - get_sysvar(data, &SlotHashes::id(), offset, length)?; + /// Return the slot hashes sysvar as a vector of `PodSlotHash`. + pub fn as_slice(&self) -> Result<&[PodSlotHash], ProgramError> { + self.data + .get(self.slot_hashes_start..self.slot_hashes_end) + .and_then(|data| bytemuck::try_cast_slice(data).ok()) + .ok_or(ProgramError::InvalidAccountData) } - Ok(pod_hashes) } #[cfg(test)] @@ -134,6 +164,7 @@ mod tests { sysvar::tests::mock_get_sysvar_syscall, }, serial_test::serial, + test_case::test_case, }; #[test] @@ -149,11 +180,29 @@ mod tests { ); } + fn mock_slot_hashes(slot_hashes: &SlotHashes) { + // The data is always `SlotHashes::size_of()`. + let mut data = vec![0; SlotHashes::size_of()]; + bincode::serialize_into(&mut data[..], slot_hashes).unwrap(); + mock_get_sysvar_syscall(&data); + } + + #[allow(clippy::arithmetic_side_effects)] + #[test_case(0)] + #[test_case(1)] + #[test_case(2)] + #[test_case(5)] + #[test_case(10)] + #[test_case(64)] + #[test_case(128)] + #[test_case(192)] + #[test_case(256)] + #[test_case(384)] + #[test_case(MAX_ENTRIES)] #[serial] - #[test] - fn test_slot_hashes_sysvar() { + fn test_slot_hashes_sysvar(num_entries: usize) { let mut slot_hashes = vec![]; - for i in 0..MAX_ENTRIES { + for i in 0..num_entries { slot_hashes.push(( i as u64, hash(&[(i >> 24) as u8, (i >> 16) as u8, (i >> 8) as u8, i as u8]), @@ -161,44 +210,46 @@ mod tests { } let check_slot_hashes = SlotHashes::new(&slot_hashes); - mock_get_sysvar_syscall(&bincode::serialize(&check_slot_hashes).unwrap()); - - // `get`: - assert_eq!( - SlotHashesSysvar::get(&0).unwrap().as_ref(), - check_slot_hashes.get(&0), - ); - assert_eq!( - SlotHashesSysvar::get(&256).unwrap().as_ref(), - check_slot_hashes.get(&256), - ); - assert_eq!( - SlotHashesSysvar::get(&511).unwrap().as_ref(), - check_slot_hashes.get(&511), - ); - // `None`. - assert_eq!( - SlotHashesSysvar::get(&600).unwrap().as_ref(), - check_slot_hashes.get(&600), - ); + mock_slot_hashes(&check_slot_hashes); + + // Fetch the slot hashes sysvar. + let slot_hashes_sysvar = SlotHashesSysvar::fetch().unwrap(); + let pod_slot_hashes = slot_hashes_sysvar.as_slice().unwrap(); + + // `pod_slot_hashes` should match the slot hashes. + // Note slot hashes are stored largest slot to smallest. + for (i, pod_slot_hash) in pod_slot_hashes.iter().enumerate() { + let check = slot_hashes[num_entries - 1 - i]; + assert_eq!(pod_slot_hash.slot, check.0); + assert_eq!(pod_slot_hash.hash, check.1); + } - // `position`: - assert_eq!( - SlotHashesSysvar::position(&0).unwrap(), - check_slot_hashes.position(&0), - ); - assert_eq!( - SlotHashesSysvar::position(&256).unwrap(), - check_slot_hashes.position(&256), - ); - assert_eq!( - SlotHashesSysvar::position(&511).unwrap(), - check_slot_hashes.position(&511), - ); - // `None`. - assert_eq!( - SlotHashesSysvar::position(&600).unwrap(), - check_slot_hashes.position(&600), - ); + // Check some arbitrary slots in the created slot hashes. + let num_entries = num_entries as Slot; + let check_slots = if num_entries == 0 { + vec![num_entries, num_entries + 100] + } else { + vec![ + 0, + num_entries / 4, + num_entries / 2, + num_entries - 1, + num_entries, + num_entries + 100, + ] + }; + + for slot in check_slots.iter() { + // `get`: + assert_eq!( + slot_hashes_sysvar.get(slot).unwrap().as_ref(), + check_slot_hashes.get(slot), + ); + // `position`: + assert_eq!( + slot_hashes_sysvar.position(slot).unwrap(), + check_slot_hashes.position(slot), + ); + } } }