Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

token 2022: add support for _writing_ repeating fixed-length extensions #5837

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions token/program-2022/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,37 @@ fn get_extension_bytes_mut<S: BaseState, V: Extension>(
Ok(&mut tlv_data[value_start..value_end])
}

fn get_repeating_extension_bytes_mut<S: BaseState, V: Extension>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit for consistency

Suggested change
fn get_repeating_extension_bytes_mut<S: BaseState, V: Extension>(
fn get_extension_bytes_with_repetition_mut<S: BaseState, V: Extension>(

tlv_data: &mut [u8],
repetition: usize,
) -> Result<&mut [u8], ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}

let mut start_index = 0;
let mut value_start = 0;
let mut value_end = 0;

for _ in 0..repetition {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's stick with 0-indexing (or use 0..=repetition if you prefer)

Suggested change
for _ in 0..repetition {
for _ in 0..repetition+1 {

let indices = get_extension_indices::<V>(&tlv_data[start_index..], false)?;

let global_length_start = indices.length_start.saturating_add(start_index);
value_start = indices.value_start.saturating_add(start_index);

let length = pod_from_bytes::<Length>(&tlv_data[global_length_start..value_start])?;
value_end = value_start.saturating_add(usize::from(*length));

if tlv_data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}

start_index = value_end;
}

Ok(&mut tlv_data[value_start..value_end])
}

/// Calculate the new expected size if the state allocates the given number
/// of bytes for the given extension type.
///
Expand Down Expand Up @@ -719,6 +750,39 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
pod_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
}

/// Unpack a portion of the TLV data as the base mutable bytes,
/// for a repeating extension
fn get_repeating_extension_bytes_mut<V: Extension + Pod>(
&mut self,
repetition: usize,
) -> Result<&mut [u8], ProgramError> {
get_repeating_extension_bytes_mut::<S, V>(self.tlv_data, repetition)
}
Comment on lines +753 to +760
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? If not, let's remove it, no need to provide more footguns


/// Unpack a portion of the TLV data as the desired type that allows
/// modifying the type, for a repeating extension
pub fn get_repeating_extension_mut<V: Extension + Pod>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
pub fn get_repeating_extension_mut<V: Extension + Pod>(
pub fn get_extension_with_repetition_mut<V: Extension + Pod>(

&mut self,
repetition: usize,
) -> Result<&mut V, ProgramError> {
pod_from_bytes_mut::<V>(self.get_repeating_extension_bytes_mut::<V>(repetition)?)
}

/// Returns an unpacked portion of TLV data that allows modifying the type,
/// based on the specified match criteria
pub fn get_first_matched_repeating_extension_mut<V: Extension + Pod>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about just find_extension_mut?

&mut self,
match_critera: impl Fn(&V) -> bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit-up-to-you: Iterator calls this predicate https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.find

) -> Result<&mut V, ProgramError> {
for (index, extension) in self.get_all_extensions::<V>()?.iter().enumerate() {
let repetition = index + 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Let's stick with 0-indexing to align with the TLV library

if match_critera(extension) {
return self.get_repeating_extension_mut(repetition);
}
}
Err(TokenError::ExtensionNotFound.into())
}

/// Packs a variable-length extension into its appropriate data segment.
/// Fails if space hasn't already been allocated for the given extension
pub fn pack_variable_len_extension<V: Extension + VariableLenPack>(
Expand Down Expand Up @@ -752,6 +816,18 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
Ok(extension_ref)
}

/// Packs the default extension data into an open slot, disregarding if
/// the extension has already been found in the data buffer.
pub fn init_extension_allow_repeating<V: Extension + Pod + Default>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
pub fn init_extension_allow_repeating<V: Extension + Pod + Default>(
pub fn init_extension_with_repetition<V: Extension + Pod + Default>(

&mut self,
) -> Result<&mut V, ProgramError> {
let length = pod_get_packed_len::<V>();
let buffer = self.alloc_allow_repeating::<V>(length)?;
let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
*extension_ref = V::default();
Ok(extension_ref)
}

/// Reallocate and overwite the TLV entry for the given variable-length
/// extension.
///
Expand Down Expand Up @@ -880,6 +956,53 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
}
}

fn alloc_allow_repeating<V: Extension>(
&mut self,
length: usize,
) -> Result<&mut [u8], ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}

let mut start_index = 0;
let mut type_start = 0;
let mut length_start = 0;
let mut value_start = 0;
let mut extension_type = V::TYPE;
let required_len = add_type_and_length_to_len(length);

while extension_type != ExtensionType::Uninitialized {
let indices = get_extension_indices::<V>(&self.tlv_data[start_index..], true)?;
(type_start, length_start, value_start) = (
indices.type_start.saturating_add(start_index),
indices.length_start.saturating_add(start_index),
indices.value_start.saturating_add(start_index),
);

if self.tlv_data[type_start..].len() < required_len {
return Err(ProgramError::InvalidAccountData);
}

extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?;
start_index = value_start.saturating_add(usize::from(*pod_from_bytes::<Length>(
&self.tlv_data[length_start..value_start],
)?));
}

// write extension type
let extension_type_array: [u8; 2] = V::TYPE.into();
let extension_type_ref = &mut self.tlv_data[type_start..length_start];
extension_type_ref.copy_from_slice(&extension_type_array);

// write length
let length_ref =
pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
*length_ref = Length::try_from(length)?;
let value_end = value_start.saturating_add(length);

Ok(&mut self.tlv_data[value_start..value_end])
}
Comment on lines +959 to +1004
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it would avoid a lot of copy-pasta to change the flag in get_extension_indices to be an enum with three variants: Get, Init, InitWithRepetition, and GetWithRepetition(usize). Do you see what I mean?

Then you can get rid of get_repeating_extension_bytes_mut and have get_repeating_extension avoid fetching all of the repetitions just to extract one.

Orrrrr, we decide to roll members without repetition to begin with 👀


/// If `extension_type` is an Account-associated ExtensionType that requires
/// initialization on InitializeAccount, this method packs the default
/// relevant Extension of an ExtensionType into an open slot if not
Expand Down Expand Up @@ -1883,6 +2006,116 @@ mod test {
);
}

#[test]
fn mint_with_repeating_extensions_pack_unpack() {
// Have to manually add the other two repeating entries, since
// `try_calculate_account_len` will skip duplicates.
let mint_size =
ExtensionType::try_calculate_account_len::<Mint>(&[ExtensionType::MintCloseAuthority])
.unwrap()
.saturating_add(36)
.saturating_add(36);
let mut buffer = vec![0; mint_size];

let close_authority1 =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
let close_authority2 =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([2; 32]))).unwrap();
let close_authority3 =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([3; 32]))).unwrap();

let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
let extension = state
.init_extension_allow_repeating::<MintCloseAuthority>()
.unwrap();
extension.close_authority = close_authority1;
let extension = state
.init_extension_allow_repeating::<MintCloseAuthority>()
.unwrap();
extension.close_authority = close_authority2;
let extension = state
.init_extension_allow_repeating::<MintCloseAuthority>()
.unwrap();
extension.close_authority = close_authority3;

assert_eq!(
&state.get_extension_types().unwrap(),
&[
ExtensionType::MintCloseAuthority,
ExtensionType::MintCloseAuthority,
ExtensionType::MintCloseAuthority,
]
);

let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
state.base = TEST_MINT;
state.pack_base();
state.init_account_type().unwrap();

let mint_close_auth_type_bytes = (ExtensionType::MintCloseAuthority as u16).to_le_bytes();
let mint_close_auth_len_bytes =
(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes();

let mut expect = TEST_MINT_SLICE.to_vec();
expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); // padding
expect.push(AccountType::Mint.into());
expect.extend_from_slice(&mint_close_auth_type_bytes);
expect.extend_from_slice(&mint_close_auth_len_bytes);
expect.extend_from_slice(&[1; 32]);
expect.extend_from_slice(&mint_close_auth_type_bytes);
expect.extend_from_slice(&mint_close_auth_len_bytes);
expect.extend_from_slice(&[2; 32]);
expect.extend_from_slice(&mint_close_auth_type_bytes);
expect.extend_from_slice(&mint_close_auth_len_bytes);
expect.extend_from_slice(&[3; 32]);
assert_eq!(expect, buffer);

// check unpacking
let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
let unpacked_extension = state
.get_repeating_extension_mut::<MintCloseAuthority>(1)
.unwrap();
assert_eq!(
*unpacked_extension,
MintCloseAuthority {
close_authority: close_authority1
}
);

// update extension
let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
unpacked_extension.close_authority = close_authority;

// check updates are propagated
let base = state.base;
let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
assert_eq!(state.base, base);
let unpacked_extension = state
.get_repeating_extension::<MintCloseAuthority>(1)
.unwrap();
assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });

// check the rest
let unpacked_extension = state
.get_repeating_extension::<MintCloseAuthority>(2)
.unwrap();
assert_eq!(
*unpacked_extension,
MintCloseAuthority {
close_authority: close_authority2
}
);
let unpacked_extension = state
.get_repeating_extension::<MintCloseAuthority>(3)
.unwrap();
assert_eq!(
*unpacked_extension,
MintCloseAuthority {
close_authority: close_authority3
}
);
}

#[test]
fn mint_extension_any_order() {
let mint_size = ExtensionType::try_calculate_account_len::<Mint>(&[
Expand Down
Loading