diff --git a/utils/src/lib.rs b/utils/src/lib.rs index a3d68ba0c1..7a43db1915 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -48,13 +48,9 @@ impl From for solana_program::program_error::ProgramError { } } -pub fn is_smaller_than_bn254_field_size_be(bytes: &[u8; 32]) -> Result { +pub fn is_smaller_than_bn254_field_size_be(bytes: &[u8; 32]) -> bool { let bigint = BigUint::from_bytes_be(bytes); - if bigint < ark_bn254::Fr::MODULUS.into() { - Ok(true) - } else { - Ok(false) - } + bigint < ark_bn254::Fr::MODULUS.into() } pub fn hash_to_bn254_field_size_be(bytes: &[u8]) -> Option<([u8; 32], u8)> { @@ -67,7 +63,7 @@ pub fn hash_to_bn254_field_size_be(bytes: &[u8]) -> Option<([u8; 32], u8)> { // Truncates to 31 bytes so that value is less than bn254 Fr modulo // field size. hashed_value[0] = 0; - if let Ok(true) = is_smaller_than_bn254_field_size_be(&hashed_value) { + if is_smaller_than_bn254_field_size_be(&hashed_value) { return Some((hashed_value, bump_seed[0])); } } @@ -109,11 +105,28 @@ pub fn rustfmt(code: String) -> Result, anyhow::Error> { #[cfg(test)] mod tests { - + use num_bigint::ToBigUint; use solana_program::pubkey::Pubkey; + use crate::bigint::bigint_to_be_bytes_array; + use super::*; + #[test] + fn test_is_smaller_than_bn254_field_size_be() { + let modulus: BigUint = ark_bn254::Fr::MODULUS.into(); + let modulus_bytes: [u8; 32] = bigint_to_be_bytes_array(&modulus).unwrap(); + assert!(!is_smaller_than_bn254_field_size_be(&modulus_bytes)); + + let bigint = modulus.clone() - 1.to_biguint().unwrap(); + let bigint_bytes: [u8; 32] = bigint_to_be_bytes_array(&bigint).unwrap(); + assert!(is_smaller_than_bn254_field_size_be(&bigint_bytes)); + + let bigint = modulus + 1.to_biguint().unwrap(); + let bigint_bytes: [u8; 32] = bigint_to_be_bytes_array(&bigint).unwrap(); + assert!(!is_smaller_than_bn254_field_size_be(&bigint_bytes)); + } + #[test] fn test_hash_to_bn254_field_size_be() { for _ in 0..10_000 { @@ -122,9 +135,37 @@ mod tests { .expect("Failed to find a hash within BN254 field size"); assert_eq!(bump, 255, "Bump seed should be 0"); assert!( - is_smaller_than_bn254_field_size_be(&hashed_value).unwrap(), + is_smaller_than_bn254_field_size_be(&hashed_value), "Hashed value should be within BN254 field size" ); } + + let max_input = [u8::MAX; 32]; + let (hashed_value, bump) = hash_to_bn254_field_size_be(max_input.as_slice()) + .expect("Failed to find a hash within BN254 field size"); + assert_eq!(bump, 255, "Bump seed should be 255"); + assert!( + is_smaller_than_bn254_field_size_be(&hashed_value), + "Hashed value should be within BN254 field size" + ); + } + + #[test] + fn test_rustfmt() { + let unformatted_code = "use std::mem; + +fn main() { println!(\"{}\", mem::size_of::()); } + " + .to_string(); + let formatted_code = rustfmt(unformatted_code).unwrap(); + assert_eq!( + String::from_utf8_lossy(&formatted_code), + "use std::mem; + +fn main() { + println!(\"{}\", mem::size_of::()); +} +" + ); } }