diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index e93fc651..a91c4b74 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -14,7 +14,7 @@ use alloc::boxed::Box; use alloc::vec; use alloc::vec::Vec; use core::any::Any; -use core::mem::MaybeUninit; +use core::mem::{transmute, MaybeUninit}; use core::ptr; use zeroize::Zeroize; @@ -34,7 +34,6 @@ trait ThreadPoolExt { #[cfg(all(not(feature = "no-threads"), feature = "std"))] mod mt { use super::*; - use core::mem::transmute; use std::sync::{Mutex, Once}; use threadpool::ThreadPool; @@ -951,6 +950,21 @@ macro_rules! sig_variant_impl { Ok(agg_pk) } + pub fn aggregate_with_randomness( + pks: &[PublicKey], + randomness: &[u8], + nbits: usize, + pks_groupcheck: bool, + ) -> Result { + if pks.len() == 0 { + return Err(BLST_ERROR::BLST_AGGR_TYPE_MISMATCH); + } + if pks_groupcheck { + pks.validate()?; + } + Ok(pks.mult(randomness, nbits)) + } + pub fn aggregate_serialized( pks: &[&[u8]], pks_validate: bool, @@ -1516,6 +1530,21 @@ macro_rules! sig_variant_impl { Ok(agg_sig) } + pub fn aggregate_with_randomness( + sigs: &[Signature], + randomness: &[u8], + nbits: usize, + sigs_groupcheck: bool, + ) -> Result { + if sigs.len() == 0 { + return Err(BLST_ERROR::BLST_AGGR_TYPE_MISMATCH); + } + if sigs_groupcheck { + sigs.validate()?; + } + Ok(sigs.mult(randomness, nbits)) + } + pub fn aggregate_serialized( sigs: &[&[u8]], sigs_groupcheck: bool, @@ -1585,21 +1614,21 @@ macro_rules! sig_variant_impl { fn mult(&self, scalars: &[u8], nbits: usize) -> Self::Output { Self::Output { - point: unsafe { - core::mem::transmute::<&[_], &[$pk_aff]>(self) - } - .mult(scalars, nbits), + point: unsafe { transmute::<&[_], &[$pk_aff]>(self) } + .mult(scalars, nbits), } } fn add(&self) -> Self::Output { Self::Output { - point: unsafe { - core::mem::transmute::<&[_], &[$pk_aff]>(self) - } - .add(), + point: unsafe { transmute::<&[_], &[$pk_aff]>(self) } + .add(), } } + + fn validate(&self) -> Result<(), BLST_ERROR> { + unsafe { transmute::<&[_], &[$pk_aff]>(self) }.validate() + } } impl MultiPoint for [Signature] { @@ -1607,21 +1636,21 @@ macro_rules! sig_variant_impl { fn mult(&self, scalars: &[u8], nbits: usize) -> Self::Output { Self::Output { - point: unsafe { - core::mem::transmute::<&[_], &[$sig_aff]>(self) - } - .mult(scalars, nbits), + point: unsafe { transmute::<&[_], &[$sig_aff]>(self) } + .mult(scalars, nbits), } } fn add(&self) -> Self::Output { Self::Output { - point: unsafe { - core::mem::transmute::<&[_], &[$sig_aff]>(self) - } - .add(), + point: unsafe { transmute::<&[_], &[$sig_aff]>(self) } + .add(), } } + + fn validate(&self) -> Result<(), BLST_ERROR> { + unsafe { transmute::<&[_], &[$sig_aff]>(self) }.validate() + } } #[cfg(test)]