diff --git a/sdk/program/src/serialize_utils/cursor.rs b/sdk/program/src/serialize_utils/cursor.rs index 079eab8f433dc5..1c0e4b9572d7dc 100644 --- a/sdk/program/src/serialize_utils/cursor.rs +++ b/sdk/program/src/serialize_utils/cursor.rs @@ -61,7 +61,9 @@ pub(crate) fn read_pubkey_into( match cursor.fill_buf() { Ok(buf) if buf.len() >= PUBKEY_SIZE => { - // Safety: `buf` is guaranteed to be at least `PUBKEY_SIZE` bytes long + // Safety: `buf` is guaranteed to be at least `PUBKEY_SIZE` bytes + // long. Pubkey a #[repr(transparent)] wrapper around a byte array, + // so this is a byte to byte copy and it's safe. unsafe { ptr::copy_nonoverlapping(buf.as_ptr(), pubkey as *mut u8, PUBKEY_SIZE); } diff --git a/sdk/program/src/vote/state/mod.rs b/sdk/program/src/vote/state/mod.rs index 3f6c402c6aef95..ee61cc84c4cc17 100644 --- a/sdk/program/src/vote/state/mod.rs +++ b/sdk/program/src/vote/state/mod.rs @@ -483,10 +483,17 @@ impl VoteState { /// /// In a SBPF context, V0_23_5 is not supported, but in non-SBPF, all versions are supported for /// compatibility with `bincode::deserialize`. + /// + /// On success, `vote_state` reflects the state of the input data. On failure, `vote_state` is + /// reset to `VoteState::default()`. pub fn deserialize_into( input: &[u8], vote_state: &mut VoteState, ) -> Result<(), InstructionError> { + // Rebind vote_state to *mut VoteState so that the &mut binding isn't + // accessible anymore, preventing accidental use after this point. + let vote_state = vote_state as *mut VoteState; + // Safety: vote_state is valid to_drop (see drop_in_place() docs). After // dropping, the pointer is treated as uninitialized and only accessed // through ptr::write, which is safe as per drop_in_place docs. @@ -494,7 +501,24 @@ impl VoteState { std::ptr::drop_in_place(vote_state); } - VoteState::deserialize_into_ptr(input, vote_state) + match VoteState::deserialize_into_ptr(input, vote_state) { + Ok(()) => Ok(()), + Err(err) => { + // Safety: + // + // Deserialize failed so at this point vote_state is uninitialized. We must write a + // new _valid_ value into it or after returning from this function the caller is + // left with an uninitialized `&mut VoteState`, which is UB (references must always + // be valid). + // + // This is always safe and doesn't leak memory because deserialize_into_ptr() writes + // into the fields that heap alloc only when it returns Ok(). + unsafe { + vote_state.write(VoteState::default()); + } + Err(err) + } + } } /// Deserializes the input `VoteStateVersions` buffer directly into the provided @@ -1135,6 +1159,50 @@ mod tests { ); } + #[test] + fn test_vote_deserialize_into() { + // base case + let target_vote_state = VoteState::default(); + let vote_state_buf = + bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap(); + + let mut test_vote_state = VoteState::default(); + VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap(); + + assert_eq!(target_vote_state, test_vote_state); + + // variant + // provide 4x the minimum struct size in bytes to ensure we typically touch every field + let struct_bytes_x4 = std::mem::size_of::() * 4; + for _ in 0..1 { + let raw_data: Vec = (0..struct_bytes_x4).map(|_| rand::random::()).collect(); + let mut unstructured = Unstructured::new(&raw_data); + + let target_vote_state_versions = + VoteStateVersions::arbitrary(&mut unstructured).unwrap(); + let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap(); + let target_vote_state = target_vote_state_versions.convert_to_current(); + + let mut test_vote_state = VoteState::default(); + VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap(); + + assert_eq!(target_vote_state, test_vote_state); + } + } + + #[test] + fn test_vote_deserialize_into_error() { + let target_vote_state = VoteState::new_rand_for_tests(Pubkey::new_unique(), 42); + let mut vote_state_buf = + bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap(); + let len = vote_state_buf.len(); + vote_state_buf.truncate(len - 1); + + let mut test_vote_state = VoteState::default(); + VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap_err(); + assert_eq!(test_vote_state, VoteState::default()); + } + #[test] fn test_vote_deserialize_into_uninit() { // base case