From 6301f533929c19eda86f6ba4ea05b65c218712e4 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 11 Mar 2024 17:46:10 -0700 Subject: [PATCH 1/3] Bit matrix transposes for converting to/from vectorized shares --- Cargo.toml | 4 + ipa-core/Cargo.toml | 5 + ipa-core/benches/transpose.rs | 113 +++ ipa-core/src/ff/boolean_array.rs | 21 + ipa-core/src/secret_sharing/decomposed.rs | 20 +- ipa-core/src/secret_sharing/mod.rs | 6 +- .../replicated/semi_honest/additive_share.rs | 22 + ipa-core/src/secret_sharing/vector/array.rs | 85 +- ipa-core/src/secret_sharing/vector/impls.rs | 23 +- ipa-core/src/secret_sharing/vector/mod.rs | 4 + .../src/secret_sharing/vector/transpose.rs | 800 ++++++++++++++++++ 11 files changed, 1067 insertions(+), 36 deletions(-) create mode 100644 ipa-core/benches/transpose.rs create mode 100644 ipa-core/src/secret_sharing/vector/transpose.rs diff --git a/Cargo.toml b/Cargo.toml index 0a09034d9..a0a6301e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,10 @@ members = ["ipa-core", "ipa-macros"] incremental = true lto = "thin" +[profile.release-max] +inherits = "release" +codegen-units = 1 + [profile.bench-dhat] inherits = "bench" incremental = true diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 6b064d09f..2eca7d8e7 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -216,6 +216,11 @@ path = "benches/oneshot/ipa.rs" harness = false required-features = ["enable-benches", "descriptive-gate"] +[[bench]] +name = "transpose" +harness = false +required-features = ["enable-benches"] + [[test]] name = "helper_networks" required-features = [ diff --git a/ipa-core/benches/transpose.rs b/ipa-core/benches/transpose.rs new file mode 100644 index 000000000..9350e9ccc --- /dev/null +++ b/ipa-core/benches/transpose.rs @@ -0,0 +1,113 @@ +//! Benchmarks for bit matrix transpose operations. +//! +//! Some of these routines run very fast, which doesn't work well with the default Criterion settings. +//! The warm up time and measurement time are reduced, because the defaults will produce a very large +//! number of samples (which in turn will take Criterion a long time to analyze). +//! +//! Some of the benchmark routines are looped so that the running time is long enough for Criterion +//! to measure reliably. When too short, Criterion complains that some measurements take zero time. +//! Presumably, the behavior of the underlying system clock is a contributing factor here. +//! +//! There is also a panic in the `plotters` crate used by Criterion to produce HTML reports that can +//! occur with very fast-running routines. This can be worked around by passing the `-n` option to +//! Criterion to disable HTML reports. + +use std::{array, iter::repeat_with, time::Duration}; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; +use ipa_core::{ + ff::boolean_array::BA64, + secret_sharing::{ + vector::{transpose_16x16, transpose_8x8}, + SharedValue, TransposeFrom, + }, +}; +use rand::{ + distributions::{Distribution, Standard}, + thread_rng, Rng, +}; + +fn random_array() -> [T; N] +where + Standard: Distribution, +{ + let mut rng = thread_rng(); + array::from_fn(|_| rng.gen()) +} + +struct Params { + rows: usize, + cols: usize, + iters: usize, +} + +fn do_benchmark( + c: &mut Criterion, + Params { rows, cols, iters }: Params, + routine: fn(&[T; N]) -> O, +) where + Standard: Distribution, +{ + let mut group = c.benchmark_group(format!("{rows}x{cols}")); + group.warm_up_time(Duration::from_millis(200)); + group.measurement_time(Duration::from_millis(200)); + group.throughput(Throughput::Elements((rows * cols * iters) as u64)); + + group.bench_with_input( + BenchmarkId::new("transpose", format!("{iters}x")), + &(), + move |b, _| { + b.iter_batched_ref( + || repeat_with(random_array).take(iters).collect::>(), + |input| input.iter().map(routine).count(), + BatchSize::SmallInput, + ) + }, + ); + group.finish(); +} + +fn bench_8x8(c: &mut Criterion) { + do_benchmark( + c, + Params { + rows: 8, + cols: 8, + iters: 100, + }, + transpose_8x8, + ); +} + +fn bench_16x16(c: &mut Criterion) { + do_benchmark( + c, + Params { + rows: 16, + cols: 16, + iters: 50, + }, + transpose_16x16, + ); +} + +fn bench_64x64(c: &mut Criterion) { + do_benchmark( + c, + Params { + rows: 64, + cols: 64, + iters: 1, + }, + |src| { + let mut dst = array::from_fn(|_| BA64::ZERO); + dst.transpose_from(src); + dst + }, + ); +} + +criterion_group!(benches_8x8, bench_8x8); +criterion_group!(benches_16x16, bench_16x16); +criterion_group!(benches_64x64, bench_64x64); +criterion_main!(benches_8x8, benches_16x16, benches_64x64); diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index 4137d0eeb..8b421098b 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -269,6 +269,18 @@ macro_rules! boolean_array_impl { impl $name { #[cfg(all(test, unit_test))] const STORE_LEN: usize = bitvec::mem::elts::($bits); + + #[inline] + #[must_use] + pub fn as_raw_slice(&self) -> &[u8] { + self.0.as_raw_slice() + } + + #[inline] + #[must_use] + pub fn as_raw_mut_slice(&mut self) -> &mut [u8] { + self.0.as_raw_mut_slice() + } } impl ArrayAccess for $name { @@ -736,12 +748,21 @@ boolean_array_impl_small!(boolean_array_5, BA5, 5, fallible); boolean_array_impl_small!(boolean_array_6, BA6, 6, fallible); boolean_array_impl_small!(boolean_array_7, BA7, 7, fallible); boolean_array_impl_small!(boolean_array_8, BA8, 8, infallible); +boolean_array_impl_small!(boolean_array_16, BA16, 16, infallible); boolean_array_impl_small!(boolean_array_20, BA20, 20, fallible); boolean_array_impl_small!(boolean_array_32, BA32, 32, infallible); boolean_array_impl_small!(boolean_array_64, BA64, 64, infallible); boolean_array_impl_small!(boolean_array_112, BA112, 112, infallible); boolean_array_impl!(boolean_array_256, BA256, 256, infallible); +impl Vectorizable<256> for BA64 { + type Array = StdArray; +} + +impl Vectorizable<256> for BA256 { + type Array = StdArray; +} + // used to convert into Fp25519 impl From<(u128, u128)> for BA256 { fn from(value: (u128, u128)) -> Self { diff --git a/ipa-core/src/secret_sharing/decomposed.rs b/ipa-core/src/secret_sharing/decomposed.rs index 83555d3ab..f15455a87 100644 --- a/ipa-core/src/secret_sharing/decomposed.rs +++ b/ipa-core/src/secret_sharing/decomposed.rs @@ -1,4 +1,8 @@ -use std::{fmt::Debug, ops::Deref, slice}; +use std::{ + fmt::Debug, + ops::{Deref, DerefMut}, + slice, +}; use crate::{ error::Error, @@ -12,7 +16,7 @@ pub struct BitDecomposed { } impl BitDecomposed { - const MAX: usize = 64; + const MAX: usize = 256; /// Create a new value from an iterator. /// # Panics @@ -99,6 +103,12 @@ impl BitDecomposed { } } +impl BitDecomposed { + pub fn resize(&mut self, new_len: usize, value: S) { + self.bits.resize(new_len, value); + } +} + impl TryFrom> for BitDecomposed { type Error = Error; fn try_from(bits: Vec) -> Result { @@ -148,6 +158,12 @@ impl Deref for BitDecomposed { } } +impl DerefMut for BitDecomposed { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.bits + } +} + impl IntoIterator for BitDecomposed { type Item = S; type IntoIter = as IntoIterator>::IntoIter; diff --git a/ipa-core/src/secret_sharing/mod.rs b/ipa-core/src/secret_sharing/mod.rs index 64365bd51..eb1838f0e 100644 --- a/ipa-core/src/secret_sharing/mod.rs +++ b/ipa-core/src/secret_sharing/mod.rs @@ -3,7 +3,10 @@ pub mod replicated; mod decomposed; mod into_shares; mod scheme; +#[cfg(not(feature = "enable-benches"))] mod vector; +#[cfg(feature = "enable-benches")] +pub mod vector; use std::{ fmt::Debug, @@ -20,7 +23,8 @@ use rand::{ }; pub use scheme::{Bitwise, Linear, LinearRefOps, SecretSharing}; pub use vector::{ - FieldArray, FieldSimd, FieldVectorizable, SharedValueArray, StdArray, Vectorizable, + FieldArray, FieldSimd, FieldVectorizable, SharedValueArray, StdArray, TransposeFrom, + Vectorizable, }; #[cfg(any(test, feature = "test-fixture", feature = "cli"))] diff --git a/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs b/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs index 817c870cd..c37bee4b1 100644 --- a/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/semi_honest/additive_share.rs @@ -102,6 +102,28 @@ impl, const N: usize> AdditiveShare { pub fn right_arr(&self) -> &>::Array { &self.1 } + + pub(in crate::secret_sharing) fn left_arr_mut(&mut self) -> &mut >::Array { + &mut self.0 + } + + pub(in crate::secret_sharing) fn right_arr_mut( + &mut self, + ) -> &mut >::Array { + &mut self.1 + } + + pub fn into_arr_tuple(self) -> (>::Array, >::Array) { + let Self(left, right) = self; + (left, right) + } + + pub fn from_fns V, RF: FnMut(usize) -> V>(lf: LF, rf: RF) -> Self { + Self( + >::Array::from_fn(lf), + >::Array::from_fn(rf), + ) + } } impl AdditiveShare diff --git a/ipa-core/src/secret_sharing/vector/array.rs b/ipa-core/src/secret_sharing/vector/array.rs index 13f476b2e..ce3c22f14 100644 --- a/ipa-core/src/secret_sharing/vector/array.rs +++ b/ipa-core/src/secret_sharing/vector/array.rs @@ -6,12 +6,12 @@ use std::{ }; use generic_array::{ArrayLength, GenericArray}; -use typenum::U32; +use typenum::{U16, U256, U32, U64}; use crate::{ error::LengthError, ff::{Field, Fp32BitPrime, Serializable}, - protocol::prss::{FromRandom, FromRandomU128}, + protocol::prss::FromRandom, secret_sharing::{FieldArray, Sendable, SharedValue, SharedValueArray}, }; @@ -277,14 +277,24 @@ impl FromRandom for StdArray { } } -impl FromRandom for StdArray { - type SourceLength = U32; +macro_rules! impl_from_random { + ($value_ty:ty, $width:expr, $source_len:ty, $item_len:expr) => { + impl FromRandom for StdArray<$value_ty, $width> { + type SourceLength = $source_len; - fn from_random(src: GenericArray) -> Self { - Self(array::from_fn(|i| Fp32BitPrime::from_random_u128(src[i]))) - } + fn from_random(src: GenericArray) -> Self { + Self(array::from_fn(|i| { + <$value_ty>::from_random( + GenericArray::from_slice(&src[$item_len * i..$item_len * (i + 1)]).clone(), + ) + })) + } + } + }; } +impl_from_random!(Fp32BitPrime, 32, U32, 1); + impl Serializable for StdArray { type Size = ::Size; type DeserializationError = ::DeserializationError; @@ -298,34 +308,45 @@ impl Serializable for StdArray { } } -impl Serializable for StdArray -where - V: SharedValue, - ::Size: Mul, - <::Size as Mul>::Output: ArrayLength, -{ - type Size = <::Size as Mul>::Output; - type DeserializationError = ::DeserializationError; - - fn serialize(&self, buf: &mut GenericArray) { - let sz: usize = (::BITS / 8).try_into().unwrap(); - for i in 0..32 { - self.0[i].serialize( - GenericArray::try_from_mut_slice(&mut buf[sz * i..sz * (i + 1)]).unwrap(), - ); - } - } - - fn deserialize(buf: &GenericArray) -> Result { - let sz: usize = (::BITS / 8).try_into().unwrap(); - let mut res = [V::ZERO; 32]; - for i in 0..32 { - res[i] = V::deserialize(GenericArray::from_slice(&buf[sz * i..sz * (i + 1)]))?; +macro_rules! impl_serializable { + ($width:expr, $width_ty:ty) => { + impl Serializable for StdArray + where + V: SharedValue, + ::Size: Mul<$width_ty>, + <::Size as Mul<$width_ty>>::Output: ArrayLength, + { + type Size = <::Size as Mul<$width_ty>>::Output; + type DeserializationError = ::DeserializationError; + + fn serialize(&self, buf: &mut GenericArray) { + let sz: usize = (::BITS / 8).try_into().unwrap(); + for i in 0..$width { + self.0[i].serialize( + GenericArray::try_from_mut_slice(&mut buf[sz * i..sz * (i + 1)]).unwrap(), + ); + } + } + + fn deserialize( + buf: &GenericArray, + ) -> Result { + let sz: usize = (::BITS / 8).try_into().unwrap(); + let mut res = [V::ZERO; $width]; + for i in 0..$width { + res[i] = V::deserialize(GenericArray::from_slice(&buf[sz * i..sz * (i + 1)]))?; + } + Ok(StdArray(res)) + } } - Ok(StdArray(res)) - } + }; } +impl_serializable!(16, U16); +impl_serializable!(32, U32); +impl_serializable!(64, U64); +impl_serializable!(256, U256); + #[cfg(all(test, unit_test))] mod test { use std::iter; diff --git a/ipa-core/src/secret_sharing/vector/impls.rs b/ipa-core/src/secret_sharing/vector/impls.rs index e29d8712b..2ada6571c 100644 --- a/ipa-core/src/secret_sharing/vector/impls.rs +++ b/ipa-core/src/secret_sharing/vector/impls.rs @@ -3,7 +3,7 @@ use crate::{ ff::{ boolean::Boolean, - boolean_array::{BA20, BA256, BA3, BA32, BA5, BA64, BA8}, + boolean_array::{BA16, BA20, BA256, BA3, BA32, BA5, BA64, BA8}, Fp32BitPrime, }, secret_sharing::{ @@ -43,7 +43,28 @@ macro_rules! boolean_vector { boolean_vector!(3, BA3); boolean_vector!(5, BA5); boolean_vector!(8, BA8); +boolean_vector!(16, BA16); boolean_vector!(20, BA20); boolean_vector!(32, BA32); boolean_vector!(64, BA64); boolean_vector!(256, BA256); + +/// Expands to the type for storing vectorized shares of a multi-bit boolean value. +/// +/// The "width" is the bit width of the value for each record. For example, a breakdown key might +/// have an 8-bit width. +/// +/// The "dimension" is the vectorization dimension, which is a number of records. For example, +/// there might be no vectorization (dimension = 1), or computation might be vectorized over +/// 256 records (dimension = 256). +/// +/// When the dimension is one, `BoolVector!(width, 1)` expands to an `AdditiveShare` of the Boolean +/// array type with the requested width. +/// +/// When the dimension is greater than one, `BoolVector!(width, dim)` expands to +/// `BitDecomposed>`. +#[macro_export] +macro_rules! BoolVector { + (16, 1) => { $crate::secret_sharing::replicated::semi_honest::AdditiveShare<$crate::ff::BA16> }; + ($width:expr, $dim:expr) => { BitDecomposed<$crate::secret_sharing::replicated::semi_honest::AdditiveShare<$crate::ff::boolean::Boolean, $dim>> }; +} diff --git a/ipa-core/src/secret_sharing/vector/mod.rs b/ipa-core/src/secret_sharing/vector/mod.rs index e276d7d22..4e4575af3 100644 --- a/ipa-core/src/secret_sharing/vector/mod.rs +++ b/ipa-core/src/secret_sharing/vector/mod.rs @@ -51,6 +51,10 @@ mod array; mod impls; mod traits; +mod transpose; pub use array::StdArray; pub use traits::{FieldArray, FieldSimd, FieldVectorizable, SharedValueArray, Vectorizable}; +pub use transpose::TransposeFrom; +#[cfg(feature = "enable-benches")] +pub use transpose::{transpose_16x16, transpose_8x8}; diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs new file mode 100644 index 000000000..2ee1f9757 --- /dev/null +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -0,0 +1,800 @@ +//! Bit matrix transpose operations. +//! +//! These are used to convert data to and from vectorized representation. +//! +//! For example, if there is a 32-bit value associated with each record, the natural storage +//! representation for those values is something like a `Vec`, with the vector indexed by +//! records. +//! +//! In vectorized code, we instead want to use something like `BitDecomposed`, where each +//! `BA16` holds the value of a particular bit for each of 16 records, and the `BitDecomposed` +//! (which is just a wrapper around a `Vec`) is indexed by bits. +//! +//! To convert between these representations we need to transpose a 16x32 bit matrix into a 32x16 +//! bit matrix. +//! +//! This module stores bytes and bits are in little-endian order. Less significant bytes store data +//! closer to the top or left of the matrix. Within each byte, the first (leftmost) column is in +//! the least significant bit, and the last (rightmost) column is in the most significant bit. +//! +//! These implementations are somewhat optimized; it is certainly possible to optimize further, but +//! that is only worthwhile if profiling indicates this is a significant contributor to our overall +//! performance. Also, for some functions, the generated code is much better with `codegen-units = +//! 1` than with the default codegen-units (unfortunately, `codegen-units = 1` significantly +//! increases compile time). See [rust issue 47745](https://github.com/rust-lang/rust/issues/47745). +//! +//! Some possibilities for further optimization: +//! * Use codegen-units = 1 or figure out how to get comparable codegen without it. +//! * Avoid cost of zero-initializing the transpose destination. +//! * Use Rust's portable SIMD abstraction (not yet stable as of early 2024), or code directly +//! against platform SIMD intrinsics. +//! +//! For more ideas on optimizing bit matrix transposes in rust, see: +//! * +//! * + +// This rule throws false positives on "MxN". +#![allow(clippy::doc_markdown)] + +use std::array; + +use crate::{ + ff::{ + boolean::Boolean, + boolean_array::{BA16, BA256, BA64}, + }, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + BitDecomposed, SharedValue, StdArray, + }, +}; + +/// Trait for overwriting a value with the transpose of a source value. +pub trait TransposeFrom { + fn transpose_from(&mut self, src: T); + + fn transposed_from(src: T) -> Self + where + Self: Default, + { + let mut dst = Self::default(); + dst.transpose_from(src); + dst + } +} + +/// 8x8 bit matrix transpose. +// +// From Hacker's Delight (2nd edition), Figure 7-6. +#[cfg(any(all(test, unit_test), feature = "enable-benches"))] +#[allow(clippy::trivially_copy_pass_by_ref)] // Keeps interface consistent with other functions. +#[inline] +pub fn transpose_8x8(x: &[u8; 8]) -> [u8; 8] { + let mut x = u64::from_le_bytes(*x); + + x = x & 0xaa55_aa55_aa55_aa55 + | (x & 0x00aa_00aa_00aa_00aa) << 7 + | (x >> 7) & 0x00aa_00aa_00aa_00aa; + + x = x & 0xcccc_3333_cccc_3333 + | (x & 0x0000_cccc_0000_cccc) << 14 + | (x >> 14) & 0x0000_cccc_0000_cccc; + + x = x & 0xf0f0_f0f0_0f0f_0f0f + | (x & 0x0000_0000_f0f0_f0f0) << 28 + | (x >> 28) & 0x0000_0000_f0f0_f0f0; + + x.to_le_bytes() +} + +/// 16x16 bit matrix transpose. +// +// Loosely based on Hacker's Delight (2nd edition), Figure 7-6. +#[inline] +pub fn transpose_16x16(src: &[u8; 32]) -> [u8; 32] { + let x: [u64; 4] = + array::from_fn(|i| u64::from_le_bytes(src[8 * i..8 * (i + 1)].try_into().unwrap())); + + let mut y0 = [0u64; 4]; + let s0 = 15; + let mut y1 = [0u64; 4]; + let s1 = 30; + for i in 0..4 { + y0[i] = x[i] & 0xaaaa_5555_aaaa_5555 + | (x[i] & 0x0000_aaaa_0000_aaaa) << s0 + | (x[i] >> s0) & 0x0000_aaaa_0000_aaaa; + + y1[i] = y0[i] & 0xcccc_cccc_3333_3333 + | (y0[i] & 0x0000_0000_cccc_cccc) << s1 + | (y0[i] >> s1) & 0x0000_0000_cccc_cccc; + } + + let y1_swp = [y1[1], y1[0], y1[3], y1[2]]; + let m2a = [ + 0x0f0f_0f0f_0f0f_0f0f, + 0xf0f0_f0f0_f0f0_f0f0, + 0x0f0f_0f0f_0f0f_0f0f, + 0xf0f0_f0f0_f0f0_f0f0, + ]; + let m2b = [0xf0f0_f0f0_f0f0_f0f0, 0, 0xf0f0_f0f0_f0f0_f0f0, 0]; + let m2c = [0, 0xf0f0_f0f0_f0f0_f0f0, 0, 0xf0f0_f0f0_f0f0_f0f0]; + let s2 = 4; + let mut y2 = [0u64; 4]; + for i in 0..4 { + y2[i] = y1[i] & m2a[i] | (y1_swp[i] << s2) & m2b[i] | (y1_swp[i] & m2c[i]) >> s2; + } + + let mut y3 = [0u64; 4]; + for i in 0..2 { + y3[i] = y2[i] & 0x00ff_00ff_00ff_00ff | (y2[i + 2] & 0x00ff_00ff_00ff_00ff) << 8; + } + for i in 0..2 { + y3[i + 2] = (y2[i] & 0xff00_ff00_ff00_ff00) >> 8 | y2[i + 2] & 0xff00_ff00_ff00_ff00; + } + + let mut dst = [0u8; 32]; + for i in 0..4 { + *<&mut [u8; 8] as TryFrom<&mut [u8]>>::try_from(&mut dst[8 * i..8 * (i + 1)]).unwrap() = + y3[i].to_le_bytes(); + } + dst +} + +// Degenerate transposes. + +impl TransposeFrom> for Vec> { + fn transpose_from(&mut self, src: AdditiveShare) { + *self = vec![src]; + } +} + +impl TransposeFrom>> for Vec { + fn transpose_from(&mut self, src: Vec>) { + *self = vec![src.iter().map(Boolean::from_array).collect::()]; + } +} + +/// Perform a larger transpose using an 16x16 kernel. +/// +/// Matrix height and width must be multiples of 16. +#[inline] +fn do_transpose_16 [u8; 32], DF: FnMut(usize, usize, [u8; 32])>( + rows_div16: usize, + cols_div16: usize, + read_src: SF, + mut write_dst: DF, +) { + for i in 0..rows_div16 { + for j in 0..cols_div16 { + let m = read_src(i, j); + let m_t = transpose_16x16(&m); + write_dst(j, i, m_t); + } + } +} + +/// Implement a transpose of a MxN bit matrix represented as `[BA{N}; {M}]` into a NxM bit matrix +/// represented as `[BA{M}; {N}]`. +/// +/// The invocation looks like `impl_transpose_ba_to_ba!(BA, BA, , )`. e.g. for MxN = 16x64, +/// `impl_transpose_ba_to_ba!(BA16, BA64, 16, 64)`. Or to put it differently, write the MxN dimensions +/// twice, first with BA in front, and then without. +macro_rules! impl_transpose_ba_to_ba { + ($dst_row:ty, $src_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { + impl TransposeFrom<&[$src_row; $src_rows]> for [$dst_row; $src_cols] { + fn transpose_from(&mut self, src: &[$src_row; $src_rows]) { + do_transpose_16( + $src_rows / 16, + $src_cols / 16, + |i, j| { + let mut d = [0u8; 32]; + for k in 0..16 { + d[2 * k..2 * (k + 1)].copy_from_slice( + &src[16 * i + k].as_raw_slice()[2 * j..2 * (j + 1)], + ); + } + d + }, + |i, j, d| { + for k in 0..16 { + self[16 * i + k].as_raw_mut_slice()[2 * j..2 * (j + 1)] + .copy_from_slice(&d[2 * k..2 * (k + 1)]); + } + }, + ); + } + } + + #[cfg(all(test, unit_test))] + #[test] + fn $test_fn() { + tests::test_transpose_ba_to_ba::<$dst_row, $src_row, $src_rows, $src_cols>(); + } + + impl TransposeFrom<&BitDecomposed<$src_row>> for Vec<$dst_row> { + fn transpose_from(&mut self, src: &BitDecomposed<$src_row>) { + self.resize($src_cols, <$dst_row>::ZERO); + let src = <&[$src_row; $src_rows]>::try_from(&**src).unwrap(); + let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); + dst.transpose_from(src); + } + } + + impl TransposeFrom<&[$src_row; $src_rows]> for Vec<$dst_row> { + fn transpose_from(&mut self, src: &[$src_row; $src_rows]) { + self.resize($src_cols, <$dst_row>::ZERO); + let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); + dst.transpose_from(src); + } + } + }; +} + +impl_transpose_ba_to_ba!(BA16, BA64, 16, 64, test_transpose_ba_16x64); +impl_transpose_ba_to_ba!(BA64, BA64, 64, 64, test_transpose_ba_64x64); +impl_transpose_ba_to_ba!(BA256, BA64, 256, 64, test_transpose_ba_256x64); +impl_transpose_ba_to_ba!(BA256, BA256, 256, 256, test_transpose_ba_256x256); + +/// Implement a transpose of a MxN matrix of secret-shared bits represented as +/// `[AdditiveShare; ]` into a NxM bit matrix represented as `[AdditiveShare>; N]`. +/// +/// For MxN = 256x64, the invocation looks like `impl_transpose_bool_to_ba!(BA256, 256, 64)`. +macro_rules! impl_transpose_shares_bool_to_ba { + ($dst_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { + impl TransposeFrom<&[AdditiveShare; $src_rows]> + for [AdditiveShare<$dst_row>; $src_cols] + { + fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + // Transpose left share + do_transpose_16( + $src_rows / 16, + $src_cols / 16, + |i, j| { + let mut d = [0u8; 32]; + for k in 0..16 { + d[2 * k..2 * (k + 1)].copy_from_slice( + &src[16 * i + k].left_arr().as_raw_slice()[2 * j..2 * (j + 1)], + ); + } + d + }, + |i, j, d| { + for k in 0..16 { + self[16 * i + k].left_arr_mut().0[0].as_raw_mut_slice() + [2 * j..2 * (j + 1)] + .copy_from_slice(&d[2 * k..2 * (k + 1)]); + } + }, + ); + // Transpose right share + do_transpose_16( + $src_rows / 16, + $src_cols / 16, + |i, j| { + let mut d = [0u8; 32]; + for k in 0..16 { + d[2 * k..2 * (k + 1)].copy_from_slice( + &src[16 * i + k].right_arr().as_raw_slice()[2 * j..2 * (j + 1)], + ); + } + d + }, + |i, j, d| { + for k in 0..16 { + self[16 * i + k].right_arr_mut().0[0].as_raw_mut_slice() + [2 * j..2 * (j + 1)] + .copy_from_slice(&d[2 * k..2 * (k + 1)]); + } + }, + ); + } + } + + #[cfg(all(test, unit_test))] + #[test] + fn $test_fn() { + tests::test_transpose_shares_bool_to_ba::<$dst_row, $src_rows, $src_cols>(); + } + + impl TransposeFrom<&BitDecomposed>> + for Vec> + { + fn transpose_from(&mut self, src: &BitDecomposed>) { + self.resize($src_cols, AdditiveShare::<$dst_row>::ZERO); + let src = + <&[AdditiveShare; $src_rows]>::try_from(&**src).unwrap(); + let dst = + <&mut [AdditiveShare<$dst_row>; $src_cols]>::try_from(&mut **self).unwrap(); + dst.transpose_from(src); + } + } + }; +} + +impl_transpose_shares_bool_to_ba!(BA256, 256, 16, test_transpose_shares_bool_to_ba_256x16); +impl_transpose_shares_bool_to_ba!(BA256, 256, 64, test_transpose_shares_bool_to_ba_256x64); +impl_transpose_shares_bool_to_ba!(BA256, 256, 256, test_transpose_shares_bool_to_ba_256x256); + +/// Implement a transpose of a MxN matrix of secret-shared bits accessed via +/// `Fn(usize) -> AdditiveShare` into a NxM bit matrix represented as `[AdditiveShare; N]`. +/// +/// For MxN = 256x64, the invocation looks like `impl_transpose_shares_ba_fn_to_bool!(BA64, 256, 64)`. +macro_rules! impl_transpose_shares_ba_fn_to_bool { + ($src_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { + // This function-based access to the source is useful when the source is not contiguous in + // memory (i.e. accessing the match key for each input record). However, it does not + // optimize as well as the other implementations (even without the dynamic dispatch). + impl TransposeFrom<&dyn Fn(usize) -> AdditiveShare<$src_row>> + for [AdditiveShare; $src_cols] + { + fn transpose_from(&mut self, src: &dyn Fn(usize) -> AdditiveShare<$src_row>) { + // Transpose left share + do_transpose_16( + $src_rows / 16, + $src_cols / 16, + |i, j| { + let mut d = [0u8; 32]; + for k in 0..16 { + d[2 * k..2 * (k + 1)].copy_from_slice( + &src(16 * i + k).left().as_raw_slice()[2 * j..2 * (j + 1)], + ); + } + d + }, + |i, j, d| { + for k in 0..16 { + self[16 * i + k].left_arr_mut().as_raw_mut_slice()[2 * j..2 * (j + 1)] + .copy_from_slice(&d[2 * k..2 * (k + 1)]); + } + }, + ); + // Transpose right share + do_transpose_16( + $src_rows / 16, + $src_cols / 16, + |i, j| { + let mut d = [0u8; 32]; + for k in 0..16 { + d[2 * k..2 * (k + 1)].copy_from_slice( + &src(16 * i + k).right().as_raw_slice()[2 * j..2 * (j + 1)], + ); + } + d + }, + |i, j, d| { + for k in 0..16 { + self[16 * i + k].right_arr_mut().as_raw_mut_slice()[2 * j..2 * (j + 1)] + .copy_from_slice(&d[2 * k..2 * (k + 1)]); + } + }, + ); + } + } + + impl TransposeFrom<&dyn Fn(usize) -> AdditiveShare<$src_row>> + for BitDecomposed> + { + fn transpose_from(&mut self, src: &dyn Fn(usize) -> AdditiveShare<$src_row>) { + self.resize($src_cols, AdditiveShare::::ZERO); + let dst = + <&mut [AdditiveShare; $src_cols]>::try_from(&mut **self) + .unwrap(); + dst.transpose_from(src); + } + } + }; +} + +impl_transpose_shares_ba_fn_to_bool!(BA64, 16, 64, test_transpose_shares_ba_fn_to_bool_16x64); +impl_transpose_shares_ba_fn_to_bool!(BA64, 256, 64, test_transpose_shares_ba_fn_to_bool_256x64); + +/// Implement a transpose of a MxN matrix of secret-shared bits represented as +/// `[AdditiveShare; ]` into a NxM bit matrix represented as `[AdditiveShare>; N]`. +/// +/// For MxN = 256x64, the invocation looks like `impl_transpose_bool_to_bool!(BA64, 256, 64)`. +macro_rules! impl_transpose_shares_bool_to_bool { + ($src_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { + impl TransposeFrom<&[AdditiveShare; $src_rows]> + for [AdditiveShare; $src_cols] + { + fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + // Transpose left share + do_transpose_16( + $src_rows / 16, + $src_cols / 16, + |i, j| { + let mut d = [0u8; 32]; + for k in 0..16 { + d[2 * k..2 * (k + 1)].copy_from_slice( + &src[16 * i + k].left_arr().as_raw_slice()[2 * j..2 * (j + 1)], + ); + } + d + }, + |i, j, d| { + for k in 0..16 { + self[16 * i + k].left_arr_mut().as_raw_mut_slice()[2 * j..2 * (j + 1)] + .copy_from_slice(&d[2 * k..2 * (k + 1)]); + } + }, + ); + // Transpose right share + do_transpose_16( + $src_rows / 16, + $src_cols / 16, + |i, j| { + let mut d = [0u8; 32]; + for k in 0..16 { + d[2 * k..2 * (k + 1)].copy_from_slice( + &src[16 * i + k].right_arr().as_raw_slice()[2 * j..2 * (j + 1)], + ); + } + d + }, + |i, j, d| { + for k in 0..16 { + self[16 * i + k].right_arr_mut().as_raw_mut_slice()[2 * j..2 * (j + 1)] + .copy_from_slice(&d[2 * k..2 * (k + 1)]); + } + }, + ); + } + } + + #[cfg(all(test, unit_test))] + #[test] + fn $test_fn() { + tests::test_transpose_shares_bool_to_bool::<$src_rows, $src_cols>(); + } + + impl TransposeFrom<&[AdditiveShare]> + for BitDecomposed> + { + fn transpose_from(&mut self, src: &[AdditiveShare]) { + let src = <&[AdditiveShare; $src_rows]>::try_from(src).unwrap(); + self.transpose_from(src); + } + } + + impl TransposeFrom<&[AdditiveShare; $src_rows]> + for BitDecomposed> + { + fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + self.resize($src_cols, AdditiveShare::::ZERO); + let dst = + <&mut [AdditiveShare; $src_cols]>::try_from(&mut **self) + .unwrap(); + dst.transpose_from(src); + } + } + }; +} + +impl_transpose_shares_bool_to_bool!(BA64, 16, 64, test_transpose_shares_bool_to_bool_16x64); +impl_transpose_shares_bool_to_bool!(BA64, 64, 64, test_transpose_shares_bool_to_bool_64x64); +impl_transpose_shares_bool_to_bool!(BA64, 256, 64, test_transpose_shares_bool_to_bool_256x64); + +#[cfg(all(test, unit_test))] +mod tests { + // Using `.enumerate()` would just obfuscate the nested for loops verifying transposes. + #![allow(clippy::needless_range_loop)] + + use std::{ + array, + cmp::min, + fmt::Debug, + iter::repeat_with, + ops::{BitAnd, Not, Shl, Shr}, + }; + + use rand::{ + distributions::{Distribution, Standard}, + thread_rng, Rng, + }; + + use super::*; + use crate::{ff::ArrayAccess, secret_sharing::Vectorizable}; + + fn random_array() -> [T; N] + where + Standard: Distribution, + { + let mut rng = thread_rng(); + array::from_fn(|_| rng.gen()) + } + + trait ByteConversion { + type Bytes; + fn into_bytes(self) -> Self::Bytes; + fn from_bytes(bytes: Self::Bytes) -> Self; + } + + impl ByteConversion for [u8; 8] { + type Bytes = Self; + + fn into_bytes(self) -> Self::Bytes { + self + } + + fn from_bytes(bytes: Self::Bytes) -> Self { + bytes + } + } + + macro_rules! impl_byte_conversion { + ([$word:ty; $n_words:expr], [u8; $n_bytes:expr]) => { + impl ByteConversion for [$word; $n_words] { + type Bytes = [u8; $n_bytes]; + + fn into_bytes(self) -> Self::Bytes { + self.into_iter() + .flat_map(<$word>::to_le_bytes) + .collect::>() + .try_into() + .unwrap() + } + + fn from_bytes(bytes: Self::Bytes) -> Self { + const BYTES: usize = $n_bytes / $n_words; + bytes + .chunks_exact(BYTES) + .map(|slice| { + <$word>::from_le_bytes(<[u8; BYTES]>::try_from(slice).unwrap()) + }) + .collect::>() + .try_into() + .unwrap() + } + } + }; + } + + impl_byte_conversion!([u16; 16], [u8; 32]); + impl_byte_conversion!([u32; 32], [u8; 128]); + impl_byte_conversion!([u64; 64], [u8; 512]); + + fn test_transpose_array< + T, // Matrix integer type (e.g. u16 for 16x16) + const N: usize, // Matrix dimension + const NB: usize, // Matrix byte array size + >( + t_impl: fn(&[u8; NB]) -> [u8; NB], + ) where + T: Copy + + Debug + + Default + + PartialEq + + Not + + Shl + + Shr + + BitAnd, + [T; N]: Copy + Debug + PartialEq<[T; N]> + ByteConversion, + { + let zero = T::default(); + let one = !zero >> (N - 1); + + // Identity + let m: [u8; NB] = <[T; N]>::into_bytes(array::from_fn(|i| one << i)); + let m_t = t_impl(&m); + assert_eq!(m_t, m); + + // Anti-diagonal + let m: [u8; NB] = <[T; N]>::into_bytes(array::from_fn(|i| one << (N - 1 - i))); + let m_t = t_impl(&m); + assert_eq!(m_t, m); + + // Lower triangular + let m: [u8; NB] = <[T; N]>::into_bytes(array::from_fn(|i| !zero >> (N - 1 - i))); + let m_t = t_impl(&m); + assert_eq!(<[T; N]>::from_bytes(m_t), array::from_fn(|i| !zero << i)); + + // Random + let m: [u8; NB] = random_array(); + let m_t = t_impl(&m); + let m = <[T; N]>::from_bytes(m); + let m_t = <[T; N]>::from_bytes(m_t); + + for i in 0..N { + for j in 0..N { + assert_eq!((m_t[i] >> j) & one, (m[j] >> i) & one); + } + } + } + + #[test] + fn transpose_8x8() { + test_transpose_array::(super::transpose_8x8); + } + + #[test] + fn transpose_16x16() { + test_transpose_array::(super::transpose_16x16); + } + + // The order of type parameters matches the implementation macro: BA, BA, , + pub(super) fn test_transpose_ba_to_ba< + DR, // Destination row type + SR, // Source row type + const SM: usize, // Source rows (== dest cols) + const DM: usize, // Destination rows (== source cols) + >() + where + SR: PartialEq + SharedValue + ArrayAccess, + DR: PartialEq + SharedValue + ArrayAccess, + [DR; DM]: for<'a> TransposeFrom<&'a [SR; SM]>, + Standard: Distribution, + { + let t_impl = |src| { + let mut dst = [DR::ZERO; DM]; + dst.transpose_from(src); + dst + }; + + let step = min(SM, DM); + let m = array::from_fn(|i| { + let mut v = SR::ZERO; + for j in ((i % DM)..DM).step_by(step) { + v.set(j, Boolean::TRUE); + } + v + }); + let m_t = t_impl(&m); + assert_eq!( + m_t, + array::from_fn(|i| { + let mut v = DR::ZERO; + for j in ((i % SM)..SM).step_by(step) { + v.set(j, Boolean::TRUE); + } + v + }) + ); + + let mut rng = thread_rng(); + let m = repeat_with(|| rng.gen()).take(SM).collect::>(); + let m_t = t_impl(<&[SR; SM]>::try_from(m.as_slice()).unwrap()); + + for i in 0..DM { + for j in 0..SM { + assert_eq!(m_t[i].get(j), m[j].get(i)); + } + } + } + + // The order of type parameters matches the implementation macro: BA, , + pub(super) fn test_transpose_shares_bool_to_ba< + DR, // Destination row type + const SM: usize, // Source rows (== dest cols) + const DM: usize, // Destination rows (== source cols) + >() + where + Boolean: Vectorizable, + >::Array: ArrayAccess, + DR: SharedValue + ArrayAccess, + [AdditiveShare; DM]: for<'a> TransposeFrom<&'a [AdditiveShare; SM]>, + { + let t_impl = |src| { + let mut dst = [AdditiveShare::::ZERO; DM]; + dst.transpose_from(src); + dst + }; + + let step = min(SM, DM); + let m = array::from_fn(|i| { + let mut left = vec![Boolean::FALSE; DM]; + let mut right = vec![Boolean::FALSE; DM]; + for j in ((i % DM)..DM).step_by(step) { + let b = Boolean::from(j % 2 != 0); + left[j] = b; + right[j] = !b; + } + AdditiveShare::new_arr( + >::Array::from_iter(left), + >::Array::from_iter(right), + ) + }); + let m_t = t_impl(&m); + assert_eq!( + m_t, + array::from_fn(|i| { + let mut v = AdditiveShare::::ZERO; + for j in ((i % SM)..SM).step_by(step) { + let b = Boolean::from(j % 2 != 0); + v.set(j, AdditiveShare::new(b, !b)); + } + v + }) + ); + + let mut left_rng = thread_rng(); + let mut right_rng = thread_rng(); + let m = repeat_with(|| AdditiveShare::from_fns(|_| left_rng.gen(), |_| right_rng.gen())) + .take(SM) + .collect::>(); + let m_t = t_impl(<&[AdditiveShare; SM]>::try_from(m.as_slice()).unwrap()); + + for i in 0..DM { + for j in 0..SM { + assert_eq!( + m_t[i].get(j).unwrap().left(), + m[j].left_arr().get(i).unwrap() + ); + assert_eq!( + m_t[i].get(j).unwrap().right(), + m[j].right_arr().get(i).unwrap() + ); + } + } + } + + pub(super) fn test_transpose_shares_bool_to_bool< + const SM: usize, // Source rows (== dest cols) + const DM: usize, // Destination rows (== source cols) + >() + where + Boolean: Vectorizable, + >::Array: ArrayAccess, + Boolean: Vectorizable, + >::Array: ArrayAccess, + [AdditiveShare; DM]: + for<'a> TransposeFrom<&'a [AdditiveShare; SM]>, + { + let t_impl = |src| { + let mut dst = [AdditiveShare::::ZERO; DM]; + dst.transpose_from(src); + dst + }; + + let step = min(SM, DM); + let m = array::from_fn(|i| { + let mut left = vec![Boolean::FALSE; DM]; + let mut right = vec![Boolean::FALSE; DM]; + for j in ((i % DM)..DM).step_by(step) { + let b = Boolean::from(j % 2 != 0); + left[j] = b; + right[j] = !b; + } + AdditiveShare::new_arr( + >::Array::from_iter(left), + >::Array::from_iter(right), + ) + }); + let m_t = t_impl(&m); + assert_eq!( + m_t, + array::from_fn(|i| { + let mut left = vec![Boolean::FALSE; SM]; + let mut right = vec![Boolean::FALSE; SM]; + for j in ((i % SM)..SM).step_by(step) { + let b = Boolean::from(j % 2 != 0); + left[j] = b; + right[j] = !b; + } + AdditiveShare::new_arr( + >::Array::from_iter(left), + >::Array::from_iter(right), + ) + }) + ); + + let mut left_rng = thread_rng(); + let mut right_rng = thread_rng(); + let m = repeat_with(|| AdditiveShare::from_fns(|_| left_rng.gen(), |_| right_rng.gen())) + .take(SM) + .collect::>(); + let m_t = t_impl(<&[AdditiveShare; SM]>::try_from(m.as_slice()).unwrap()); + + for i in 0..DM { + for j in 0..SM { + assert_eq!( + m_t[i].left_arr().get(j).unwrap(), + m[j].left_arr().get(i).unwrap() + ); + assert_eq!( + m_t[i].right_arr().get(j).unwrap(), + m[j].right_arr().get(i).unwrap() + ); + } + } + } +} From 46f1a3a85ac5c15e09868c78171c21858b8b77ac Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 18 Mar 2024 17:04:49 -0700 Subject: [PATCH 2/3] PR feedback --- ipa-core/benches/transpose.rs | 5 +- .../src/secret_sharing/vector/transpose.rs | 151 +++++++++++++----- 2 files changed, 116 insertions(+), 40 deletions(-) diff --git a/ipa-core/benches/transpose.rs b/ipa-core/benches/transpose.rs index 9350e9ccc..e0b872e93 100644 --- a/ipa-core/benches/transpose.rs +++ b/ipa-core/benches/transpose.rs @@ -16,6 +16,7 @@ use std::{array, iter::repeat_with, time::Duration}; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; use ipa_core::{ + error::UnwrapInfallible, ff::boolean_array::BA64, secret_sharing::{ vector::{transpose_16x16, transpose_8x8}, @@ -75,7 +76,7 @@ fn bench_8x8(c: &mut Criterion) { cols: 8, iters: 100, }, - transpose_8x8, + |m| transpose_8x8(m), ); } @@ -101,7 +102,7 @@ fn bench_64x64(c: &mut Criterion) { }, |src| { let mut dst = array::from_fn(|_| BA64::ZERO); - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); dst }, ); diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index 2ee1f9757..8b4145671 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -36,9 +36,12 @@ // This rule throws false positives on "MxN". #![allow(clippy::doc_markdown)] -use std::array; +#[cfg(any(all(test, unit_test), feature = "enable-benches"))] +use std::borrow::Borrow; +use std::{array, convert::Infallible}; use crate::{ + error::{LengthError, UnwrapInfallible}, ff::{ boolean::Boolean, boolean_array::{BA16, BA256, BA64}, @@ -51,15 +54,25 @@ use crate::{ /// Trait for overwriting a value with the transpose of a source value. pub trait TransposeFrom { - fn transpose_from(&mut self, src: T); - - fn transposed_from(src: T) -> Self + type Error; + + /// Overwrite `self` with the transpose of `src`. + /// + /// # Errors + /// If the size of the source and destination are not compatible. + fn transpose_from(&mut self, src: T) -> Result<(), Self::Error>; + + /// Fills a new `Self` with the transpose of `src`. + /// + /// # Errors + /// If the size of the source and destination are not compatible. + fn transposed_from(src: T) -> Result where Self: Default, { let mut dst = Self::default(); - dst.transpose_from(src); - dst + dst.transpose_from(src)?; + Ok(dst) } } @@ -67,10 +80,9 @@ pub trait TransposeFrom { // // From Hacker's Delight (2nd edition), Figure 7-6. #[cfg(any(all(test, unit_test), feature = "enable-benches"))] -#[allow(clippy::trivially_copy_pass_by_ref)] // Keeps interface consistent with other functions. #[inline] -pub fn transpose_8x8(x: &[u8; 8]) -> [u8; 8] { - let mut x = u64::from_le_bytes(*x); +pub fn transpose_8x8>(x: B) -> [u8; 8] { + let mut x = u64::from_le_bytes(*x.borrow()); x = x & 0xaa55_aa55_aa55_aa55 | (x & 0x00aa_00aa_00aa_00aa) << 7 @@ -143,14 +155,20 @@ pub fn transpose_16x16(src: &[u8; 32]) -> [u8; 32] { // Degenerate transposes. impl TransposeFrom> for Vec> { - fn transpose_from(&mut self, src: AdditiveShare) { + type Error = Infallible; + + fn transpose_from(&mut self, src: AdditiveShare) -> Result<(), Infallible> { *self = vec![src]; + Ok(()) } } impl TransposeFrom>> for Vec { - fn transpose_from(&mut self, src: Vec>) { + type Error = Infallible; + + fn transpose_from(&mut self, src: Vec>) -> Result<(), Infallible> { *self = vec![src.iter().map(Boolean::from_array).collect::()]; + Ok(()) } } @@ -182,7 +200,9 @@ fn do_transpose_16 [u8; 32], DF: FnMut(usize, usize, [u8 macro_rules! impl_transpose_ba_to_ba { ($dst_row:ty, $src_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { impl TransposeFrom<&[$src_row; $src_rows]> for [$dst_row; $src_cols] { - fn transpose_from(&mut self, src: &[$src_row; $src_rows]) { + type Error = Infallible; + + fn transpose_from(&mut self, src: &[$src_row; $src_rows]) -> Result<(), Infallible> { do_transpose_16( $src_rows / 16, $src_cols / 16, @@ -202,6 +222,7 @@ macro_rules! impl_transpose_ba_to_ba { } }, ); + Ok(()) } } @@ -212,19 +233,27 @@ macro_rules! impl_transpose_ba_to_ba { } impl TransposeFrom<&BitDecomposed<$src_row>> for Vec<$dst_row> { - fn transpose_from(&mut self, src: &BitDecomposed<$src_row>) { + type Error = LengthError; + + fn transpose_from(&mut self, src: &BitDecomposed<$src_row>) -> Result<(), LengthError> { self.resize($src_cols, <$dst_row>::ZERO); - let src = <&[$src_row; $src_rows]>::try_from(&**src).unwrap(); + let src = <&[$src_row; $src_rows]>::try_from(&**src).map_err(|_| LengthError { + expected: $src_rows, + actual: src.len(), + })?; let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); + Ok(()) } } impl TransposeFrom<&[$src_row; $src_rows]> for Vec<$dst_row> { - fn transpose_from(&mut self, src: &[$src_row; $src_rows]) { + type Error = Infallible; + + fn transpose_from(&mut self, src: &[$src_row; $src_rows]) -> Result<(), Infallible> { self.resize($src_cols, <$dst_row>::ZERO); let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src); + dst.transpose_from(src) } } }; @@ -244,7 +273,12 @@ macro_rules! impl_transpose_shares_bool_to_ba { impl TransposeFrom<&[AdditiveShare; $src_rows]> for [AdditiveShare<$dst_row>; $src_cols] { - fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + type Error = Infallible; + + fn transpose_from( + &mut self, + src: &[AdditiveShare; $src_rows], + ) -> Result<(), Infallible> { // Transpose left share do_transpose_16( $src_rows / 16, @@ -287,6 +321,7 @@ macro_rules! impl_transpose_shares_bool_to_ba { } }, ); + Ok(()) } } @@ -299,13 +334,22 @@ macro_rules! impl_transpose_shares_bool_to_ba { impl TransposeFrom<&BitDecomposed>> for Vec> { - fn transpose_from(&mut self, src: &BitDecomposed>) { + type Error = LengthError; + + fn transpose_from( + &mut self, + src: &BitDecomposed>, + ) -> Result<(), LengthError> { self.resize($src_cols, AdditiveShare::<$dst_row>::ZERO); - let src = - <&[AdditiveShare; $src_rows]>::try_from(&**src).unwrap(); + let src = <&[AdditiveShare; $src_rows]>::try_from(&**src) + .map_err(|_| LengthError { + expected: $src_rows, + actual: src.len(), + })?; let dst = <&mut [AdditiveShare<$dst_row>; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); + Ok(()) } } }; @@ -327,7 +371,12 @@ macro_rules! impl_transpose_shares_ba_fn_to_bool { impl TransposeFrom<&dyn Fn(usize) -> AdditiveShare<$src_row>> for [AdditiveShare; $src_cols] { - fn transpose_from(&mut self, src: &dyn Fn(usize) -> AdditiveShare<$src_row>) { + type Error = Infallible; + + fn transpose_from( + &mut self, + src: &dyn Fn(usize) -> AdditiveShare<$src_row>, + ) -> Result<(), Infallible> { // Transpose left share do_transpose_16( $src_rows / 16, @@ -368,18 +417,24 @@ macro_rules! impl_transpose_shares_ba_fn_to_bool { } }, ); + Ok(()) } } impl TransposeFrom<&dyn Fn(usize) -> AdditiveShare<$src_row>> for BitDecomposed> { - fn transpose_from(&mut self, src: &dyn Fn(usize) -> AdditiveShare<$src_row>) { + type Error = Infallible; + + fn transpose_from( + &mut self, + src: &dyn Fn(usize) -> AdditiveShare<$src_row>, + ) -> Result<(), Infallible> { self.resize($src_cols, AdditiveShare::::ZERO); let dst = <&mut [AdditiveShare; $src_cols]>::try_from(&mut **self) .unwrap(); - dst.transpose_from(src); + dst.transpose_from(src) } } }; @@ -397,7 +452,12 @@ macro_rules! impl_transpose_shares_bool_to_bool { impl TransposeFrom<&[AdditiveShare; $src_rows]> for [AdditiveShare; $src_cols] { - fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + type Error = Infallible; + + fn transpose_from( + &mut self, + src: &[AdditiveShare; $src_rows], + ) -> Result<(), Infallible> { // Transpose left share do_transpose_16( $src_rows / 16, @@ -438,6 +498,7 @@ macro_rules! impl_transpose_shares_bool_to_bool { } }, ); + Ok(()) } } @@ -450,21 +511,34 @@ macro_rules! impl_transpose_shares_bool_to_bool { impl TransposeFrom<&[AdditiveShare]> for BitDecomposed> { - fn transpose_from(&mut self, src: &[AdditiveShare]) { - let src = <&[AdditiveShare; $src_rows]>::try_from(src).unwrap(); - self.transpose_from(src); + type Error = LengthError; + fn transpose_from( + &mut self, + src: &[AdditiveShare], + ) -> Result<(), LengthError> { + let src = <&[AdditiveShare; $src_rows]>::try_from(src) + .map_err(|_| LengthError { + expected: $src_rows, + actual: src.len(), + })?; + self.transpose_from(src).unwrap_infallible(); + Ok(()) } } impl TransposeFrom<&[AdditiveShare; $src_rows]> for BitDecomposed> { - fn transpose_from(&mut self, src: &[AdditiveShare; $src_rows]) { + type Error = Infallible; + fn transpose_from( + &mut self, + src: &[AdditiveShare; $src_rows], + ) -> Result<(), Infallible> { self.resize($src_cols, AdditiveShare::::ZERO); let dst = <&mut [AdditiveShare; $src_cols]>::try_from(&mut **self) .unwrap(); - dst.transpose_from(src); + dst.transpose_from(src) } } }; @@ -603,7 +677,7 @@ mod tests { #[test] fn transpose_8x8() { - test_transpose_array::(super::transpose_8x8); + test_transpose_array::(|m| super::transpose_8x8(m)); } #[test] @@ -621,12 +695,12 @@ mod tests { where SR: PartialEq + SharedValue + ArrayAccess, DR: PartialEq + SharedValue + ArrayAccess, - [DR; DM]: for<'a> TransposeFrom<&'a [SR; SM]>, + [DR; DM]: for<'a> TransposeFrom<&'a [SR; SM], Error = Infallible>, Standard: Distribution, { let t_impl = |src| { let mut dst = [DR::ZERO; DM]; - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); dst }; @@ -671,11 +745,12 @@ mod tests { Boolean: Vectorizable, >::Array: ArrayAccess, DR: SharedValue + ArrayAccess, - [AdditiveShare; DM]: for<'a> TransposeFrom<&'a [AdditiveShare; SM]>, + [AdditiveShare; DM]: + for<'a> TransposeFrom<&'a [AdditiveShare; SM], Error = Infallible>, { let t_impl = |src| { let mut dst = [AdditiveShare::::ZERO; DM]; - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); dst }; @@ -737,11 +812,11 @@ mod tests { Boolean: Vectorizable, >::Array: ArrayAccess, [AdditiveShare; DM]: - for<'a> TransposeFrom<&'a [AdditiveShare; SM]>, + for<'a> TransposeFrom<&'a [AdditiveShare; SM], Error = Infallible>, { let t_impl = |src| { let mut dst = [AdditiveShare::::ZERO; DM]; - dst.transpose_from(src); + dst.transpose_from(src).unwrap_infallible(); dst }; From a4a2aa205ac6baec70c83cd5559602361338b78f Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 18 Mar 2024 17:49:24 -0700 Subject: [PATCH 3/3] Fix redundant import warning (that may only be detected by nightly?) --- ipa-core/src/secret_sharing/vector/transpose.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index 8b4145671..020cdb9a6 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -554,7 +554,6 @@ mod tests { #![allow(clippy::needless_range_loop)] use std::{ - array, cmp::min, fmt::Debug, iter::repeat_with,