Skip to content

Commit

Permalink
VoteState::deserialize_into: reset to VoteState::default() on failure
Browse files Browse the repository at this point in the history
On failure we must ensure that `vote_state` is left in a valid state to
avoid UB.
  • Loading branch information
alessandrod committed Jul 31, 2024
1 parent 38be6b8 commit 354d61f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
4 changes: 3 additions & 1 deletion sdk/program/src/serialize_utils/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ pub(crate) fn read_pubkey_into(

match cursor.fill_buf() {
Ok(buf) if buf.len() >= PUBKEY_SIZE => {
// Safety: `buf` is guaranteed to be at least `PUBKEY_SIZE` bytes long
// Safety: `buf` is guaranteed to be at least `PUBKEY_SIZE` bytes
// long. Pubkey a #[repr(transparent)] wrapper around a byte array,
// so this is a byte to byte copy and it's safe.
unsafe {
ptr::copy_nonoverlapping(buf.as_ptr(), pubkey as *mut u8, PUBKEY_SIZE);
}
Expand Down
70 changes: 69 additions & 1 deletion sdk/program/src/vote/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,18 +483,42 @@ impl VoteState {
///
/// In a SBPF context, V0_23_5 is not supported, but in non-SBPF, all versions are supported for
/// compatibility with `bincode::deserialize`.
///
/// On success, `vote_state` reflects the state of the input data. On failure, `vote_state` is
/// reset to `VoteState::default()`.
pub fn deserialize_into(
input: &[u8],
vote_state: &mut VoteState,
) -> Result<(), InstructionError> {
// Rebind vote_state to *mut VoteState so that the &mut binding isn't
// accessible anymore, preventing accidental use after this point.
let vote_state = vote_state as *mut VoteState;

// Safety: vote_state is valid to_drop (see drop_in_place() docs). After
// dropping, the pointer is treated as uninitialized and only accessed
// through ptr::write, which is safe as per drop_in_place docs.
unsafe {
std::ptr::drop_in_place(vote_state);
}

VoteState::deserialize_into_ptr(input, vote_state)
match VoteState::deserialize_into_ptr(input, vote_state) {
Ok(()) => Ok(()),
Err(err) => {
// Safety:
//
// Deserialize failed so at this point vote_state is uninitialized. We must write a
// new _valid_ value into it or after returning from this function the caller is
// left with an uninitialized `&mut VoteState`, which is UB (references must always
// be valid).
//
// This is always safe and doesn't leak memory because deserialize_into_ptr() writes
// into the fields that heap alloc only when it returns Ok().
unsafe {
vote_state.write(VoteState::default());
}
Err(err)
}
}
}

/// Deserializes the input `VoteStateVersions` buffer directly into the provided
Expand Down Expand Up @@ -1135,6 +1159,50 @@ mod tests {
);
}

#[test]
fn test_vote_deserialize_into() {
// base case
let target_vote_state = VoteState::default();
let vote_state_buf =
bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();

assert_eq!(target_vote_state, test_vote_state);

// variant
// provide 4x the minimum struct size in bytes to ensure we typically touch every field
let struct_bytes_x4 = std::mem::size_of::<VoteState>() * 4;
for _ in 0..1 {
let raw_data: Vec<u8> = (0..struct_bytes_x4).map(|_| rand::random::<u8>()).collect();
let mut unstructured = Unstructured::new(&raw_data);

let target_vote_state_versions =
VoteStateVersions::arbitrary(&mut unstructured).unwrap();
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();

assert_eq!(target_vote_state, test_vote_state);
}
}

#[test]
fn test_vote_deserialize_into_error() {
let target_vote_state = VoteState::new_rand_for_tests(Pubkey::new_unique(), 42);
let mut vote_state_buf =
bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap();
let len = vote_state_buf.len();
vote_state_buf.truncate(len - 1);

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap_err();
assert_eq!(test_vote_state, VoteState::default());
}

#[test]
fn test_vote_deserialize_into_uninit() {
// base case
Expand Down

0 comments on commit 354d61f

Please sign in to comment.