diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 9083bd8a..ae089e82 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -1970,6 +1970,13 @@ pub mod min_sig { ); } +pub trait MultiScalar { + type Output; + + fn mult(&self, scalars: &[u8], nbits: usize) -> Self::Output; + fn add(&self) -> Self::Output; +} + #[cfg(feature = "std")] include!("pippenger.rs"); diff --git a/bindings/rust/src/pippenger-no_std.rs b/bindings/rust/src/pippenger-no_std.rs index c316e87a..b0708866 100644 --- a/bindings/rust/src/pippenger-no_std.rs +++ b/bindings/rust/src/pippenger-no_std.rs @@ -60,15 +60,26 @@ macro_rules! pippenger_mult_impl { } pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point { - let npoints = self.points.len(); + self.as_slice().mult(scalars, nbits) + } + + pub fn add(&self) -> $point { + self.as_slice().add() + } + } + + impl MultiScalar for [$point_affine] { + type Output = $point; + + fn mult(&self, scalars: &[u8], nbits: usize) -> $point { + let npoints = self.len(); let nbytes = (nbits + 7) / 8; if scalars.len() < nbytes * npoints { panic!("scalars length mismatch"); } - let p: [*const $point_affine; 2] = - [&self.points[0], ptr::null()]; + let p: [*const $point_affine; 2] = [&self[0], ptr::null()]; let s: [*const u8; 2] = [&scalars[0], ptr::null()]; let mut ret = <$point>::default(); @@ -89,10 +100,10 @@ macro_rules! pippenger_mult_impl { ret } - pub fn add(&self) -> $point { - let npoints = self.points.len(); + fn add(&self) -> $point { + let npoints = self.len(); - let p: [*const _; 2] = [&self.points[0], ptr::null()]; + let p: [*const _; 2] = [&self[0], ptr::null()]; let mut ret = <$point>::default(); unsafe { $add(&mut ret, &p[0], npoints) }; diff --git a/bindings/rust/src/pippenger.rs b/bindings/rust/src/pippenger.rs index bdaec95a..d135de37 100644 --- a/bindings/rust/src/pippenger.rs +++ b/bindings/rust/src/pippenger.rs @@ -114,7 +114,19 @@ macro_rules! pippenger_mult_impl { } pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point { - let npoints = self.points.len(); + self.as_slice().mult(scalars, nbits) + } + + pub fn add(&self) -> $point { + self.as_slice().add() + } + } + + impl MultiScalar for [$point_affine] { + type Output = $point; + + fn mult(&self, scalars: &[u8], nbits: usize) -> $point { + let npoints = self.len(); let nbytes = (nbits + 7) / 8; if scalars.len() < nbytes * npoints { @@ -124,8 +136,7 @@ macro_rules! pippenger_mult_impl { let pool = mt::da_pool(); let ncpus = pool.max_count(); if ncpus < 2 || npoints < 32 { - let p: [*const $point_affine; 2] = - [&self.points[0], ptr::null()]; + let p: [*const $point_affine; 2] = [&self[0], ptr::null()]; let s: [*const u8; 2] = [&scalars[0], ptr::null()]; unsafe { @@ -178,7 +189,7 @@ macro_rules! pippenger_mult_impl { } let grid = &grid[..]; - let points = &self.points[..]; + let points = &self[..]; let sz = unsafe { $scratch_sizeof(0) / 8 }; let mut row_sync: Vec = Vec::with_capacity(ny); @@ -262,13 +273,13 @@ macro_rules! pippenger_mult_impl { ret } - pub fn add(&self) -> $point { - let npoints = self.points.len(); + fn add(&self) -> $point { + let npoints = self.len(); let pool = mt::da_pool(); let ncpus = pool.max_count(); if ncpus < 2 || npoints < 384 { - let p: [*const _; 2] = [&self.points[0], ptr::null()]; + let p: [*const _; 2] = [&self[0], ptr::null()]; let mut ret = <$point>::default(); unsafe { $add(&mut ret, &p[0], npoints) }; return ret; @@ -295,7 +306,7 @@ macro_rules! pippenger_mult_impl { if work >= npoints { break; } - p[0] = &self.points[work]; + p[0] = &self[work]; if work + chunk > npoints { chunk = npoints - work; }