diff --git a/Cargo.lock b/Cargo.lock index 5c7f9d36bd68eb..0455ad9a1f632b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6929,6 +6929,7 @@ dependencies = [ "solana-secp256k1-recover", "solana-short-vec", "static_assertions", + "test-case", "thiserror", "wasm-bindgen", ] diff --git a/sdk/program/Cargo.toml b/sdk/program/Cargo.toml index e6c76f6d2a42a5..954c2fa6b6511e 100644 --- a/sdk/program/Cargo.toml +++ b/sdk/program/Cargo.toml @@ -83,6 +83,7 @@ itertools = { workspace = true } serde_json = { workspace = true } serial_test = { workspace = true } static_assertions = { workspace = true } +test-case = { workspace = true } [build-dependencies] rustc_version = { workspace = true } diff --git a/sdk/program/src/sysvar/slot_hashes.rs b/sdk/program/src/sysvar/slot_hashes.rs index 97a465165314e3..69ab49dc4e124a 100644 --- a/sdk/program/src/sysvar/slot_hashes.rs +++ b/sdk/program/src/sysvar/slot_hashes.rs @@ -58,6 +58,8 @@ use { bytemuck_derive::{Pod, Zeroable}, }; +const U64_SIZE: usize = std::mem::size_of::(); + crate::declare_sysvar_id!("SysvarS1otHashes111111111111111111111111111", SlotHashes); impl Sysvar for SlotHashes { @@ -72,16 +74,104 @@ impl Sysvar for SlotHashes { } } +/// A bytemuck-compatible (plain old data) version of `SlotHash`. #[derive(Copy, Clone, Default, Pod, Zeroable)] #[repr(C)] -struct PodSlotHash { - slot: Slot, - hash: Hash, +pub struct PodSlotHash { + pub slot: Slot, + pub hash: Hash, +} + +/// API for querying of the `SlotHashes` sysvar by on-chain programs. +/// +/// Hangs onto the allocated raw buffer from the account data, which can be +/// queried or accessed directly as a slice of `PodSlotHash`. +#[derive(Default)] +pub struct PodSlotHashes { + data: Vec, + slot_hashes_start: usize, + slot_hashes_end: usize, +} + +impl PodSlotHashes { + /// Fetch all of the raw sysvar data using the `sol_get_sysvar` syscall. + pub fn fetch() -> Result { + // Allocate an uninitialized buffer for the raw sysvar data. + let sysvar_len = SlotHashes::size_of(); + let mut data = vec![0; sysvar_len]; + + // Ensure the created buffer is aligned to 8. + if data.as_ptr().align_offset(8) != 0 { + return Err(ProgramError::InvalidAccountData); + } + + // Populate the buffer by fetching all sysvar data using the + // `sol_get_sysvar` syscall. + get_sysvar( + &mut data, + &SlotHashes::id(), + /* offset */ 0, + /* length */ sysvar_len as u64, + )?; + + // Get the number of slot hashes present in the data by reading the + // `u64` length at the beginning of the data, then use that count to + // calculate the length of the slot hashes data. + // + // The rest of the buffer is uninitialized and should not be accessed. + let length = data + .get(..U64_SIZE) + .and_then(|bytes| bytes.try_into().ok()) + .map(u64::from_le_bytes) + .and_then(|length| length.checked_mul(std::mem::size_of::() as u64)) + .ok_or(ProgramError::InvalidAccountData)?; + + let slot_hashes_start = U64_SIZE; + let slot_hashes_end = slot_hashes_start.saturating_add(length as usize); + + Ok(Self { + data, + slot_hashes_start, + slot_hashes_end, + }) + } + + /// Return the `SlotHashes` sysvar data as a slice of `PodSlotHash`. + /// Returns a slice of only the initialized sysvar data. + pub fn as_slice(&self) -> Result<&[PodSlotHash], ProgramError> { + self.data + .get(self.slot_hashes_start..self.slot_hashes_end) + .and_then(|data| bytemuck::try_cast_slice(data).ok()) + .ok_or(ProgramError::InvalidAccountData) + } + + /// Given a slot, get its corresponding hash in the `SlotHashes` sysvar + /// data. Returns `None` if the slot is not found. + pub fn get(&self, slot: &Slot) -> Result, ProgramError> { + self.as_slice().map(|pod_hashes| { + pod_hashes + .binary_search_by(|PodSlotHash { slot: this, .. }| slot.cmp(this)) + .map(|idx| pod_hashes[idx].hash) + .ok() + }) + } + + /// Given a slot, get its position in the `SlotHashes` sysvar data. Returns + /// `None` if the slot is not found. + pub fn position(&self, slot: &Slot) -> Result, ProgramError> { + self.as_slice().map(|pod_hashes| { + pod_hashes + .binary_search_by(|PodSlotHash { slot: this, .. }| slot.cmp(this)) + .ok() + }) + } } /// API for querying the `SlotHashes` sysvar. +#[deprecated(since = "2.1.0", note = "Please use `PodSlotHashes` instead")] pub struct SlotHashesSysvar; +#[allow(deprecated)] impl SlotHashesSysvar { /// Get a value from the sysvar entries by its key. /// Returns `None` if the key is not found. @@ -134,6 +224,7 @@ mod tests { sysvar::tests::mock_get_sysvar_syscall, }, serial_test::serial, + test_case::test_case, }; #[test] @@ -149,6 +240,86 @@ mod tests { ); } + fn mock_slot_hashes(slot_hashes: &SlotHashes) { + // The data is always `SlotHashes::size_of()`. + let mut data = vec![0; SlotHashes::size_of()]; + bincode::serialize_into(&mut data[..], slot_hashes).unwrap(); + mock_get_sysvar_syscall(&data); + } + + #[test_case(0)] + #[test_case(1)] + #[test_case(2)] + #[test_case(5)] + #[test_case(10)] + #[test_case(64)] + #[test_case(128)] + #[test_case(192)] + #[test_case(256)] + #[test_case(384)] + #[test_case(MAX_ENTRIES)] + #[serial] + fn test_pod_slot_hashes(num_entries: usize) { + let mut slot_hashes = vec![]; + for i in 0..num_entries { + slot_hashes.push(( + i as u64, + hash(&[(i >> 24) as u8, (i >> 16) as u8, (i >> 8) as u8, i as u8]), + )); + } + + let check_slot_hashes = SlotHashes::new(&slot_hashes); + mock_slot_hashes(&check_slot_hashes); + + let pod_slot_hashes = PodSlotHashes::fetch().unwrap(); + + // Assert the slice of `PodSlotHash` has the same length as + // `SlotHashes`. + let pod_slot_hashes_slice = pod_slot_hashes.as_slice().unwrap(); + assert_eq!(pod_slot_hashes_slice.len(), slot_hashes.len()); + + // Assert `PodSlotHashes` and `SlotHashes` contain the same slot hashes + // in the same order. + for slot in slot_hashes.iter().map(|(slot, _hash)| slot) { + // `get`: + assert_eq!( + pod_slot_hashes.get(slot).unwrap().as_ref(), + check_slot_hashes.get(slot), + ); + // `position`: + assert_eq!( + pod_slot_hashes.position(slot).unwrap(), + check_slot_hashes.position(slot), + ); + } + + // Check a few `None` values. + let not_a_slot = num_entries.saturating_add(1) as u64; + assert_eq!( + pod_slot_hashes.get(¬_a_slot).unwrap().as_ref(), + check_slot_hashes.get(¬_a_slot), + ); + assert_eq!(pod_slot_hashes.get(¬_a_slot).unwrap(), None); + assert_eq!( + pod_slot_hashes.position(¬_a_slot).unwrap(), + check_slot_hashes.position(¬_a_slot), + ); + assert_eq!(pod_slot_hashes.position(¬_a_slot).unwrap(), None); + + let not_a_slot = num_entries.saturating_add(2) as u64; + assert_eq!( + pod_slot_hashes.get(¬_a_slot).unwrap().as_ref(), + check_slot_hashes.get(¬_a_slot), + ); + assert_eq!(pod_slot_hashes.get(¬_a_slot).unwrap(), None); + assert_eq!( + pod_slot_hashes.position(¬_a_slot).unwrap(), + check_slot_hashes.position(¬_a_slot), + ); + assert_eq!(pod_slot_hashes.position(¬_a_slot).unwrap(), None); + } + + #[allow(deprecated)] #[serial] #[test] fn test_slot_hashes_sysvar() {