Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce VoteState::deserialize_into_uninit #2272

Merged
31 changes: 29 additions & 2 deletions sdk/program/src/serialize_utils/cursor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use {
crate::{instruction::InstructionError, pubkey::Pubkey},
std::io::{Cursor, Read},
crate::{
instruction::InstructionError,
pubkey::{Pubkey, PUBKEY_BYTES},
},
std::{
io::{BufRead as _, Cursor, Read},
ptr,
},
};

pub(crate) fn read_u8<T: AsRef<[u8]>>(cursor: &mut Cursor<T>) -> Result<u8, InstructionError> {
Expand Down Expand Up @@ -50,6 +56,27 @@ pub(crate) fn read_i64<T: AsRef<[u8]>>(cursor: &mut Cursor<T>) -> Result<i64, In
Ok(i64::from_le_bytes(buf))
}

pub(crate) fn read_pubkey_into(
cursor: &mut Cursor<&[u8]>,
pubkey: *mut Pubkey,
) -> Result<(), InstructionError> {
match cursor.fill_buf() {
Ok(buf) if buf.len() >= PUBKEY_BYTES => {
// Safety: `buf` is guaranteed to be at least `PUBKEY_BYTES` bytes
// long. Pubkey a #[repr(transparent)] wrapper around a byte array,
// so this is a byte to byte copy and it's safe.
unsafe {
ptr::copy_nonoverlapping(buf.as_ptr(), pubkey as *mut u8, PUBKEY_BYTES);
}

cursor.consume(PUBKEY_BYTES);
}
_ => return Err(InstructionError::InvalidAccountData),
}

Ok(())
}

pub(crate) fn read_pubkey<T: AsRef<[u8]>>(
cursor: &mut Cursor<T>,
) -> Result<Pubkey, InstructionError> {
Expand Down
165 changes: 147 additions & 18 deletions sdk/program/src/vote/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ use {
},
bincode::{serialize_into, ErrorKind},
serde_derive::{Deserialize, Serialize},
std::{collections::VecDeque, fmt::Debug, io::Cursor},
std::{
collections::VecDeque,
fmt::Debug,
io::Cursor,
mem::{self, MaybeUninit},
},
};

mod vote_state_0_23_5;
Expand Down Expand Up @@ -479,13 +484,81 @@ impl VoteState {
}
}

/// Deserializes the input `VoteStateVersions` buffer directly into a provided `VoteState` struct
/// Deserializes the input `VoteStateVersions` buffer directly into the provided `VoteState`.
///
/// In a SBPF context, V0_23_5 is not supported, but in non-SBPF, all versions are supported for
/// compatibility with `bincode::deserialize`.
///
/// In a BPF context, V0_23_5 is not supported, but in non-BPF, all versions are supported for
/// compatibility with `bincode::deserialize`
/// On success, `vote_state` reflects the state of the input data. On failure, `vote_state` is
/// reset to `VoteState::default()`.
pub fn deserialize_into(
input: &[u8],
vote_state: &mut VoteState,
) -> Result<(), InstructionError> {
// Rebind vote_state to *mut VoteState so that the &mut binding isn't
// accessible anymore, preventing accidental use after this point.
//
// NOTE: switch to ptr::from_mut() once platform-tools moves to rustc >= 1.76
let vote_state = vote_state as *mut VoteState;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i like this recently stabilized one over the good old as *mut T casts as documented there and for better readability (imo): https://doc.rust-lang.org/std/ptr/fn.from_mut.html

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

I'm a fan of ptr::from_mut(var).cast() too (vs var as *mut _ as *mut u8)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nod, I'll just never remember that ptr::from_mut exists

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Author

@alessandrod alessandrod Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is giving me

error: current MSRV (Minimum Supported Rust Version) is `1.75.0` but this item is stable since `1.76.0`
   --> sdk/program/src/vote/state/mod.rs:497:26
    |
497 |         let vote_state = ptr::from_mut(vote_state);
    |                          ^^^^^^^^^^^^^
    |
    = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#incompatible_msrv
    = note: `-D clippy::incompatible-msrv` implied by `-D warnings`
    = help: to override `-D warnings` add `#[allow(clippy::incompatible_msrv)]`

which is weird since there are already uses of it from accounts-db?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted since platform-tools is still on rust 1.75


// Safety: vote_state is valid to_drop (see drop_in_place() docs). After
// dropping, the pointer is treated as uninitialized and only accessed
// through ptr::write, which is safe as per drop_in_place docs.
unsafe {
std::ptr::drop_in_place(vote_state);
}

// This is to reset vote_state to VoteState::default() if deserialize fails or panics.
struct DropGuard {
vote_state: *mut VoteState,
}

impl Drop for DropGuard {
fn drop(&mut self) {
// Safety:
//
// Deserialize failed or panicked so at this point vote_state is uninitialized. We
// must write a new _valid_ value into it or after returning (or unwinding) from
// this function the caller is left with an uninitialized `&mut VoteState`, which is
// UB (references must always be valid).
//
// This is always safe and doesn't leak memory because deserialize_into_ptr() writes
// into the fields that heap alloc only when it returns Ok().
unsafe {
self.vote_state.write(VoteState::default());
}
}
}

let guard = DropGuard { vote_state };

let res = VoteState::deserialize_into_ptr(input, vote_state);
if res.is_ok() {
mem::forget(guard);
}

res
}

/// Deserializes the input `VoteStateVersions` buffer directly into the provided
/// `MaybeUninit<VoteState>`.
///
/// In a SBPF context, V0_23_5 is not supported, but in non-SBPF, all versions are supported for
/// compatibility with `bincode::deserialize`.
///
/// On success, `vote_state` is fully initialized and can be converted to `VoteState` using
/// [MaybeUninit::assume_init]. On failure, `vote_state` may still be uninitialized and must not
/// be converted to `VoteState`.
pub fn deserialize_into_uninit(
input: &[u8],
vote_state: &mut MaybeUninit<VoteState>,
) -> Result<(), InstructionError> {
VoteState::deserialize_into_ptr(input, vote_state.as_mut_ptr())
ryoqun marked this conversation as resolved.
Show resolved Hide resolved
}
brooksprumo marked this conversation as resolved.
Show resolved Hide resolved

fn deserialize_into_ptr(
input: &[u8],
vote_state: *mut VoteState,
) -> Result<(), InstructionError> {
let mut cursor = Cursor::new(input);

Expand All @@ -496,10 +569,18 @@ impl VoteState {
0 => {
#[cfg(not(target_os = "solana"))]
{
*vote_state = bincode::deserialize::<VoteStateVersions>(input)
.map(|versioned| versioned.convert_to_current())
.map_err(|_| InstructionError::InvalidAccountData)?;

// Safety: vote_state is valid as it comes from `&mut MaybeUninit<VoteState>` or
// `&mut VoteState`. In the first case, the value is uninitialized so we write()
// to avoid dropping invalid data; in the latter case, we `drop_in_place()`
// before writing so the value has already been dropped and we just write a new
// one in place.
unsafe {
vote_state.write(
bincode::deserialize::<VoteStateVersions>(input)
.map(|versioned| versioned.convert_to_current())
.map_err(|_| InstructionError::InvalidAccountData)?,
);
}
Ok(())
}
#[cfg(target_os = "solana")]
Expand Down Expand Up @@ -1129,10 +1210,56 @@ mod tests {
}

#[test]
fn test_vote_deserialize_into_nopanic() {
// base case
fn test_vote_deserialize_into_error() {
let target_vote_state = VoteState::new_rand_for_tests(Pubkey::new_unique(), 42);
let mut vote_state_buf =
bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap();
let len = vote_state_buf.len();
vote_state_buf.truncate(len - 1);

let mut test_vote_state = VoteState::default();
let e = VoteState::deserialize_into(&[], &mut test_vote_state).unwrap_err();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap_err();
assert_eq!(test_vote_state, VoteState::default());
}

#[test]
fn test_vote_deserialize_into_uninit() {
// base case
let target_vote_state = VoteState::default();
let vote_state_buf =
bincode::serialize(&VoteStateVersions::new_current(target_vote_state.clone())).unwrap();

let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(target_vote_state, test_vote_state);

// variant
// provide 4x the minimum struct size in bytes to ensure we typically touch every field
let struct_bytes_x4 = std::mem::size_of::<VoteState>() * 4;
for _ in 0..1000 {
let raw_data: Vec<u8> = (0..struct_bytes_x4).map(|_| rand::random::<u8>()).collect();
let mut unstructured = Unstructured::new(&raw_data);

let target_vote_state_versions =
VoteStateVersions::arbitrary(&mut unstructured).unwrap();
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(target_vote_state, test_vote_state);
}
}

#[test]
fn test_vote_deserialize_into_uninit_nopanic() {
// base case
let mut test_vote_state = MaybeUninit::uninit();
let e = VoteState::deserialize_into_uninit(&[], &mut test_vote_state).unwrap_err();
assert_eq!(e, InstructionError::InvalidAccountData);

// variant
Expand All @@ -1153,21 +1280,22 @@ mod tests {

// it is extremely improbable, though theoretically possible, for random bytes to be syntactically valid
// so we only check that the parser does not panic and that it succeeds or fails exactly in line with bincode
let mut test_vote_state = VoteState::default();
let test_res = VoteState::deserialize_into(&raw_data, &mut test_vote_state);
let mut test_vote_state = MaybeUninit::uninit();
let test_res = VoteState::deserialize_into_uninit(&raw_data, &mut test_vote_state);
let bincode_res = bincode::deserialize::<VoteStateVersions>(&raw_data)
.map(|versioned| versioned.convert_to_current());

if test_res.is_err() {
assert!(bincode_res.is_err());
} else {
let test_vote_state = unsafe { test_vote_state.assume_init() };
assert_eq!(test_vote_state, bincode_res.unwrap());
}
}
}

#[test]
fn test_vote_deserialize_into_ill_sized() {
fn test_vote_deserialize_into_uninit_ill_sized() {
// provide 4x the minimum struct size in bytes to ensure we typically touch every field
let struct_bytes_x4 = std::mem::size_of::<VoteState>() * 4;
for _ in 0..1000 {
Expand All @@ -1185,20 +1313,21 @@ mod tests {
expanded_buf.resize(original_buf.len() + 8, 0);

// truncated fails
let mut test_vote_state = VoteState::default();
let test_res = VoteState::deserialize_into(&truncated_buf, &mut test_vote_state);
let mut test_vote_state = MaybeUninit::uninit();
let test_res = VoteState::deserialize_into_uninit(&truncated_buf, &mut test_vote_state);
let bincode_res = bincode::deserialize::<VoteStateVersions>(&truncated_buf)
.map(|versioned| versioned.convert_to_current());

assert!(test_res.is_err());
assert!(bincode_res.is_err());

// expanded succeeds
let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&expanded_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&expanded_buf, &mut test_vote_state).unwrap();
let bincode_res = bincode::deserialize::<VoteStateVersions>(&expanded_buf)
.map(|versioned| versioned.convert_to_current());

let test_vote_state = unsafe { test_vote_state.assume_init() };
assert_eq!(test_vote_state, bincode_res.unwrap());
}
}
Expand Down
10 changes: 6 additions & 4 deletions sdk/program/src/vote/state/vote_state_0_23_5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ mod tests {
let target_vote_state_versions = VoteStateVersions::V0_23_5(Box::new(target_vote_state));
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(
target_vote_state_versions.convert_to_current(),
Expand All @@ -97,8 +98,9 @@ mod tests {
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(target_vote_state, test_vote_state);
}
Expand Down
10 changes: 6 additions & 4 deletions sdk/program/src/vote/state/vote_state_1_14_11.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ mod tests {
let target_vote_state_versions = VoteStateVersions::V1_14_11(Box::new(target_vote_state));
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(
target_vote_state_versions.convert_to_current(),
Expand All @@ -116,8 +117,9 @@ mod tests {
let vote_state_buf = bincode::serialize(&target_vote_state_versions).unwrap();
let target_vote_state = target_vote_state_versions.convert_to_current();

let mut test_vote_state = VoteState::default();
VoteState::deserialize_into(&vote_state_buf, &mut test_vote_state).unwrap();
let mut test_vote_state = MaybeUninit::uninit();
VoteState::deserialize_into_uninit(&vote_state_buf, &mut test_vote_state).unwrap();
let test_vote_state = unsafe { test_vote_state.assume_init() };

assert_eq!(target_vote_state, test_vote_state);
}
Expand Down
Loading