From 978f02246f0ac37b97cdbcd9e79fad22e00b52b9 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 11 Apr 2024 10:40:19 -0400 Subject: [PATCH 1/2] Transposes for vectorized aggregation Add additional macros in transpose implementation, to reduce repetitive code --- ipa-core/src/secret_sharing/decomposed.rs | 4 + .../src/secret_sharing/vector/transpose.rs | 670 ++++++++++++------ 2 files changed, 444 insertions(+), 230 deletions(-) diff --git a/ipa-core/src/secret_sharing/decomposed.rs b/ipa-core/src/secret_sharing/decomposed.rs index e96d8e322..4830715a4 100644 --- a/ipa-core/src/secret_sharing/decomposed.rs +++ b/ipa-core/src/secret_sharing/decomposed.rs @@ -118,6 +118,10 @@ impl BitDecomposed { pub fn resize(&mut self, new_len: usize, value: S) { self.bits.resize(new_len, value); } + + pub fn truncate(&mut self, len: usize) { + self.bits.truncate(len); + } } impl FromPrss for BitDecomposed> diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index d17a21b69..59536cf4f 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -1,4 +1,4 @@ -//! Bit matrix transpose operations. +//! # Bit matrix transpose operations //! //! These are used to convert data to and from vectorized representation. //! @@ -36,16 +36,14 @@ // This rule throws false positives on "MxN". #![allow(clippy::doc_markdown)] -#[cfg(any(all(test, unit_test), feature = "enable-benches"))] -use std::borrow::Borrow; -use std::{array, convert::Infallible}; +use std::{array, borrow::Borrow, convert::Infallible, ops::Deref}; use crate::{ const_assert_eq, error::{LengthError, UnwrapInfallible}, ff::{ boolean::Boolean, - boolean_array::{BA256, BA64}, + boolean_array::{BA16, BA256, BA3, BA5, BA64, BA8}, ec_prime_field::Fp25519, }, protocol::ipa_prf::{MK_BITS, PRF_CHUNK}, @@ -102,7 +100,6 @@ pub trait TransposeFrom { /// 8x8 bit matrix transpose. // // From Hacker's Delight (2nd edition), Figure 7-6. -#[cfg(any(all(test, unit_test), feature = "enable-benches"))] #[inline] pub fn transpose_8x8>(x: B) -> [u8; 8] { let mut x = u64::from_le_bytes(*x.borrow()); @@ -209,25 +206,265 @@ impl<'a> TransposeFrom<&'a [StdArray; 256]> for Vec { // Matrix transpose helpers -/// Perform a larger transpose using an 16x16 kernel. +macro_rules! read_ba_left_8_pad { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident, $pad_value:expr) => { + $m[$k] = $src + .get(8 * $i + $k) + .unwrap_or($pad_value) + .left() + .as_raw_slice()[$j] + }; +} + +macro_rules! read_ba_right_8_pad { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident, $pad_value:expr) => { + $m[$k] = $src + .get(8 * $i + $k) + .unwrap_or($pad_value) + .right() + .as_raw_slice()[$j] + }; +} + +macro_rules! write_ba_left_8 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[8 * $i + $k].left_arr_mut().0[0].as_raw_mut_slice()[$j] = $m[$k] + }; +} + +macro_rules! write_ba_right_8 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[8 * $i + $k].right_arr_mut().0[0].as_raw_mut_slice()[$j] = $m[$k] + }; +} + +macro_rules! read_bool_left_8 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[$k] = $src[8 * $i + $k].left_arr().as_raw_slice()[$j] + }; +} + +macro_rules! read_bool_right_8 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[$k] = $src[8 * $i + $k].right_arr().as_raw_slice()[$j] + }; +} + +macro_rules! write_bool_left_8 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[8 * $i + $k].left_arr_mut().as_raw_mut_slice()[$j] = $m[$k] + }; +} + +macro_rules! write_bool_right_8 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[8 * $i + $k].right_arr_mut().as_raw_mut_slice()[$j] = $m[$k] + }; +} + +macro_rules! read_ba_16 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[2 * $k..2 * ($k + 1)] + .copy_from_slice(&$src[16 * $i + $k].as_raw_slice()[2 * $j..2 * ($j + 1)]) + }; +} + +macro_rules! read_ba_left_16 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[2 * $k..2 * ($k + 1)] + .copy_from_slice(&$src[16 * $i + $k].left().as_raw_slice()[2 * $j..2 * ($j + 1)]) + }; +} + +macro_rules! read_ba_right_16 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[2 * $k..2 * ($k + 1)] + .copy_from_slice(&$src[16 * $i + $k].right().as_raw_slice()[2 * $j..2 * ($j + 1)]) + }; +} + +macro_rules! read_ba_fn_left_16 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[2 * $k..2 * ($k + 1)] + .copy_from_slice(&$src(16 * $i + $k).left().as_raw_slice()[2 * $j..2 * ($j + 1)]) + }; +} + +macro_rules! read_ba_fn_right_16 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[2 * $k..2 * ($k + 1)] + .copy_from_slice(&$src(16 * $i + $k).right().as_raw_slice()[2 * $j..2 * ($j + 1)]) + }; +} + +macro_rules! write_ba_16 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[16 * $i + $k].as_raw_mut_slice()[2 * $j..2 * ($j + 1)] + .copy_from_slice(&$m[2 * $k..2 * ($k + 1)]); + }; +} + +macro_rules! write_ba_left_16 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[16 * $i + $k].left_arr_mut().0[0].as_raw_mut_slice()[2 * $j..2 * ($j + 1)] + .copy_from_slice(&$m[2 * $k..2 * ($k + 1)]); + }; +} + +macro_rules! write_ba_right_16 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[16 * $i + $k].right_arr_mut().0[0].as_raw_mut_slice()[2 * $j..2 * ($j + 1)] + .copy_from_slice(&$m[2 * $k..2 * ($k + 1)]); + }; +} + +macro_rules! read_bool_left_16 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[2 * $k..2 * ($k + 1)] + .copy_from_slice(&$src[16 * $i + $k].left_arr().as_raw_slice()[2 * $j..2 * ($j + 1)]) + }; +} + +macro_rules! read_bool_right_16 { + ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident) => { + $m[2 * $k..2 * ($k + 1)] + .copy_from_slice(&$src[16 * $i + $k].right_arr().as_raw_slice()[2 * $j..2 * ($j + 1)]) + }; +} + +macro_rules! write_bool_left_16 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[16 * $i + $k].left_arr_mut().as_raw_mut_slice()[2 * $j..2 * ($j + 1)] + .copy_from_slice(&$m[2 * $k..2 * ($k + 1)]) + }; +} + +macro_rules! write_bool_right_16 { + ($dst:ident, $m:ident, $i:ident, $j:ident, $k:ident) => { + $dst[16 * $i + $k].right_arr_mut().as_raw_mut_slice()[2 * $j..2 * ($j + 1)] + .copy_from_slice(&$m[2 * $k..2 * ($k + 1)]) + }; +} + +/// Implement a larger transpose using the 8x8 kernel. +/// +/// Matrix height and width must be multiples of 8. +macro_rules! impl_transpose_8 { + ($dst:ident, $src:ident, $src_rows:expr, $src_cols:expr, $read:ident, $write:ident $(,)?) => { + debug_assert!( + $src_rows % 8 == 0 && $src_cols % 8 == 0, + "This implementation requires that both dimensions are multiples of 8", + ); + + for i in 0..$src_rows / 8 { + for j in 0..$src_cols / 8 { + let mut m = [0u8; 8]; + for k in 0..8 { + $read!(m, $src, i, j, k); + } + let m_t = transpose_8x8(&m); + for k in 0..8 { + $write!($dst, m_t, j, i, k); + } + } + } + }; +} + +/// Implement a larger transpose using the 8x8 kernel. +/// +/// Matrix height and width do not need to be multiples of 8, however, the row stride in memory must +/// still be a multiple of 8 (i.e. whole bytes). +macro_rules! impl_transpose_8_pad { + ($dst:ident, $src:ident, $src_rows:expr, $src_cols:expr, $read:ident, $pad_value:expr, $write:ident $(,)?) => { + for i in 0..($src_rows + 7) / 8 { + for j in 0..($src_cols + 7) / 8 { + let mut m = [0u8; 8]; + for k in 0..8 { + $read!(m, $src, i, j, k, $pad_value); + } + let m_t = transpose_8x8(&m); + for k in 0..8 { + $write!($dst, m_t, j, i, k); + } + } + } + }; +} + +/// Implement a larger transpose using the 16x16 kernel. /// /// Matrix height and width must be multiples of 16. -#[inline] -fn do_transpose_16 [u8; 32], DF: FnMut(usize, usize, [u8; 32])>( - rows_div16: usize, - cols_div16: usize, - read_src: SF, - mut write_dst: DF, -) { - for i in 0..rows_div16 { - for j in 0..cols_div16 { - let m = read_src(i, j); - let m_t = transpose_16x16(&m); - write_dst(j, i, m_t); +macro_rules! impl_transpose_16 { + ($dst:ident, $src:ident, $src_rows:expr, $src_cols:expr, $read:ident, $write:ident $(,)?) => { + debug_assert!( + $src_rows % 16 == 0 && $src_cols % 16 == 0, + "This implementation requires that both dimensions are multiples of 16", + ); + + for i in 0..$src_rows / 16 { + for j in 0..$src_cols / 16 { + let mut m = [0u8; 32]; + for k in 0..16 { + $read!(m, $src, i, j, k); + } + let m_t = transpose_16x16(&m); + for k in 0..16 { + $write!($dst, m_t, j, i, k); + } + } } - } + }; +} + +macro_rules! transpose_shim_convert_src { + ($converted:ty, $expected_len:expr, $src:ident, LengthError) => { + <$converted>::try_from($src.deref()).map_err(|_| LengthError { + expected: $expected_len, + actual: $src.len(), + })? + }; + ($converted:ty, $expected_len:expr, $src:ident, Infallible) => { + $src + }; } +macro_rules! impl_transpose_shim { + ($src_ty:ty, $src_row:ty, $dst_ty:ty, $dst_row:ty, $src_rows:expr, $src_cols:expr, $error:tt $(,)?) => { + impl TransposeFrom<$src_ty> for $dst_ty { + type Error = $error; + fn transpose_from(&mut self, src: $src_ty) -> Result<(), Self::Error> { + self.resize($src_cols, <$dst_row>::ZERO); + let src = + transpose_shim_convert_src!(&[$src_row; $src_rows], $src_rows, src, $error); + let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); + dst.transpose_from(src).unwrap_infallible(); + Ok(()) + } + } + }; +} + +macro_rules! impl_transpose_shim_pad { + ($src_ty:ty, $src_row:ty, $dst_ty:ty, $dst_row:ty, $src_rows:expr, $src_cols:expr, $error:tt $(,)?) => { + impl TransposeFrom<$src_ty> for $dst_ty { + type Error = $error; + fn transpose_from(&mut self, src: $src_ty) -> Result<(), Self::Error> { + self.resize(($src_cols + 7) / 8 * 8, <$dst_row>::ZERO); + let src = + transpose_shim_convert_src!(&[$src_row; $src_rows], $src_rows, src, $error); + let dst = + <&mut [$dst_row; ($src_cols + 7) / 8 * 8]>::try_from(&mut **self).unwrap(); + dst.transpose_from(src).unwrap_infallible(); + self.truncate($src_cols); + Ok(()) + } + } + }; +} + +// Matrix transposes + /// Implement a transpose of a MxN bit matrix represented as `[BA{N}; {M}]` into a NxM bit matrix /// represented as `[BA{M}; {N}]`. /// @@ -235,30 +472,12 @@ fn do_transpose_16 [u8; 32], DF: FnMut(usize, usize, [u8 /// `impl_transpose_ba_to_ba!(BA16, BA64, 16, 64)`. Or to put it differently, write the MxN dimensions /// twice, first with BA in front, and then without. macro_rules! impl_transpose_ba_to_ba { - ($dst_row:ty, $src_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { + ($dst_row:ty, $src_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident $(,)?) => { impl TransposeFrom<&[$src_row; $src_rows]> for [$dst_row; $src_cols] { type Error = Infallible; fn transpose_from(&mut self, src: &[$src_row; $src_rows]) -> Result<(), Infallible> { - do_transpose_16( - $src_rows / 16, - $src_cols / 16, - |i, j| { - let mut d = [0u8; 32]; - for k in 0..16 { - d[2 * k..2 * (k + 1)].copy_from_slice( - &src[16 * i + k].as_raw_slice()[2 * j..2 * (j + 1)], - ); - } - d - }, - |i, j, d| { - for k in 0..16 { - self[16 * i + k].as_raw_mut_slice()[2 * j..2 * (j + 1)] - .copy_from_slice(&d[2 * k..2 * (k + 1)]); - } - }, - ); + impl_transpose_16!(self, src, $src_rows, $src_cols, read_ba_16, write_ba_16); Ok(()) } } @@ -269,30 +488,25 @@ macro_rules! impl_transpose_ba_to_ba { tests::test_transpose_ba_to_ba::<$dst_row, $src_row, $src_rows, $src_cols>(); } - impl TransposeFrom<&BitDecomposed<$src_row>> for Vec<$dst_row> { - type Error = LengthError; - - fn transpose_from(&mut self, src: &BitDecomposed<$src_row>) -> Result<(), LengthError> { - self.resize($src_cols, <$dst_row>::ZERO); - let src = <&[$src_row; $src_rows]>::try_from(&**src).map_err(|_| LengthError { - expected: $src_rows, - actual: src.len(), - })?; - let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src).unwrap_infallible(); - Ok(()) - } - } - - impl TransposeFrom<&[$src_row; $src_rows]> for Vec<$dst_row> { - type Error = Infallible; + impl_transpose_shim!( + &BitDecomposed<$src_row>, + $src_row, + Vec<$dst_row>, + $dst_row, + $src_rows, + $src_cols, + LengthError, + ); - fn transpose_from(&mut self, src: &[$src_row; $src_rows]) -> Result<(), Infallible> { - self.resize($src_cols, <$dst_row>::ZERO); - let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src) - } - } + impl_transpose_shim!( + &[$src_row; $src_rows], + $src_row, + Vec<$dst_row>, + $dst_row, + $src_rows, + $src_cols, + Infallible, + ); }; } @@ -317,48 +531,8 @@ macro_rules! impl_transpose_shares_bool_to_ba { &mut self, src: &[AdditiveShare; $src_rows], ) -> Result<(), Infallible> { - // Transpose left share - do_transpose_16( - $src_rows / 16, - $src_cols / 16, - |i, j| { - let mut d = [0u8; 32]; - for k in 0..16 { - d[2 * k..2 * (k + 1)].copy_from_slice( - &src[16 * i + k].left_arr().as_raw_slice()[2 * j..2 * (j + 1)], - ); - } - d - }, - |i, j, d| { - for k in 0..16 { - self[16 * i + k].left_arr_mut().0[0].as_raw_mut_slice() - [2 * j..2 * (j + 1)] - .copy_from_slice(&d[2 * k..2 * (k + 1)]); - } - }, - ); - // Transpose right share - do_transpose_16( - $src_rows / 16, - $src_cols / 16, - |i, j| { - let mut d = [0u8; 32]; - for k in 0..16 { - d[2 * k..2 * (k + 1)].copy_from_slice( - &src[16 * i + k].right_arr().as_raw_slice()[2 * j..2 * (j + 1)], - ); - } - d - }, - |i, j, d| { - for k in 0..16 { - self[16 * i + k].right_arr_mut().0[0].as_raw_mut_slice() - [2 * j..2 * (j + 1)] - .copy_from_slice(&d[2 * k..2 * (k + 1)]); - } - }, - ); + impl_transpose_16!(self, src, $src_rows, $src_cols, read_bool_left_16, write_ba_left_16); + impl_transpose_16!(self, src, $src_rows, $src_cols, read_bool_right_16, write_ba_right_16); Ok(()) } } @@ -369,32 +543,54 @@ macro_rules! impl_transpose_shares_bool_to_ba { tests::test_transpose_shares_bool_to_ba::<$dst_row, $src_rows, $src_cols>(); } - impl TransposeFrom<&BitDecomposed>> - for Vec> + impl_transpose_shim!( + &BitDecomposed>, AdditiveShare, + Vec>, AdditiveShare<$dst_row>, + $src_rows, $src_cols, + LengthError, + ); + }; +} + +// Usage: Share conversion output (r/s). M = Fp25519::BITS, N = PRF_CHUNK. +impl_transpose_shares_bool_to_ba!(BA256, 256, 64, test_transpose_shares_bool_to_ba_256x64); + +macro_rules! impl_transpose_shares_bool_to_ba_small { + ($dst_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { + impl TransposeFrom<&[AdditiveShare; $src_rows]> + for [AdditiveShare<$dst_row>; $src_cols] { - type Error = LengthError; + type Error = Infallible; fn transpose_from( &mut self, - src: &BitDecomposed>, - ) -> Result<(), LengthError> { - self.resize($src_cols, AdditiveShare::<$dst_row>::ZERO); - let src = <&[AdditiveShare; $src_rows]>::try_from(&**src) - .map_err(|_| LengthError { - expected: $src_rows, - actual: src.len(), - })?; - let dst = - <&mut [AdditiveShare<$dst_row>; $src_cols]>::try_from(&mut **self).unwrap(); - dst.transpose_from(src).unwrap_infallible(); + src: &[AdditiveShare; $src_rows], + ) -> Result<(), Infallible> { + impl_transpose_8!(self, src, $src_rows, $src_cols, read_bool_left_8, write_ba_left_8); + impl_transpose_8!(self, src, $src_rows, $src_cols, read_bool_right_8, write_ba_right_8); Ok(()) } } + + #[cfg(all(test, unit_test))] + #[test] + fn $test_fn() { + tests::test_transpose_shares_bool_to_ba::<$dst_row, $src_rows, $src_cols>(); + } + + impl_transpose_shim!( + &BitDecomposed>, AdditiveShare, + Vec>, AdditiveShare<$dst_row>, + $src_rows, $src_cols, + LengthError, + ); }; } -// Usage: Share conversion output (r/s). M = Fp25519::BITS, N = PRF_CHUNK. -impl_transpose_shares_bool_to_ba!(BA256, 256, 64, test_transpose_shares_bool_to_ba_256x64); +// Usage: Aggregation output. M = HV bits, N = number of breakdowns. +impl_transpose_shares_bool_to_ba_small!(BA8, 8, 256, test_transpose_shares_bool_to_ba_8x256); +impl_transpose_shares_bool_to_ba!(BA16, 16, 256, test_transpose_shares_bool_to_ba_16x256); +impl_transpose_shares_bool_to_ba!(BA16, 16, 32, test_transpose_shares_bool_to_ba_16x32); /// Implement a transpose of a MxN matrix of secret-shared bits represented as /// `[AdditiveShare>; M]` into a NxM bit matrix represented as `[AdditiveShare; N]`. @@ -410,46 +606,8 @@ macro_rules! impl_transpose_shares_ba_to_bool { &mut self, src: &[AdditiveShare<$src_row>; $src_rows], ) -> Result<(), Infallible> { - // Transpose left share - do_transpose_16( - $src_rows / 16, - $src_cols / 16, - |i, j| { - let mut d = [0u8; 32]; - for k in 0..16 { - d[2 * k..2 * (k + 1)].copy_from_slice( - &src[16 * i + k].left().as_raw_slice()[2 * j..2 * (j + 1)], - ); - } - d - }, - |i, j, d| { - for k in 0..16 { - self[16 * i + k].left_arr_mut().as_raw_mut_slice()[2 * j..2 * (j + 1)] - .copy_from_slice(&d[2 * k..2 * (k + 1)]); - } - }, - ); - // Transpose right share - do_transpose_16( - $src_rows / 16, - $src_cols / 16, - |i, j| { - let mut d = [0u8; 32]; - for k in 0..16 { - d[2 * k..2 * (k + 1)].copy_from_slice( - &src[16 * i + k].right().as_raw_slice()[2 * j..2 * (j + 1)], - ); - } - d - }, - |i, j, d| { - for k in 0..16 { - self[16 * i + k].right_arr_mut().as_raw_mut_slice()[2 * j..2 * (j + 1)] - .copy_from_slice(&d[2 * k..2 * (k + 1)]); - } - }, - ); + impl_transpose_16!(self, src, $src_rows, $src_cols, read_ba_left_16, write_bool_left_16); + impl_transpose_16!(self, src, $src_rows, $src_cols, read_ba_right_16, write_bool_right_16); Ok(()) } } @@ -460,21 +618,19 @@ macro_rules! impl_transpose_shares_ba_to_bool { tests::test_transpose_shares_ba_to_bool::<$src_row, $src_rows, $src_cols>(); } - impl TransposeFrom<&[AdditiveShare<$src_row>; $src_rows]> - for BitDecomposed> - { - type Error = Infallible; - fn transpose_from( - &mut self, - src: &[AdditiveShare<$src_row>; $src_rows], - ) -> Result<(), Infallible> { - self.resize($src_cols, AdditiveShare::::ZERO); - let dst = - <&mut [AdditiveShare; $src_cols]>::try_from(&mut **self) - .unwrap(); - dst.transpose_from(src) - } - } + impl_transpose_shim!( + &[AdditiveShare<$src_row>; $src_rows], AdditiveShare<$src_row>, + BitDecomposed>, AdditiveShare, + $src_rows, $src_cols, + Infallible, + ); + + impl_transpose_shim!( + &Vec>, AdditiveShare<$src_row>, + BitDecomposed>, AdditiveShare, + $src_rows, $src_cols, + LengthError, + ); }; } @@ -500,46 +656,8 @@ macro_rules! impl_transpose_shares_ba_fn_to_bool { &mut self, src: &dyn Fn(usize) -> AdditiveShare<$src_row>, ) -> Result<(), Infallible> { - // Transpose left share - do_transpose_16( - $src_rows / 16, - $src_cols / 16, - |i, j| { - let mut d = [0u8; 32]; - for k in 0..16 { - d[2 * k..2 * (k + 1)].copy_from_slice( - &src(16 * i + k).left().as_raw_slice()[2 * j..2 * (j + 1)], - ); - } - d - }, - |i, j, d| { - for k in 0..16 { - self[16 * i + k].left_arr_mut().as_raw_mut_slice()[2 * j..2 * (j + 1)] - .copy_from_slice(&d[2 * k..2 * (k + 1)]); - } - }, - ); - // Transpose right share - do_transpose_16( - $src_rows / 16, - $src_cols / 16, - |i, j| { - let mut d = [0u8; 32]; - for k in 0..16 { - d[2 * k..2 * (k + 1)].copy_from_slice( - &src(16 * i + k).right().as_raw_slice()[2 * j..2 * (j + 1)], - ); - } - d - }, - |i, j, d| { - for k in 0..16 { - self[16 * i + k].right_arr_mut().as_raw_mut_slice()[2 * j..2 * (j + 1)] - .copy_from_slice(&d[2 * k..2 * (k + 1)]); - } - }, - ); + impl_transpose_16!(self, src, $src_rows, $src_cols, read_ba_fn_left_16, write_bool_left_16); + impl_transpose_16!(self, src, $src_rows, $src_cols, read_ba_fn_right_16, write_bool_right_16); Ok(()) } } @@ -550,28 +668,77 @@ macro_rules! impl_transpose_shares_ba_fn_to_bool { tests::test_transpose_shares_ba_fn_to_bool::<$src_row, $src_rows, $src_cols>(); } - impl TransposeFrom<&dyn Fn(usize) -> AdditiveShare<$src_row>> - for BitDecomposed> + impl_transpose_shim!( + &dyn Fn(usize) -> AdditiveShare<$src_row>, AdditiveShare<$src_row>, + BitDecomposed>, AdditiveShare, + $src_rows, $src_cols, + Infallible, + ); + }; +} + +// Usage: Share conversion input (compute_prf_for_inputs). M = PRF_CHUNK, N = MK_BITS. +// Note first macro argument is `BA{N}`, not `BA{M}`. +impl_transpose_shares_ba_fn_to_bool!(BA64, 64, 64, test_transpose_shares_ba_fn_to_bool_64x64); + +/// Implement a transpose of a MxN matrix of secret-shared bits represented as +/// `[AdditiveShare>; M]` into a NxM bit matrix represented as `[AdditiveShare; N]`. +/// +/// For MxN = 16x64, the invocation looks like `impl_transpose_shares_ba_to_bool_small!(BA64, 16, 64)`. +/// +/// This version uses the 8x8 transpose kernel and supports dimensions that are not a multiple of 8. +macro_rules! impl_transpose_shares_ba_to_bool_small { + ($src_row:ty, $src_rows:expr, $src_cols:expr, $test_fn:ident) => { + impl TransposeFrom<&[AdditiveShare<$src_row>; $src_rows]> + for [AdditiveShare; ($src_cols + 7) / 8 * 8] { type Error = Infallible; - fn transpose_from( &mut self, - src: &dyn Fn(usize) -> AdditiveShare<$src_row>, + src: &[AdditiveShare<$src_row>; $src_rows], ) -> Result<(), Infallible> { - self.resize($src_cols, AdditiveShare::::ZERO); - let dst = - <&mut [AdditiveShare; $src_cols]>::try_from(&mut **self) - .unwrap(); - dst.transpose_from(src) + impl_transpose_8_pad!( + self, src, + $src_rows, $src_cols, + read_ba_left_8_pad, &AdditiveShare::<$src_row>::ZERO, + write_bool_left_8, + ); + impl_transpose_8_pad!( + self, src, + $src_rows, $src_cols, + read_ba_right_8_pad, &AdditiveShare::<$src_row>::ZERO, + write_bool_right_8, + ); + Ok(()) } } + + #[cfg(all(test, unit_test))] + #[test] + fn $test_fn() { + tests::test_transpose_shares_ba_to_bool_small::<$src_row, $src_rows, $src_cols>(); + } + + impl_transpose_shim_pad!( + &[AdditiveShare<$src_row>; $src_rows], AdditiveShare<$src_row>, + BitDecomposed>, AdditiveShare, + $src_rows, $src_cols, + Infallible, + ); + + impl_transpose_shim_pad!( + &Vec>, AdditiveShare<$src_row>, + BitDecomposed>, AdditiveShare, + $src_rows, $src_cols, + LengthError, + ); }; } -// Usage: Share conversion input (compute_prf_for_inputs). M = PRF_CHUNK, N = MK_BITS. -// Note first macro argument is `BA{N}`, not `BA{M}`. -impl_transpose_shares_ba_fn_to_bool!(BA64, 64, 64, test_transpose_shares_ba_fn_to_bool_64x64); +// Usage: Aggregation input. M = AGG_CHUNK, N = BK or TV bits. +impl_transpose_shares_ba_to_bool_small!(BA8, 256, 8, test_transpose_shares_ba_to_bool_256x8); +impl_transpose_shares_ba_to_bool_small!(BA5, 256, 5, test_transpose_shares_ba_to_bool_256x5); +impl_transpose_shares_ba_to_bool_small!(BA3, 256, 3, test_transpose_shares_ba_to_bool_256x3); #[cfg(all(test, unit_test))] mod tests { @@ -725,7 +892,7 @@ mod tests { let mut left = vec![Boolean::FALSE; N]; let mut right = vec![Boolean::FALSE; N]; for j in ((i % N)..N).step_by(step) { - let b = Boolean::from(j % 2 != 0); + let b = Boolean::from((j / N) % 2 != 0); left[j] = b; right[j] = !b; } @@ -746,7 +913,7 @@ mod tests { let mut left = vec![Boolean::FALSE; N]; let mut right = vec![Boolean::FALSE; N]; for j in ((i % N)..N).step_by(step) { - let b = Boolean::from(j % 2 != 0); + let b = Boolean::from((j / N) % 2 != 0); left[j] = b; right[j] = !b; } @@ -845,6 +1012,49 @@ mod tests { ); } + // The order of type parameters matches the implementation macro: BA, , + pub(super) fn test_transpose_shares_ba_to_bool_small< + SR, // Source row type + const SM: usize, // Source rows (== dest cols) + const DM: usize, // Destination rows (== source cols) + >() + where + Boolean: Vectorizable, + >::Array: ArrayAccess, + SR: SharedValue + + ArrayAccess + + FromIterator + + Vectorizable<1, Array = StdArray>, + BitDecomposed>: + for<'a> TransposeFrom<&'a Vec>, Error = LengthError>, + Standard: Distribution, + { + let t_impl = |src| { + let mut dst = + BitDecomposed::try_from(vec![AdditiveShare::::ZERO; DM]).unwrap(); + dst.transpose_from(src).ok().unwrap(); + dst + }; + + let step = min(SM, DM); + let m = ba_shares_test_matrix::(step).to_vec(); + let m_t = t_impl(&m); + assert_eq!(&*m_t, &bool_shares_test_matrix::(step)); + + let mut left_rng = thread_rng(); + let mut right_rng = thread_rng(); + let m = repeat_with(|| AdditiveShare::from_fns(|_| left_rng.gen(), |_| right_rng.gen())) + .take(SM) + .collect::>(); + let m_t = t_impl(&m); + + #[rustfmt::skip] + verify_transpose(SM, DM, + |i, j| (m_t[i].left_arr().get(j).unwrap(), m_t[i].right_arr().get(j).unwrap()), + |i, j| (m[i].get(j).unwrap().left(), m[i].get(j).unwrap().right()), + ); + } + // The order of type parameters matches the implementation macro: BA, , pub(super) fn test_transpose_shares_ba_fn_to_bool< SR, // Source row type From aa1554d23e770655c411fafbc3f2ea98ba5d144d Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 17 Apr 2024 11:49:37 -0700 Subject: [PATCH 2/2] Additional comments --- .../src/secret_sharing/vector/transpose.rs | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index 59536cf4f..45dc8acb5 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -206,6 +206,26 @@ impl<'a> TransposeFrom<&'a [StdArray; 256]> for Vec { // Matrix transpose helpers +// The read and write helpers are used with the `impl_transpose` macros to support a specific data +// type: +// 1. `impl_transpose` interprets the bit matrix in terms of 8x8 or 16x16 submatrices. It iterates +// over the entire matrix, with `i` and `j` serving as row and column submatrix indices, +// respectively. +// 2. `impl_transpose` invokes the selected `read_*` macro to transfer one submatrix +// from the position (i, j) in the source source to temporary storage. The `read_*` macro is +// invoked once for each row of the submatrix, with `k` as row index within the submatrix. +// 3. The submatrix is transposed. +// 4. `impl_transpose` invokes the selected `write_*` macro to write the transposed +// submatrix at position (j, i) in the destination. As when reading, `k` indexes rows +// within the submatrix. +// +// The `left` and `right` variants access data in the indicated share of a replicated sharing. The +// `ba` variants access data in an array of `BA{n}` or an array of `AdditiveShare`. The +// `bool` variants access data in an array of `AdditiveShare`. The `ba_fn_{left,right}` +// variants access data by calling a closure that returns `AdditiveShare`. The `_8_pad` +// variants support reading data from a source that may not have the full height of 8, by padding +// with zeros. + macro_rules! read_ba_left_8_pad { ($m:ident, $src:ident, $i:ident, $j:ident, $k:ident, $pad_value:expr) => { $m[$k] = $src @@ -417,6 +437,9 @@ macro_rules! impl_transpose_16 { }; } +// Helper for `impl_transpose_shim` that performs a `TryFrom` conversion for the source, +// if applicable. For example, a `Vec` dereferences to a slice, which then must be +// converted to an array using `TryFrom`. macro_rules! transpose_shim_convert_src { ($converted:ty, $expected_len:expr, $src:ident, LengthError) => { <$converted>::try_from($src.deref()).map_err(|_| LengthError { @@ -429,6 +452,8 @@ macro_rules! transpose_shim_convert_src { }; } +// Implement a transpose shim that adapts a transpose implementation for arrays into a transpose +// implementation for some other type like `BitDecomposed` or `Vec`. macro_rules! impl_transpose_shim { ($src_ty:ty, $src_row:ty, $dst_ty:ty, $dst_row:ty, $src_rows:expr, $src_cols:expr, $error:tt $(,)?) => { impl TransposeFrom<$src_ty> for $dst_ty { @@ -437,6 +462,7 @@ macro_rules! impl_transpose_shim { self.resize($src_cols, <$dst_row>::ZERO); let src = transpose_shim_convert_src!(&[$src_row; $src_rows], $src_rows, src, $error); + // This unwrap cannot fail, because we resized `self` to the proper size. let dst = <&mut [$dst_row; $src_cols]>::try_from(&mut **self).unwrap(); dst.transpose_from(src).unwrap_infallible(); Ok(()) @@ -445,7 +471,8 @@ macro_rules! impl_transpose_shim { }; } -macro_rules! impl_transpose_shim_pad { +// Variant of impl_transpose_shim that adjusts non-multiple-of-8 sizes to the next multiple of 8. +macro_rules! impl_transpose_shim_8_pad { ($src_ty:ty, $src_row:ty, $dst_ty:ty, $dst_row:ty, $src_rows:expr, $src_cols:expr, $error:tt $(,)?) => { impl TransposeFrom<$src_ty> for $dst_ty { type Error = $error; @@ -453,6 +480,7 @@ macro_rules! impl_transpose_shim_pad { self.resize(($src_cols + 7) / 8 * 8, <$dst_row>::ZERO); let src = transpose_shim_convert_src!(&[$src_row; $src_rows], $src_rows, src, $error); + // This unwrap cannot fail, because we resized `self` to the proper size. let dst = <&mut [$dst_row; ($src_cols + 7) / 8 * 8]>::try_from(&mut **self).unwrap(); dst.transpose_from(src).unwrap_infallible(); @@ -719,14 +747,14 @@ macro_rules! impl_transpose_shares_ba_to_bool_small { tests::test_transpose_shares_ba_to_bool_small::<$src_row, $src_rows, $src_cols>(); } - impl_transpose_shim_pad!( + impl_transpose_shim_8_pad!( &[AdditiveShare<$src_row>; $src_rows], AdditiveShare<$src_row>, BitDecomposed>, AdditiveShare, $src_rows, $src_cols, Infallible, ); - impl_transpose_shim_pad!( + impl_transpose_shim_8_pad!( &Vec>, AdditiveShare<$src_row>, BitDecomposed>, AdditiveShare, $src_rows, $src_cols,