Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
token 2022: add support for _writing_ repeating fixed-length extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
buffalojoec committed Nov 14, 2023
1 parent 5114e94 commit 7fcf604
Showing 1 changed file with 233 additions and 0 deletions.
233 changes: 233 additions & 0 deletions token/program-2022/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,37 @@ fn get_extension_bytes_mut<S: BaseState, V: Extension>(
Ok(&mut tlv_data[value_start..value_end])
}

fn get_repeating_extension_bytes_mut<S: BaseState, V: Extension>(
tlv_data: &mut [u8],
repetition: usize,
) -> Result<&mut [u8], ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}

let mut start_index = 0;
let mut value_start = 0;
let mut value_end = 0;

for _ in 0..repetition {
let indices = get_extension_indices::<V>(&tlv_data[start_index..], false)?;

let global_length_start = indices.length_start.saturating_add(start_index);
value_start = indices.value_start.saturating_add(start_index);

let length = pod_from_bytes::<Length>(&tlv_data[global_length_start..value_start])?;
value_end = value_start.saturating_add(usize::from(*length));

if tlv_data.len() < value_end {
return Err(ProgramError::InvalidAccountData);
}

start_index = value_end;
}

Ok(&mut tlv_data[value_start..value_end])
}

/// Calculate the new expected size if the state allocates the given number
/// of bytes for the given extension type.
///
Expand Down Expand Up @@ -719,6 +750,39 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
pod_from_bytes_mut::<V>(self.get_extension_bytes_mut::<V>()?)
}

/// Unpack a portion of the TLV data as the base mutable bytes,
/// for a repeating extension
fn get_repeating_extension_bytes_mut<V: Extension + Pod>(
&mut self,
repetition: usize,
) -> Result<&mut [u8], ProgramError> {
get_repeating_extension_bytes_mut::<S, V>(self.tlv_data, repetition)
}

/// Unpack a portion of the TLV data as the desired type that allows
/// modifying the type, for a repeating extension
pub fn get_repeating_extension_mut<V: Extension + Pod>(
&mut self,
repetition: usize,
) -> Result<&mut V, ProgramError> {
pod_from_bytes_mut::<V>(self.get_repeating_extension_bytes_mut::<V>(repetition)?)
}

/// Returns an unpacked portion of TLV data that allows modifying the type,
/// based on the specified match criteria
pub fn get_first_matched_repeating_extension_mut<V: Extension + Pod>(
&mut self,
match_critera: impl Fn(&V) -> bool,
) -> Result<&mut V, ProgramError> {
for (index, extension) in self.get_all_extensions::<V>()?.iter().enumerate() {
let repetition = index + 1;
if match_critera(extension) {
return self.get_repeating_extension_mut(repetition);
}
}
Err(TokenError::ExtensionNotFound.into())
}

/// Packs a variable-length extension into its appropriate data segment.
/// Fails if space hasn't already been allocated for the given extension
pub fn pack_variable_len_extension<V: Extension + VariableLenPack>(
Expand Down Expand Up @@ -752,6 +816,18 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
Ok(extension_ref)
}

/// Packs the default extension data into an open slot, disregarding if
/// the extension has already been found in the data buffer.
pub fn init_extension_allow_repeating<V: Extension + Pod + Default>(
&mut self,
) -> Result<&mut V, ProgramError> {
let length = pod_get_packed_len::<V>();
let buffer = self.alloc_allow_repeating::<V>(length)?;
let extension_ref = pod_from_bytes_mut::<V>(buffer)?;
*extension_ref = V::default();
Ok(extension_ref)
}

/// Reallocate and overwite the TLV entry for the given variable-length
/// extension.
///
Expand Down Expand Up @@ -880,6 +956,53 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
}
}

fn alloc_allow_repeating<V: Extension>(
&mut self,
length: usize,
) -> Result<&mut [u8], ProgramError> {
if V::TYPE.get_account_type() != S::ACCOUNT_TYPE {
return Err(ProgramError::InvalidAccountData);
}

let mut start_index = 0;
let mut type_start = 0;
let mut length_start = 0;
let mut value_start = 0;
let mut extension_type = V::TYPE;
let required_len = add_type_and_length_to_len(length);

while extension_type != ExtensionType::Uninitialized {
let indices = get_extension_indices::<V>(&self.tlv_data[start_index..], true)?;
(type_start, length_start, value_start) = (
indices.type_start.saturating_add(start_index),
indices.length_start.saturating_add(start_index),
indices.value_start.saturating_add(start_index),
);

if self.tlv_data[type_start..].len() < required_len {
return Err(ProgramError::InvalidAccountData);
}

extension_type = ExtensionType::try_from(&self.tlv_data[type_start..length_start])?;
start_index = value_start.saturating_add(usize::from(*pod_from_bytes::<Length>(
&self.tlv_data[length_start..value_start],
)?));
}

// write extension type
let extension_type_array: [u8; 2] = V::TYPE.into();
let extension_type_ref = &mut self.tlv_data[type_start..length_start];
extension_type_ref.copy_from_slice(&extension_type_array);

// write length
let length_ref =
pod_from_bytes_mut::<Length>(&mut self.tlv_data[length_start..value_start])?;
*length_ref = Length::try_from(length)?;
let value_end = value_start.saturating_add(length);

Ok(&mut self.tlv_data[value_start..value_end])
}

/// If `extension_type` is an Account-associated ExtensionType that requires
/// initialization on InitializeAccount, this method packs the default
/// relevant Extension of an ExtensionType into an open slot if not
Expand Down Expand Up @@ -1883,6 +2006,116 @@ mod test {
);
}

#[test]
fn mint_with_repeating_extensions_pack_unpack() {
// Have to manually add the other two repeating entries, since
// `try_calculate_account_len` will skip duplicates.
let mint_size =
ExtensionType::try_calculate_account_len::<Mint>(&[ExtensionType::MintCloseAuthority])
.unwrap()
.saturating_add(36)
.saturating_add(36);
let mut buffer = vec![0; mint_size];

let close_authority1 =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([1; 32]))).unwrap();
let close_authority2 =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([2; 32]))).unwrap();
let close_authority3 =
OptionalNonZeroPubkey::try_from(Some(Pubkey::new_from_array([3; 32]))).unwrap();

let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
let extension = state
.init_extension_allow_repeating::<MintCloseAuthority>()
.unwrap();
extension.close_authority = close_authority1;
let extension = state
.init_extension_allow_repeating::<MintCloseAuthority>()
.unwrap();
extension.close_authority = close_authority2;
let extension = state
.init_extension_allow_repeating::<MintCloseAuthority>()
.unwrap();
extension.close_authority = close_authority3;

assert_eq!(
&state.get_extension_types().unwrap(),
&[
ExtensionType::MintCloseAuthority,
ExtensionType::MintCloseAuthority,
ExtensionType::MintCloseAuthority,
]
);

let mut state = StateWithExtensionsMut::<Mint>::unpack_uninitialized(&mut buffer).unwrap();
state.base = TEST_MINT;
state.pack_base();
state.init_account_type().unwrap();

let mint_close_auth_type_bytes = (ExtensionType::MintCloseAuthority as u16).to_le_bytes();
let mint_close_auth_len_bytes =
(pod_get_packed_len::<MintCloseAuthority>() as u16).to_le_bytes();

let mut expect = TEST_MINT_SLICE.to_vec();
expect.extend_from_slice(&[0; BASE_ACCOUNT_LENGTH - Mint::LEN]); // padding
expect.push(AccountType::Mint.into());
expect.extend_from_slice(&mint_close_auth_type_bytes);
expect.extend_from_slice(&mint_close_auth_len_bytes);
expect.extend_from_slice(&[1; 32]);
expect.extend_from_slice(&mint_close_auth_type_bytes);
expect.extend_from_slice(&mint_close_auth_len_bytes);
expect.extend_from_slice(&[2; 32]);
expect.extend_from_slice(&mint_close_auth_type_bytes);
expect.extend_from_slice(&mint_close_auth_len_bytes);
expect.extend_from_slice(&[3; 32]);
assert_eq!(expect, buffer);

// check unpacking
let mut state = StateWithExtensionsMut::<Mint>::unpack(&mut buffer).unwrap();
let unpacked_extension = state
.get_repeating_extension_mut::<MintCloseAuthority>(1)
.unwrap();
assert_eq!(
*unpacked_extension,
MintCloseAuthority {
close_authority: close_authority1
}
);

// update extension
let close_authority = OptionalNonZeroPubkey::try_from(None).unwrap();
unpacked_extension.close_authority = close_authority;

// check updates are propagated
let base = state.base;
let state = StateWithExtensions::<Mint>::unpack(&buffer).unwrap();
assert_eq!(state.base, base);
let unpacked_extension = state
.get_repeating_extension::<MintCloseAuthority>(1)
.unwrap();
assert_eq!(*unpacked_extension, MintCloseAuthority { close_authority });

// check the rest
let unpacked_extension = state
.get_repeating_extension::<MintCloseAuthority>(2)
.unwrap();
assert_eq!(
*unpacked_extension,
MintCloseAuthority {
close_authority: close_authority2
}
);
let unpacked_extension = state
.get_repeating_extension::<MintCloseAuthority>(3)
.unwrap();
assert_eq!(
*unpacked_extension,
MintCloseAuthority {
close_authority: close_authority3
}
);
}

#[test]
fn mint_extension_any_order() {
let mint_size = ExtensionType::try_calculate_account_len::<Mint>(&[
Expand Down

0 comments on commit 7fcf604

Please sign in to comment.