diff --git a/ipa-core/benches/transpose.rs b/ipa-core/benches/transpose.rs index 9350e9ccc..e0b872e93 100644 --- a/ipa-core/benches/transpose.rs +++ b/ipa-core/benches/transpose.rs @@ -16,6 +16,7 @@ use std::{array, iter::repeat_with, time::Duration}; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; use ipa_core::{ + error::UnwrapInfallible, ff::boolean_array::BA64, secret_sharing::{ vector::{transpose_16x16, transpose_8x8}, @@ -75,7 +76,7 @@ fn bench_8x8(c: &mut Criterion) { cols: 8, iters: 100, }, - transpose_8x8, + |m| transpose_8x8(m), ); } @@ -101,7 +102,7 @@ fn bench_64x64(c: &mut Criterion) { }, |src| { let mut dst = array::from_fn(|_| BA64::ZERO); - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); dst }, ); diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index 2ee1f9757..8b4145671 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -36,9 +36,12 @@ // This rule throws false positives on "MxN". #![allow(clippy::doc_markdown)] -use std::array; +#[cfg(any(all(test, unit_test), feature = "enable-benches"))] +use std::borrow::Borrow; +use std::{array, convert::Infallible}; use crate::{ + error::{LengthError, UnwrapInfallible}, ff::{ boolean::Boolean, boolean_array::{BA16, BA256, BA64}, @@ -51,15 +54,25 @@ use crate::{ /// Trait for overwriting a value with the transpose of a source value. pub trait TransposeFrom { - fn transpose_from(&mut self, src: T); - - fn transposed_from(src: T) -> Self + type Error; + + /// Overwrite `self` with the transpose of `src`. + /// + /// # Errors + /// If the size of the source and destination are not compatible. + fn transpose_from(&mut self, src: T) -> Result<(), Self::Error>; + + /// Fills a new `Self` with the transpose of `src`. + /// + /// # Errors + /// If the size of the source and destination are not compatible. + fn transposed_from(src: T) -> Result where Self: Default, { let mut dst = Self::default(); - dst.transpose_from(src); - dst + dst.transpose_from(src)?; + Ok(dst) } } @@ -67,10 +80,9 @@ pub trait TransposeFrom { // // From Hacker's Delight (2nd edition), Figure 7-6. #[cfg(any(all(test, unit_test), feature = "enable-benches"))] -#[allow(clippy::trivially_copy_pass_by_ref)] // Keeps interface consistent with other functions. #[inline] -pub fn transpose_8x8(x: &[u8; 8]) -> [u8; 8] { - let mut x = u64::from_le_bytes(*x); +pub fn transpose_8x8>(x: B) -> [u8; 8] { + let mut x = u64::from_le_bytes(*x.borrow()); x = x & 0xaa55_aa55_aa55_aa55 | (x & 0x00aa_00aa_00aa_00aa) << 7 @@ -143,14 +155,20 @@ pub fn transpose_16x16(src: &[u8; 32]) -> [u8; 32] { // Degenerate transposes. impl TransposeFrom> for Vec> { - fn transpose_from(&mut self, src: AdditiveShare) { + type Error = Infallible; + + fn transpose_from(&mut self, src: AdditiveShare) -> Result<(), Infallible> { *self = vec![src]; + Ok(()) } } impl TransposeFrom>> for Vec { - fn transpose_from(&mut self, src: Vec>) { + type Error = Infallible; + + fn transpose_from(&mut self, src: Vec>) -> Result<(), Infallible> { *self = vec![src.iter().map(Boolean::from_array).collect::()]; + Ok(()) } } @@ -182,7 +200,9 @@ fn do_transpose_16 [u8; 32], DF: FnMut(usize, usize, [u8 macro_rules! impl_transpose_ba_to_ba { ($dst_row:ty, $src_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { impl TransposeFrom<&[$src_row; $src_rows]> for [$dst_row; $src_cols] { - fn transpose_from(&mut self, src: &[$src_row; $src_rows]) { + type Error = Infallible; + + fn transpose_from(&mut self, src: &[$src_row; $src_rows]) -> Result<(), Infallible> { do_transpose_16( $src_rows / 16, $src_cols / 16, @@ -202,6 +222,7 @@ macro_rules! impl_transpose_ba_to_ba { } }, ); + Ok(()) } } @@ -212,19 +233,27 @@ macro_rules! impl_transpose_ba_to_ba { } impl TransposeFrom<&BitDecomposed<$src_row>> for Vec<$dst_row> { - fn transpose_from(&mut self, src: &BitDecomposed<$src_row>) { + type Error = LengthError; + + fn transpose_from(&mut self, src: &BitDecomposed<$src_row>) -> Result<(), LengthError> { self.resize($src_cols, <$dst_row>::ZERO); - let src = <&[$src_row; $src_rows]>::try_from(&**src).unwrap(); + let src = <&[$src_row; $src_rows]>::try_from(&**src).map_err(|_| LengthError { + expected: $src_rows, + actual: src.len(), + })?; let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); + Ok(()) } } impl TransposeFrom<&[$src_row; $src_rows]> for Vec<$dst_row> { - fn transpose_from(&mut self, src: &[$src_row; $src_rows]) { + type Error = Infallible; + + fn transpose_from(&mut self, src: &[$src_row; $src_rows]) -> Result<(), Infallible> { self.resize($src_cols, <$dst_row>::ZERO); let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src); + dst.transpose_from(src) } } }; @@ -244,7 +273,12 @@ macro_rules! impl_transpose_shares_bool_to_ba { impl TransposeFrom<&[AdditiveShare; $src_rows]> for [AdditiveShare<$dst_row>; $src_cols] { - fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + type Error = Infallible; + + fn transpose_from( + &mut self, + src: &[AdditiveShare; $src_rows], + ) -> Result<(), Infallible> { // Transpose left share do_transpose_16( $src_rows / 16, @@ -287,6 +321,7 @@ macro_rules! impl_transpose_shares_bool_to_ba { } }, ); + Ok(()) } } @@ -299,13 +334,22 @@ macro_rules! impl_transpose_shares_bool_to_ba { impl TransposeFrom<&BitDecomposed>> for Vec> { - fn transpose_from(&mut self, src: &BitDecomposed>) { + type Error = LengthError; + + fn transpose_from( + &mut self, + src: &BitDecomposed>, + ) -> Result<(), LengthError> { self.resize($src_cols, AdditiveShare::<$dst_row>::ZERO); - let src = - <&[AdditiveShare; $src_rows]>::try_from(&**src).unwrap(); + let src = <&[AdditiveShare; $src_rows]>::try_from(&**src) + .map_err(|_| LengthError { + expected: $src_rows, + actual: src.len(), + })?; let dst = <&mut [AdditiveShare<$dst_row>; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); + Ok(()) } } }; @@ -327,7 +371,12 @@ macro_rules! impl_transpose_shares_ba_fn_to_bool { impl TransposeFrom<&dyn Fn(usize) -> AdditiveShare<$src_row>> for [AdditiveShare; $src_cols] { - fn transpose_from(&mut self, src: &dyn Fn(usize) -> AdditiveShare<$src_row>) { + type Error = Infallible; + + fn transpose_from( + &mut self, + src: &dyn Fn(usize) -> AdditiveShare<$src_row>, + ) -> Result<(), Infallible> { // Transpose left share do_transpose_16( $src_rows / 16, @@ -368,18 +417,24 @@ macro_rules! impl_transpose_shares_ba_fn_to_bool { } }, ); + Ok(()) } } impl TransposeFrom<&dyn Fn(usize) -> AdditiveShare<$src_row>> for BitDecomposed> { - fn transpose_from(&mut self, src: &dyn Fn(usize) -> AdditiveShare<$src_row>) { + type Error = Infallible; + + fn transpose_from( + &mut self, + src: &dyn Fn(usize) -> AdditiveShare<$src_row>, + ) -> Result<(), Infallible> { self.resize($src_cols, AdditiveShare::::ZERO); let dst = <&mut [AdditiveShare; $src_cols]>::try_from(&mut **self) .unwrap(); - dst.transpose_from(src); + dst.transpose_from(src) } } }; @@ -397,7 +452,12 @@ macro_rules! impl_transpose_shares_bool_to_bool { impl TransposeFrom<&[AdditiveShare; $src_rows]> for [AdditiveShare; $src_cols] { - fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + type Error = Infallible; + + fn transpose_from( + &mut self, + src: &[AdditiveShare; $src_rows], + ) -> Result<(), Infallible> { // Transpose left share do_transpose_16( $src_rows / 16, @@ -438,6 +498,7 @@ macro_rules! impl_transpose_shares_bool_to_bool { } }, ); + Ok(()) } } @@ -450,21 +511,34 @@ macro_rules! impl_transpose_shares_bool_to_bool { impl TransposeFrom<&[AdditiveShare]> for BitDecomposed> { - fn transpose_from(&mut self, src: &[AdditiveShare]) { - let src = <&[AdditiveShare; $src_rows]>::try_from(src).unwrap(); - self.transpose_from(src); + type Error = LengthError; + fn transpose_from( + &mut self, + src: &[AdditiveShare], + ) -> Result<(), LengthError> { + let src = <&[AdditiveShare; $src_rows]>::try_from(src) + .map_err(|_| LengthError { + expected: $src_rows, + actual: src.len(), + })?; + self.transpose_from(src).unwrap_infallible(); + Ok(()) } } impl TransposeFrom<&[AdditiveShare; $src_rows]> for BitDecomposed> { - fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + type Error = Infallible; + fn transpose_from( + &mut self, + src: &[AdditiveShare; $src_rows], + ) -> Result<(), Infallible> { self.resize($src_cols, AdditiveShare::::ZERO); let dst = <&mut [AdditiveShare; $src_cols]>::try_from(&mut **self) .unwrap(); - dst.transpose_from(src); + dst.transpose_from(src) } } }; @@ -603,7 +677,7 @@ mod tests { #[test] fn transpose_8x8() { - test_transpose_array::(super::transpose_8x8); + test_transpose_array::(|m| super::transpose_8x8(m)); } #[test] @@ -621,12 +695,12 @@ mod tests { where SR: PartialEq + SharedValue + ArrayAccess, DR: PartialEq + SharedValue + ArrayAccess, - [DR; DM]: for<'a> TransposeFrom<&'a [SR; SM]>, + [DR; DM]: for<'a> TransposeFrom<&'a [SR; SM], Error = Infallible>, Standard: Distribution, { let t_impl = |src| { let mut dst = [DR::ZERO; DM]; - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); dst }; @@ -671,11 +745,12 @@ mod tests { Boolean: Vectorizable, >::Array: ArrayAccess, DR: SharedValue + ArrayAccess, - [AdditiveShare; DM]: for<'a> TransposeFrom<&'a [AdditiveShare; SM]>, + [AdditiveShare; DM]: + for<'a> TransposeFrom<&'a [AdditiveShare; SM], Error = Infallible>, { let t_impl = |src| { let mut dst = [AdditiveShare::::ZERO; DM]; - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); dst }; @@ -737,11 +812,11 @@ mod tests { Boolean: Vectorizable, >::Array: ArrayAccess, [AdditiveShare; DM]: - for<'a> TransposeFrom<&'a [AdditiveShare; SM]>, + for<'a> TransposeFrom<&'a [AdditiveShare; SM], Error = Infallible>, { let t_impl = |src| { let mut dst = [AdditiveShare::::ZERO; DM]; - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); dst };