From 5e23eb0a9ff151da92d30dcfe87ae9924f67c75b Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 25 Jan 2024 11:23:12 -0800 Subject: [PATCH 1/3] Clean up `Field` trait * Remove `Field` trait impl from Boolean arrays. * Remove u128 conversions from `Fp25519`. * Add `select` variant of `if_else`. `select` is bus multiplexer with a single-bit control input. `if_else` is a vectorizable multiplexer, with condition input the same width as the data inputs. * Remove `Field` (or replace with `SharedValue`) in various trait bounds. * Implement `Vectorizable` and related traits for more boolean arrays. Fixes #812 --- ipa-core/src/ff/boolean.rs | 16 ---- ipa-core/src/ff/boolean_array.rs | 72 +++------------ ipa-core/src/ff/ec_prime_field.rs | 50 +++-------- ipa-core/src/ff/field.rs | 6 +- ipa-core/src/protocol/basics/check_zero.rs | 4 +- ipa-core/src/protocol/basics/if_else.rs | 62 ++++++++++++- ipa-core/src/protocol/basics/mod.rs | 4 +- ipa-core/src/protocol/basics/mul/mod.rs | 60 +++++++++++++ .../boolean_ops/addition_sequential.rs | 19 ++-- .../comparison_and_subtraction_sequential.rs | 25 +++--- .../boolean_ops/share_conversion_aby.rs | 4 +- ipa-core/src/protocol/ipa_prf/mod.rs | 20 +++-- .../prf_sharding/feature_label_dot_product.rs | 18 ++-- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 90 ++++++++++--------- ipa-core/src/protocol/ipa_prf/quicksort.rs | 19 ++-- ipa-core/src/protocol/prss/crypto.rs | 2 +- ipa-core/src/secret_sharing/array.rs | 8 +- ipa-core/src/secret_sharing/mod.rs | 75 ++++++++++++++-- ipa-core/src/test_fixture/input/sharing.rs | 8 +- 19 files changed, 327 insertions(+), 235 deletions(-) diff --git a/ipa-core/src/ff/boolean.rs b/ipa-core/src/ff/boolean.rs index 954d9806f..04bb22551 100644 --- a/ipa-core/src/ff/boolean.rs +++ b/ipa-core/src/ff/boolean.rs @@ -197,22 +197,6 @@ impl FromRandomU128 for Boolean { } } -impl Vectorizable<64> for Boolean { - type Array = crate::ff::boolean_array::BA64; -} - -impl FieldVectorizable<64> for Boolean { - type ArrayAlias = crate::ff::boolean_array::BA64; -} - -impl Vectorizable<256> for Boolean { - type Array = crate::ff::boolean_array::BA256; -} - -impl FieldVectorizable<256> for Boolean { - type ArrayAlias = crate::ff::boolean_array::BA256; -} - #[cfg(all(test, unit_test))] mod test { use generic_array::GenericArray; diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index 3078f5856..ba43f05ea 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -1,5 +1,5 @@ use bitvec::{ - prelude::{bitarr, BitArr, Lsb0}, + prelude::{BitArr, Lsb0}, slice::Iter, }; use generic_array::GenericArray; @@ -9,7 +9,7 @@ use crate::{ error::LengthError, ff::{boolean::Boolean, ArrayAccess, ArrayBuilder, Field, Serializable, U128Conversions}, protocol::prss::{FromRandom, FromRandomU128}, - secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, + secret_sharing::{Block, SharedValue, StdArray, Vectorizable}, }; /// The implementation below cannot be constrained without breaking Rust's @@ -78,64 +78,11 @@ where } } -/// A value of ONE has a one in the first element of the bit array, followed by `$bits-1` zeros. -/// This macro uses a bit of recursive repetition to produce those zeros. -/// -/// The longest call is 8 bits, which involves `2(n+1)` macro expansions in addition to `bitarr!`. -macro_rules! bitarr_one { - - // The binary value of `$bits-1` is expanded in MSB order for each of the values we care about. - // e.g., 20 =(-1)=> 19 =(binary)=> 0b10011 =(expand)=> 1 0 0 1 1 - - (2) => { bitarr_one!(1) }; - (3) => { bitarr_one!(1 0) }; - (4) => { bitarr_one!(1 1) }; - (5) => { bitarr_one!(1 0 0) }; - (6) => { bitarr_one!(1 0 1) }; - (7) => { bitarr_one!(1 1 0) }; - (8) => { bitarr_one!(1 1 1) }; - (20) => { bitarr_one!(1 0 0 1 1) }; - (32) => { bitarr_one!(1 1 1 1 1) }; - (64) => { bitarr_one!(1 1 1 1 1 1) }; - (112) => { bitarr_one!(1 1 0 1 1 1 1) }; - (256) => { bitarr_one!(1 1 1 1 1 1 1 1) }; - - // Incrementally convert 1 or 0 into `[0,]` or `[]` as needed for the recursion step. - // This also reverses the bit order so that the MSB comes last, as needed for recursion. - - // This passes a value back once the conversion is done. - ($([$($x:tt)*])*) => { bitarr_one!(@r $([$($x)*])*) }; - // This converts one 1 into `[0,]`. - ($([$($x:tt)*])* 1 $($y:tt)*) => { bitarr_one!([0,] $([$($x)*])* $($y)*) }; - // This converts one 0 into `[]`. - ($([$($x:tt)*])* 0 $($y:tt)*) => { bitarr_one!([] $([$($x)*])* $($y)*) }; - - // Recursion step. - - // This is where recursion ends with a `BitArray`. - (@r [$($x:tt)*]) => { bitarr![const u8, Lsb0; 1, $($x)*] }; - // This is the recursion workhorse. It takes a list of lists. The outer lists are bracketed. - // The inner lists contain any form that can be repeated and concatenated, which probably - // means comma-separated values with a trailing comma. - // The first value is repeated once. - // The second value is repeated twice and merged into the first value. - // The third and subsequent values are repeated twice and shifted along one place. - // One-valued bits are represented as `[0,]`, zero-valued bits as `[]`. - (@r [$($x:tt)*] [$($y:tt)*] $([$($z:tt)*])*) => { bitarr_one!(@r [$($x)* $($y)* $($y)*] $([$($z)* $($z)*])*) }; -} - // Macro for boolean arrays <= 128 bits. macro_rules! boolean_array_impl_small { ($modname:ident, $name:ident, $bits:tt, $deser_type:tt) => { boolean_array_impl!($modname, $name, $bits, $deser_type); - // TODO(812): remove this impl; BAs are not field elements. - impl Field for $name { - const NAME: &'static str = stringify!($name); - - const ONE: Self = Self(bitarr_one!($bits)); - } - impl U128Conversions for $name { fn truncate_from>(v: T) -> Self { let v = v.into(); @@ -193,10 +140,6 @@ macro_rules! boolean_array_impl_small { Self::truncate_from(src) } } - - impl FieldVectorizable<1> for $name { - type ArrayAlias = StdArray<$name, 1>; - } }; } @@ -237,6 +180,8 @@ macro_rules! impl_serializable_trait { #[cfg(all(test, unit_test))] mod fallible_serialization_tests { + use rand::{thread_rng, Rng}; + use super::*; /// [`https://github.com/private-attribution/ipa/issues/911`] @@ -252,6 +197,8 @@ macro_rules! impl_serializable_trait { "Padding only makes sense for lengths that are not multiples of 8." ); + let mut rng = thread_rng(); + let mut non_zero_padding = $name::ZERO.0; non_zero_padding.set($bits, true); assert_eq!( @@ -262,12 +209,13 @@ macro_rules! impl_serializable_trait { let min_value = $name::ZERO.0; deserialize(min_value).unwrap(); - let one = $name::ONE.0; - deserialize(one).unwrap(); - let mut max_value = $name::ZERO.0; max_value[..$bits].fill(true); deserialize(max_value).unwrap(); + + let mut rnd_value = $name::ZERO.0; + rnd_value[..$bits].fill_with(|_| rng.gen()); + deserialize(rnd_value).unwrap(); } } }; diff --git a/ipa-core/src/ff/ec_prime_field.rs b/ipa-core/src/ff/ec_prime_field.rs index 147defaad..dad5ca0f4 100644 --- a/ipa-core/src/ff/ec_prime_field.rs +++ b/ipa-core/src/ff/ec_prime_field.rs @@ -2,14 +2,12 @@ use std::convert::Infallible; use curve25519_dalek::scalar::Scalar; use generic_array::GenericArray; -use hkdf::Hkdf; -use sha2::Sha256; -use typenum::U32; +use typenum::{U2, U32}; use crate::{ - ff::{boolean_array::BA256, Field, Serializable, U128Conversions}, + ff::{boolean_array::BA256, Field, Serializable}, impl_shared_value_common, - protocol::prss::FromRandomU128, + protocol::prss::FromRandom, secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, }; @@ -193,43 +191,17 @@ impl Field for Fp25519 { const ONE: Fp25519 = Fp25519::ONE; } -// TODO(812): remove these impls -impl U128Conversions for Fp25519 { - ///both following methods are based on hashing and do not allow to actually convert elements in Fp25519 - /// from or into u128. However it is sufficient to generate random elements in Fp25519 - fn as_u128(&self) -> u128 { - unimplemented!() - } +impl FromRandom for Fp25519 { + type SourceLength = U2; - ///PRSS uses `truncate_from function`, we need to expand the u128 using a PRG (Sha256) to a [u8;32] - fn truncate_from>(_v: T) -> Self { - unimplemented!() - } -} - -impl FromRandomU128 for Fp25519 { - fn from_random_u128(v: u128) -> Self { - let hk = Hkdf::::new(None, &v.to_le_bytes()); - let mut okm = [0u8; 32]; - //error invalid length from expand only happens when okm is very large - hk.expand(&[], &mut okm).unwrap(); - Fp25519::deserialize_infallible(&okm.into()) - } -} - -///implement `TryFrom` since required by Field -impl TryFrom for Fp25519 { - type Error = crate::error::Error; - - fn try_from(v: u128) -> Result { - let mut bits = [0u8; 32]; - bits[..].copy_from_slice(&v.to_le_bytes()); - let f: Fp25519 = Fp25519::ONE; - f.serialize((&mut bits).into()); - Ok(f) + fn from_random(src: GenericArray) -> Self { + let mut src_bytes = [0u8; 32]; + src_bytes[0..16].copy_from_slice(&src[0].to_le_bytes()); + src_bytes[16..32].copy_from_slice(&src[1].to_le_bytes()); + // Reduces mod order + Fp25519::deserialize_infallible(<&GenericArray>::from(&src_bytes)) } } -// TODO(812): end remove impls #[cfg(all(test, unit_test))] mod test { diff --git a/ipa-core/src/ff/field.rs b/ipa-core/src/ff/field.rs index 88d41087b..a7e9932ff 100644 --- a/ipa-core/src/ff/field.rs +++ b/ipa-core/src/ff/field.rs @@ -6,8 +6,7 @@ use std::{ use typenum::{U1, U4}; use crate::{ - error, - protocol::prss::FromRandomU128, + protocol::prss::FromRandom, secret_sharing::{Block, FieldVectorizable, SharedValue, Vectorizable}, }; @@ -26,8 +25,7 @@ pub trait Field: SharedValue + Mul + MulAssign - + FromRandomU128 - + TryFrom + + FromRandom + Into + Vectorizable<1> + FieldVectorizable<1, ArrayAlias = >::Array> diff --git a/ipa-core/src/protocol/basics/check_zero.rs b/ipa-core/src/protocol/basics/check_zero.rs index 192533d85..97d9f3723 100644 --- a/ipa-core/src/protocol/basics/check_zero.rs +++ b/ipa-core/src/protocol/basics/check_zero.rs @@ -6,7 +6,7 @@ use crate::{ protocol::{ basics::{reveal::Reveal, SecureMul}, context::Context, - prss::SharedRandomness, + prss::{FromRandom, SharedRandomness}, RecordId, }, secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, @@ -47,7 +47,7 @@ pub(crate) enum Step { /// ## Errors /// Lots of things may go wrong here, from timeouts to bad output. They will be signalled /// back via the error response -pub async fn check_zero( +pub async fn check_zero( ctx: C, record_id: RecordId, v: &Replicated, diff --git a/ipa-core/src/protocol/basics/if_else.rs b/ipa-core/src/protocol/basics/if_else.rs index 3fe486983..0a2416f9f 100644 --- a/ipa-core/src/protocol/basics/if_else.rs +++ b/ipa-core/src/protocol/basics/if_else.rs @@ -1,11 +1,23 @@ use crate::{ error::Error, - ff::Field, - protocol::{basics::SecureMul, context::Context, RecordId}, - secret_sharing::{Linear as LinearSecretSharing, LinearRefOps}, + ff::{boolean::Boolean, Field}, + protocol::{ + basics::{mul::BooleanArrayMul, SecureMul}, + context::Context, + RecordId, + }, + secret_sharing::{replicated::semi_honest::AdditiveShare, LinearRefOps}, }; +/// Multiplexer. +/// /// Returns `true_value` if `condition` is a share of 1, else `false_value`. +/// If the arguments are vectors, all must have the same dimension and the +/// operation is performed element-wise. +/// +/// Each `condition` must be a share of either 0 or 1. +/// Each `true_value` and `false_value` may be any type supporting multiplication. +/// /// # Errors /// If the protocol fails to execute. pub async fn if_else( @@ -18,7 +30,7 @@ pub async fn if_else( where F: Field, C: Context, - S: LinearSecretSharing + SecureMul, + S: SecureMul, for<'a> &'a S: LinearRefOps<'a, S, F>, { // If `condition` is a share of 1 (true), then @@ -34,3 +46,45 @@ where .multiply(&(true_value - false_value), ctx, record_id) .await?) } + +/// Wide multiplexer. +/// +/// Returns `true_value` if `condition` is a share of 1, else `false_value`. +/// `condition` must be a single shared value. `true_value` and `false_value` +/// may be vectors, in which case one or the other is selected in its entirety, +/// depending on `condition`. +/// +/// `condition` must be a share of either 0 or 1. +/// `true_value` and `false_value` may be any type supporting multiplication, +/// vectors of a type supporting multiplication, or a type convertible to +/// one of those. +/// +/// # Errors +/// If the protocol fails to execute. +pub async fn select( + ctx: C, + record_id: RecordId, + condition: &AdditiveShare, + true_value: &B, + false_value: &B, +) -> Result +where + C: Context, + B: Clone + BooleanArrayMul, +{ + let false_value = false_value.clone().into(); + let true_value = true_value.clone().into(); + let condition = B::expand(condition).into(); + // If `condition` is a share of 1 (true), then + // false_value + condition * (true_value - false_value) + // = false_value + true_value - false_value + // = true_value + // + // If `condition` is a share of 0 (false), then + // false_value + condition * (true_value - false_value) + // = false_value + 0 + // = false_value + let product = B::multiply(ctx, record_id, &condition, &(true_value - &false_value)).await?; + + Ok((false_value + &product).into()) +} diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index f4db7fc9b..2b4065dd1 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -10,8 +10,8 @@ pub mod sum_of_product; #[cfg(feature = "descriptive-gate")] pub use check_zero::check_zero; -pub use if_else::if_else; -pub use mul::{MultiplyZeroPositions, SecureMul, ZeroPositions}; +pub use if_else::{if_else, select}; +pub use mul::{BooleanArrayMul, MultiplyZeroPositions, SecureMul, ZeroPositions}; pub use reshare::Reshare; pub use reveal::Reveal; pub use share_known_value::ShareKnownValue; diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index acab715c2..443982e00 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -1,8 +1,16 @@ +use std::ops::{Add, Sub}; + use async_trait::async_trait; use crate::{ error::Error, + ff::{ + boolean::Boolean, + boolean_array::{BA20, BA3, BA32, BA5, BA8}, + Expand, + }, protocol::{context::Context, RecordId}, + secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, }; #[cfg(feature = "descriptive-gate")] @@ -39,3 +47,55 @@ pub trait SecureMul: Send + Sync + Sized { where C: 'fut; } + +use semi_honest::multiply as semi_honest_mul; + +// The BooleanArrayMul trait is implemented for types like `Replicated`. It hides the `N` +// const parameter so that implementations parameterized with a Boolean array type parameter (e.g. +// breakdown key type BK is BA8) can invoke vectorized multiply. Without this trait, those +// implementations would need to specify the `N` const parameter, which is tricky, because you +// can't supply an expression involving a type parameter (BK::BITS) as a const parameter. +#[async_trait] +pub trait BooleanArrayMul: + Expand> + From + Into +{ + type Vectorized: Send + + for<'a> Add<&'a Self::Vectorized, Output = Self::Vectorized> + + for<'a> Sub<&'a Self::Vectorized, Output = Self::Vectorized>; + + async fn multiply<'fut, C>( + ctx: C, + record_id: RecordId, + a: &'fut Self::Vectorized, + b: &'fut Self::Vectorized, + ) -> Result + where + C: Context + 'fut; +} + +macro_rules! boolean_array_mul { + ($dim:expr, $vec:ty) => { + #[async_trait] + impl BooleanArrayMul for Replicated<$vec> { + type Vectorized = Replicated; + + async fn multiply<'fut, C>( + ctx: C, + record_id: RecordId, + a: &'fut Self::Vectorized, + b: &'fut Self::Vectorized, + ) -> Result + where + C: Context + 'fut, + { + semi_honest_mul(ctx, record_id, a, b, ZeroPositions::NONE).await + } + } + }; +} + +boolean_array_mul!(3, BA3); +boolean_array_mul!(5, BA5); +boolean_array_mul!(8, BA8); +boolean_array_mul!(20, BA20); +boolean_array_mul!(32, BA32); diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs index e4b937790..dc27e6ef7 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/addition_sequential.rs @@ -1,6 +1,8 @@ #[cfg(all(test, unit_test))] use ipa_macros::Step; +#[cfg(all(test, unit_test))] +use crate::secret_sharing::{FieldSimd, FieldVectorizable}; use crate::{ error::Error, ff::{ArrayAccess, CustomArray, Field}, @@ -46,19 +48,20 @@ where /// # Errors /// propagates errors from multiply #[cfg(all(test, unit_test))] -pub async fn integer_sat_add( +pub async fn integer_sat_add( ctx: C, record_id: RecordId, x: &AdditiveShare, y: &AdditiveShare, ) -> Result, Error> where + F: Field + FieldSimd + FieldVectorizable, C: Context, - S: CustomArray + Field, - S::Element: Field, + S: SharedValue + CustomArray, + AdditiveShare: From> + Into>, { use crate::{ff::Expand, protocol::basics::if_else}; - let mut carry = AdditiveShare::::ZERO; + let mut carry = AdditiveShare::::ZERO; let result = addition_circuit( ctx.narrow(&Step::SaturatedAddition), record_id, @@ -66,10 +69,11 @@ where y, &mut carry, ) - .await?; + .await? + .into(); // expand carry bit to array - let carry_array = AdditiveShare::::expand(&carry); + let carry_array = AdditiveShare::::expand(&carry).into(); // if carry_array==1 then {carry_array} else {result}: if_else( @@ -80,6 +84,7 @@ where &result, ) .await + .map(Into::into) } /// addition using bit adder @@ -237,7 +242,7 @@ mod test { let result = world .semi_honest((x_ba64, y_ba64), |ctx, x_y| async move { - integer_sat_add::<_, BA64>( + integer_sat_add::<_, _, _, 64>( ctx.set_total_records(1), protocol::RecordId(0), &x_y.0, diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/comparison_and_subtraction_sequential.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/comparison_and_subtraction_sequential.rs index 82b68e0e3..21b86aed5 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/comparison_and_subtraction_sequential.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/comparison_and_subtraction_sequential.rs @@ -9,6 +9,8 @@ use std::{borrow::Borrow, iter::repeat, ops::Not}; #[cfg(all(test, unit_test))] use ipa_macros::Step; +#[cfg(all(test, unit_test))] +use crate::secret_sharing::FieldVectorizable; use crate::{ error::Error, ff::{ArrayAccessRef, ArrayBuild, ArrayBuilder, Field}, @@ -113,22 +115,22 @@ where /// # Errors /// propagates errors from multiply #[cfg(all(test, unit_test))] -pub async fn integer_sat_sub( +pub async fn integer_sat_sub( ctx: C, record_id: RecordId, x: &AdditiveShare, y: &AdditiveShare, ) -> Result, Error> where + F: Field + FieldSimd + FieldVectorizable, C: Context, - S::Element: Field, - S: SharedValue + CustomArray + Expand, - AdditiveShare: SecureMul - + ArrayAccessRef> - + ArrayBuild>, - AdditiveShare: SecureMul + Not>, + S: SharedValue + CustomArray, + AdditiveShare: + ArrayAccessRef> + ArrayBuild>, + AdditiveShare: SecureMul + Not>, + AdditiveShare: From> + Into>, { - let mut carry = AdditiveShare::::share_known_value(&ctx, S::Element::ONE); + let mut carry = AdditiveShare::::share_known_value(&ctx, F::ONE); let result = subtraction_circuit( ctx.narrow(&Step::SaturatedSubtraction), record_id, @@ -136,14 +138,17 @@ where y, &mut carry, ) - .await?; + .await? + .into(); // carry computes carry=(x>=y) // if carry==0 {all 0 array, i.e. Array[carry]} else {result}: // compute (1-carry)*Array[carry]+carry*result =carry*result AdditiveShare::::expand(&carry) + .into() .multiply(&result, ctx.narrow(&Step::MultiplyWithCarry), record_id) .await + .map(Into::into) } /// subtraction using bit subtractor @@ -569,7 +574,7 @@ mod test { let result = world .semi_honest(records.into_iter(), |ctx, x_y| async move { - integer_sat_sub::<_, BA64>( + integer_sat_sub::<_, _, _, 64>( ctx.set_total_records(1), protocol::RecordId(0), &x_y[0], diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index 95fcf4417..9ac7ca68a 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -6,7 +6,7 @@ use crate::{ error::Error, ff::{ boolean::Boolean, boolean_array::BA256, ec_prime_field::Fp25519, ArrayAccess, CustomArray, - Expand, Field, + Expand, }, helpers::Role, protocol::{ @@ -99,7 +99,7 @@ pub async fn convert_to_fp25519( ) -> Result, Error> where C: Context, - B: SharedValue + CustomArray + Field, + B: SharedValue + CustomArray, { // generate sh_r = (0, 0, sh_r) and sh_s = (sh_s, 0, 0) // the two highest bits are set to 0 to allow carries for two additions diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 8068a5684..2c241f9ad 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -8,10 +8,11 @@ use self::{quicksort::quicksort_ranges_by_key_insecure, shuffle::shuffle_inputs} use crate::{ error::{Error, UnwrapInfallible}, ff::{ - boolean::Boolean, boolean_array::BA64, CustomArray, Field, PrimeField, Serializable, + boolean::Boolean, boolean_array::BA64, CustomArray, PrimeField, Serializable, U128Conversions, }, protocol::{ + basics::BooleanArrayMul, context::{UpgradableContext, UpgradedContext}, ipa_prf::{ boolean_ops::convert_to_fp25519, @@ -171,10 +172,13 @@ where C: UpgradableContext, C::UpgradedContext: UpgradedContext>, C::UpgradedContext: UpgradedContext>, - BK: SharedValue + U128Conversions + CustomArray + Field, - TV: SharedValue + U128Conversions + CustomArray + Field, - TS: SharedValue + U128Conversions + CustomArray + Field, - SS: SharedValue + U128Conversions + CustomArray + Field, + BK: SharedValue + U128Conversions + CustomArray, + TV: SharedValue + U128Conversions + CustomArray, + TS: SharedValue + U128Conversions + CustomArray, + SS: SharedValue + U128Conversions + CustomArray, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, F: PrimeField + ExtendableField, Replicated: Serializable, { @@ -212,9 +216,9 @@ where C: UpgradableContext, C::UpgradedContext: UpgradedContext>, C::UpgradedContext: UpgradedContext>, - BK: SharedValue + CustomArray + Field, - TV: SharedValue + CustomArray + Field, - TS: SharedValue + CustomArray + Field, + BK: SharedValue + CustomArray, + TV: SharedValue + CustomArray, + TS: SharedValue + CustomArray, F: PrimeField + ExtendableField, Replicated: Serializable, { diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs index c706395c4..8cb9bec5a 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/feature_label_dot_product.rs @@ -6,9 +6,9 @@ use ipa_macros::Step; use crate::{ error::Error, - ff::{boolean::Boolean, CustomArray, Expand, Field, PrimeField, Serializable}, + ff::{boolean::Boolean, CustomArray, Field, PrimeField, Serializable}, protocol::{ - basics::{if_else, SecureMul, ShareKnownValue}, + basics::{select, BooleanArrayMul, SecureMul, ShareKnownValue}, boolean::or::or, context::{Context, UpgradableContext, UpgradedContext, Validator}, modulus_conversion::convert_bits, @@ -57,7 +57,8 @@ impl InputsRequiredFromPrevRow { ) -> Result, Error> where C: UpgradedContext>, - FV: CustomArray + Field, + FV: SharedValue + CustomArray, + Replicated: BooleanArrayMul, { let share_of_one = Replicated::share_known_value(&ctx, Boolean::ONE); let is_source_event = &share_of_one - &input_row.is_trigger_bit; @@ -92,11 +93,10 @@ impl InputsRequiredFromPrevRow { ) .await?; - let capped_label_vector = Replicated::::expand(&capped_label); - let capped_attributed_feature_vector = if_else( + let capped_attributed_feature_vector = select( ctx.narrow(&Step::ComputedCappedFeatureVector), record_id, - &capped_label_vector, + &capped_label, &input_row.feature_vector, &Replicated::::ZERO, ) @@ -223,7 +223,8 @@ where C::UpgradedContext: UpgradedContext>, C::UpgradedContext: UpgradedContext, S: LinearSecretSharing + Serializable + SecureMul>, - FV: CustomArray + Field, + FV: SharedValue + CustomArray, + Replicated: BooleanArrayMul, F: PrimeField + ExtendableField, { assert!(::BITS > 0); @@ -300,7 +301,8 @@ async fn evaluate_per_user_attribution_circuit( ) -> Result>, Error> where C: UpgradedContext>, - FV: CustomArray + Field, + FV: SharedValue + CustomArray, + Replicated: BooleanArrayMul, { assert!(!rows_for_user.is_empty()); if rows_for_user.len() == 1 { diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 52a50b4d1..eb7a9d9e1 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -20,7 +20,7 @@ use crate::{ }, helpers::Role, protocol::{ - basics::{if_else, SecureMul, ShareKnownValue}, + basics::{select, BooleanArrayMul, SecureMul, ShareKnownValue}, boolean::or::or, context::{Context, UpgradableContext, UpgradedContext, Validator}, ipa_prf::boolean_ops::{ @@ -102,10 +102,13 @@ struct InputsRequiredFromPrevRow InputsRequiredFromPrevRow where - BK: SharedValue + U128Conversions + CustomArray + Field, - TV: SharedValue + U128Conversions + CustomArray + Field, - TS: SharedValue + U128Conversions + CustomArray + Field, - SS: SharedValue + U128Conversions + CustomArray + Field, + BK: SharedValue + U128Conversions + CustomArray, + TV: SharedValue + U128Conversions + CustomArray, + TS: SharedValue + U128Conversions + CustomArray, + SS: SharedValue + U128Conversions + CustomArray, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, { /// /// This function contains the main logic for the per-user attribution circuit. @@ -466,10 +469,13 @@ where C::UpgradedContext: UpgradedContext>, C::UpgradedContext: UpgradedContext, S: LinearSecretSharing + Serializable + SecureMul>, - BK: SharedValue + U128Conversions + CustomArray + Field, - TV: SharedValue + U128Conversions + CustomArray + Field, - TS: SharedValue + U128Conversions + CustomArray + Field, - SS: SharedValue + U128Conversions + CustomArray + Field, + BK: SharedValue + U128Conversions + CustomArray, + TV: SharedValue + U128Conversions + CustomArray, + TS: SharedValue + U128Conversions + CustomArray, + SS: SharedValue + U128Conversions + CustomArray, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, F: PrimeField + ExtendableField, { // Get the validator and context to use for Boolean multiplication operations @@ -572,10 +578,13 @@ async fn evaluate_per_user_attribution_circuit( ) -> Result>, Error> where C: Context, - BK: SharedValue + U128Conversions + CustomArray + Field, - TV: SharedValue + U128Conversions + CustomArray + Field, - TS: SharedValue + U128Conversions + CustomArray + Field, - SS: SharedValue + U128Conversions + CustomArray + Field, + BK: SharedValue + U128Conversions + CustomArray, + TV: SharedValue + U128Conversions + CustomArray, + TS: SharedValue + U128Conversions + CustomArray, + SS: SharedValue + U128Conversions + CustomArray, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, + Replicated: BooleanArrayMul, { assert!(!rows_for_user.is_empty()); if rows_for_user.len() == 1 { @@ -644,14 +653,13 @@ async fn breakdown_key_of_most_recent_source_event( ) -> Result, Error> where C: Context, - BK: SharedValue + CustomArray + Field, + BK: SharedValue + CustomArray, + Replicated: BooleanArrayMul, { - let is_trigger_bit_array = Replicated::::expand(is_trigger_bit); - - if_else( + select( ctx, record_id, - &is_trigger_bit_array, + is_trigger_bit, prev_row_breakdown_key_bits, cur_row_breakdown_key_bits, ) @@ -670,17 +678,16 @@ async fn timestamp_of_most_recent_source_event( ) -> Result, Error> where C: Context, - TS: SharedValue + U128Conversions + CustomArray + Field, + TS: SharedValue + U128Conversions + CustomArray, + Replicated: BooleanArrayMul, { match attribution_window_seconds { None => Ok(prev_row_timestamp_bits.clone()), Some(_) => { - let is_trigger_bit_array = Replicated::::expand(is_trigger_bit); - - if_else( + select( ctx, record_id, - &is_trigger_bit_array, + is_trigger_bit, prev_row_timestamp_bits, cur_row_timestamp_bits, ) @@ -711,8 +718,9 @@ async fn zero_out_trigger_value_unless_attributed( ) -> Result, Error> where C: Context, - TV: SharedValue + U128Conversions + CustomArray + Field, - TS: SharedValue + U128Conversions + CustomArray + Field, + TV: SharedValue + U128Conversions + CustomArray, + TS: SharedValue + U128Conversions + CustomArray, + Replicated: BooleanArrayMul, { let (did_trigger_get_attributed, is_trigger_within_window) = try_join( is_trigger_bit.multiply( @@ -740,12 +748,10 @@ where did_trigger_get_attributed.clone() }; - let zero_out_flag_array = Replicated::::expand(&zero_out_flag); - - if_else( + select( ctx, record_id, - &zero_out_flag_array, + &zero_out_flag, trigger_value, &Replicated::::ZERO, ) @@ -765,7 +771,7 @@ async fn is_trigger_event_within_attribution_window( ) -> Result, Error> where C: Context, - TS: SharedValue + U128Conversions + CustomArray + Field, + TS: SharedValue + U128Conversions + CustomArray, { if let Some(attribution_window_seconds) = attribution_window_seconds { let time_delta_bits = integer_sub( @@ -821,29 +827,25 @@ async fn compute_capped_trigger_value( ) -> Result, Error> where C: Context, - TV: SharedValue + U128Conversions + CustomArray + Field, + TV: SharedValue + U128Conversions + CustomArray, + Replicated: BooleanArrayMul, { let narrowed_ctx1 = ctx.narrow(&Step::ComputedCappedAttributedTriggerValueNotSaturatedCase); let narrowed_ctx2 = ctx.narrow(&Step::ComputedCappedAttributedTriggerValueJustSaturatedCase); - let is_saturated_array = Replicated::::expand(is_saturated); - - let is_saturated_and_prev_row_not_saturated_array = - Replicated::::expand(is_saturated_and_prev_row_not_saturated); - - let attributed_trigger_value_or_zero = if_else( + let attributed_trigger_value_or_zero = select( narrowed_ctx1, record_id, - &is_saturated_array, + is_saturated, &Replicated::new(::ZERO, ::ZERO), attributed_trigger_value, ) .await?; - if_else( + select( narrowed_ctx2, record_id, - &is_saturated_and_prev_row_not_saturated_array, + is_saturated_and_prev_row_not_saturated, prev_row_diff_to_cap, &attributed_trigger_value_or_zero, ) @@ -885,7 +887,7 @@ pub mod tests { trigger_value: u8, ) -> PreShardedAndSortedOPRFTestInput where - BK: SharedValue + U128Conversions + Field, + BK: SharedValue + U128Conversions, { oprf_test_input_with_timestamp( prf_of_match_key, @@ -904,7 +906,7 @@ pub mod tests { timestamp: u32, ) -> PreShardedAndSortedOPRFTestInput where - BK: SharedValue + U128Conversions + Field, + BK: SharedValue + U128Conversions, { let is_trigger_bit = if is_trigger { Boolean::ONE @@ -981,8 +983,8 @@ pub mod tests { impl Reconstruct for [&CappedAttributionOutputs; 3] where - BK: SharedValue + U128Conversions + CustomArray + Field, - TV: SharedValue + U128Conversions + CustomArray + Field, + BK: SharedValue + U128Conversions + CustomArray, + TV: SharedValue + U128Conversions + CustomArray, { fn reconstruct(&self) -> PreAggregationTestOutputInDecimal { let [s0, s1, s2] = self; diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index 3c9caaf42..dbc0ca189 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -155,7 +155,10 @@ where #[cfg(all(test, unit_test))] pub mod tests { - use std::cmp::Ordering; + use std::{ + cmp::Ordering, + iter::{repeat, repeat_with}, + }; use ipa_macros::Step; use rand::Rng; @@ -163,7 +166,7 @@ pub mod tests { use crate::{ ff::{ boolean_array::{BA20, BA64}, - Field, U128Conversions, + U128Conversions, }, protocol::{context::Context, ipa_prf::quicksort::quicksort_ranges_by_key_insecure}, rand::thread_rng, @@ -188,8 +191,7 @@ pub mod tests { for desc in bools { // generate vector of random values - let mut records: Vec = vec![::ONE; 20]; - records.iter_mut().for_each(|x| *x = rng.gen::()); + let records: Vec = repeat_with(|| rng.gen()).take(20).collect(); // convert expected into more readable format let mut expected: Vec = @@ -236,10 +238,8 @@ pub mod tests { let bools = vec![false, true]; for desc in bools { - // generate vector of random values - let element = rng.gen::(); - let mut records: Vec = vec![::ONE; 20]; - records.iter_mut().for_each(|x| *x = element); + // generate vector of 20 copies of same random value + let records: Vec = repeat(rng.gen()).take(20).collect(); // convert expected into more readable format let mut expected: Vec = @@ -334,8 +334,7 @@ pub mod tests { for desc in bools { // generate vector of random values - let mut records: Vec = vec![::ONE; 20]; - records.iter_mut().for_each(|x| *x = rng.gen::()); + let records: Vec = repeat_with(|| rng.gen()).take(20).collect(); // convert expected into more readable format let mut expected: Vec = diff --git a/ipa-core/src/protocol/prss/crypto.rs b/ipa-core/src/protocol/prss/crypto.rs index 25e68a1fe..a332b9bea 100644 --- a/ipa-core/src/protocol/prss/crypto.rs +++ b/ipa-core/src/protocol/prss/crypto.rs @@ -140,7 +140,7 @@ pub trait SharedRandomness { // Equivalent functionality could be obtained by defining an `Unreplicated` type that // implements `FromPrss`. #[must_use] - fn zero>(&self, index: I) -> V { + fn zero>(&self, index: I) -> V { let (l, r): (V, V) = self.generate(index); l - r } diff --git a/ipa-core/src/secret_sharing/array.rs b/ipa-core/src/secret_sharing/array.rs index effce79f1..f38e608c0 100644 --- a/ipa-core/src/secret_sharing/array.rs +++ b/ipa-core/src/secret_sharing/array.rs @@ -6,7 +6,7 @@ use std::{ }; use generic_array::{ArrayLength, GenericArray}; -use typenum::{U1, U32}; +use typenum::U32; use crate::{ error::LengthError, @@ -270,9 +270,9 @@ impl, const N: usize> Not for StdArray { } } -impl> FromRandom for StdArray { - type SourceLength = U1; - fn from_random(src: GenericArray) -> Self { +impl FromRandom for StdArray { + type SourceLength = ::SourceLength; + fn from_random(src: GenericArray) -> Self { Self([F::from_random(src)]) } } diff --git a/ipa-core/src/secret_sharing/mod.rs b/ipa-core/src/secret_sharing/mod.rs index a7ac163b9..7fa77a89f 100644 --- a/ipa-core/src/secret_sharing/mod.rs +++ b/ipa-core/src/secret_sharing/mod.rs @@ -22,8 +22,12 @@ //! a trait bound, something like `F: Field + FieldSimd`. //! //! The other traits are `Vectorizable` (for `SharedValue`s) and `FieldVectorizable`. These traits -//! are needed to work around a limitation in the rust type system. See the `FieldVectorizable` -//! documentation for details. +//! are needed to work around a limitation in the rust type system. In most cases, you do not need +//! to reference the `Vectorizable` or `FieldVectorizable` traits directly when implementing +//! protocols. Usually the vector type is hidden within `AdditiveShare`, but if you are writing a +//! vectorized low-level primitive, you may need to refer to it directly, as `>::Array`. It is even more rare to need to use `FieldVectorizable`; see its +//! documentation and the documentation of `FieldSimd` for details. //! //! We require that each supported vectorization configuration (i.e. combination of data type and //! width) be explicitly identified, by implementing the `Vectorizable` and/or `FieldVectorizable` @@ -64,14 +68,17 @@ use rand::{ distributions::{Distribution, Standard}, Rng, }; -#[cfg(any(test, feature = "test-fixture", feature = "cli"))] -use replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}; pub use scheme::{Bitwise, Linear, LinearRefOps, SecretSharing}; use crate::{ error::LengthError, - ff::{boolean::Boolean, AddSub, AddSubAssign, Field, Fp32BitPrime, Serializable}, + ff::{ + boolean::Boolean, + boolean_array::{BA20, BA256, BA3, BA32, BA5, BA64, BA8}, + AddSub, AddSubAssign, Field, Fp32BitPrime, Serializable, + }, protocol::prss::FromRandom, + secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, }; /// Operations supported for weak shared values. @@ -190,13 +197,35 @@ pub trait Vectorizable: Sized { /// `Array` so that references to the `Array` associated type do not require qualification /// with a trait name. /// 3. `F: Vectorizable`. This is implied by the previous two, because `FieldArray` -/// is a sub-trait of `SharedValueArray`. +/// is a sub-trait of `SharedValueArray`. (See the `FieldSimd` documentation for another +/// important consequence of this sub-trait relationship.) pub trait FieldVectorizable: SharedValue + Sized { type ArrayAlias: FieldArray; } +// Convenience alias to express a supported vectorization when writing protocols. +// +// Typically appears like this: `F: Field + FieldSimd`. +// // We could define a `SharedValueSimd` trait that is the analog of this for `SharedValue`s, but // there are not currently any protocols that need it. +// +// Because we have constrained the associated types Vectorizable::Array and +// FieldVectorizable::ArrayAlias to be equal, the type they refer to must satisfy the union of all +// trait bounds applicable to either. However, in some cases the compiler has trouble proving +// properties related to this. (See rust issues [41118] and [60471].) A typical workaround for +// problems of this sort is to redundantly list a trait bound on both associated types, but for us +// that is not necessary in most cases because `FieldArray` is a sub-trait of `SharedValueArray`. +// +// Another consequence of this limitation of the compiler is that if you write the bound `F: Field + +// FieldSimd + Vectorizable`, you will get the error ``type annotations needed: +// cannot satisfy `>::Array == >::ArrayAlias```. The compiler is not smart enough to +// coalesce the constraints and see that `S`, `::Array`, and `::ArrayAlias` must all to refer to the same type. +// +// [41118](https://github.com/rust-lang/rust/issues/41118) +// [60471](https://github.com/rust-lang/rust/issues/60471) pub trait FieldSimd: Field + Vectorizable>::ArrayAlias> + FieldVectorizable { @@ -211,9 +240,39 @@ impl FieldSimd<1> for F {} impl FieldSimd<32> for Fp32BitPrime {} -impl FieldSimd<64> for Boolean {} +macro_rules! boolean_vector { + ($dim:expr, $vec:ty) => { + impl Vectorizable<$dim> for Boolean { + type Array = $vec; + } + + impl FieldVectorizable<$dim> for Boolean { + type ArrayAlias = $vec; + } + + impl FieldSimd<$dim> for Boolean {} + + impl From> for AdditiveShare { + fn from(value: AdditiveShare<$vec>) -> Self { + AdditiveShare::new_arr(value.left(), value.right()) + } + } + + impl From> for AdditiveShare<$vec> { + fn from(value: AdditiveShare) -> Self { + AdditiveShare::new(*value.left_arr(), *value.right_arr()) + } + } + }; +} -impl FieldSimd<256> for Boolean {} +boolean_vector!(3, BA3); +boolean_vector!(5, BA5); +boolean_vector!(8, BA8); +boolean_vector!(20, BA20); +boolean_vector!(32, BA32); +boolean_vector!(64, BA64); +boolean_vector!(256, BA256); pub trait SharedValueArray: Clone diff --git a/ipa-core/src/test_fixture/input/sharing.rs b/ipa-core/src/test_fixture/input/sharing.rs index 1d674efff..eb706bd4f 100644 --- a/ipa-core/src/test_fixture/input/sharing.rs +++ b/ipa-core/src/test_fixture/input/sharing.rs @@ -1,7 +1,7 @@ use std::iter::{repeat, zip}; use crate::{ - ff::{boolean::Boolean, boolean_array::BA64, Field, U128Conversions}, + ff::{boolean::Boolean, boolean_array::BA64, U128Conversions}, protocol::ipa_prf::OPRFIPAInputRow, rand::Rng, report::{EventType, OprfReport}, @@ -67,9 +67,9 @@ where impl IntoShares> for TestRawDataRecord where - BK: SharedValue + Field + IntoShares>, - TV: SharedValue + Field + IntoShares>, - TS: SharedValue + Field + IntoShares>, + BK: SharedValue + U128Conversions + IntoShares>, + TV: SharedValue + U128Conversions + IntoShares>, + TS: SharedValue + U128Conversions + IntoShares>, { fn share_with(self, rng: &mut R) -> [OPRFIPAInputRow; 3] { let is_trigger = Replicated::new( From 38e2461ae200ccbe7a27a87e9fbac646b53d874a Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 21 Feb 2024 12:08:01 -0800 Subject: [PATCH 2/3] Eliminate async_trait for BooleanArrayMul --- ipa-core/src/protocol/basics/if_else.rs | 9 ++++-- ipa-core/src/protocol/basics/mul/mod.rs | 37 +++++++++++++++++++------ 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/ipa-core/src/protocol/basics/if_else.rs b/ipa-core/src/protocol/basics/if_else.rs index 0a2416f9f..ee89c19ea 100644 --- a/ipa-core/src/protocol/basics/if_else.rs +++ b/ipa-core/src/protocol/basics/if_else.rs @@ -2,7 +2,10 @@ use crate::{ error::Error, ff::{boolean::Boolean, Field}, protocol::{ - basics::{mul::BooleanArrayMul, SecureMul}, + basics::{ + mul::{boolean_array_multiply, BooleanArrayMul}, + SecureMul, + }, context::Context, RecordId, }, @@ -84,7 +87,9 @@ where // false_value + condition * (true_value - false_value) // = false_value + 0 // = false_value - let product = B::multiply(ctx, record_id, &condition, &(true_value - &false_value)).await?; + let product = + boolean_array_multiply::<_, B>(ctx, record_id, &condition, &(true_value - &false_value)) + .await?; Ok((false_value + &product).into()) } diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 443982e00..82b6e9420 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -1,4 +1,7 @@ -use std::ops::{Add, Sub}; +use std::{ + future::Future, + ops::{Add, Sub}, +}; use async_trait::async_trait; @@ -55,40 +58,56 @@ use semi_honest::multiply as semi_honest_mul; // breakdown key type BK is BA8) can invoke vectorized multiply. Without this trait, those // implementations would need to specify the `N` const parameter, which is tricky, because you // can't supply an expression involving a type parameter (BK::BITS) as a const parameter. -#[async_trait] pub trait BooleanArrayMul: Expand> + From + Into { type Vectorized: Send + + Sync + for<'a> Add<&'a Self::Vectorized, Output = Self::Vectorized> - + for<'a> Sub<&'a Self::Vectorized, Output = Self::Vectorized>; + + for<'a> Sub<&'a Self::Vectorized, Output = Self::Vectorized> + + 'static; - async fn multiply<'fut, C>( + fn multiply<'fut, C>( ctx: C, record_id: RecordId, a: &'fut Self::Vectorized, b: &'fut Self::Vectorized, - ) -> Result + ) -> impl Future> + Send + 'fut where C: Context + 'fut; } +// Workaround for https://github.com/rust-lang/rust/issues/100013. Calling this wrapper function +// instead of `<_ as BooleanArrayMul>::multiply` seems to hide the BooleanArrayMul `impl Future` +// GAT. +pub fn boolean_array_multiply<'fut, C, B>( + ctx: C, + record_id: RecordId, + a: &'fut B::Vectorized, + b: &'fut B::Vectorized, +) -> impl Future> + Send + 'fut +where + C: Context + 'fut, + B: BooleanArrayMul, +{ + B::multiply(ctx, record_id, a, b) +} + macro_rules! boolean_array_mul { ($dim:expr, $vec:ty) => { - #[async_trait] impl BooleanArrayMul for Replicated<$vec> { type Vectorized = Replicated; - async fn multiply<'fut, C>( + fn multiply<'fut, C>( ctx: C, record_id: RecordId, a: &'fut Self::Vectorized, b: &'fut Self::Vectorized, - ) -> Result + ) -> impl Future> + Send + 'fut where C: Context + 'fut, { - semi_honest_mul(ctx, record_id, a, b, ZeroPositions::NONE).await + semi_honest_mul(ctx, record_id, a, b, ZeroPositions::NONE) } } }; From fb191e0168a1d42c41678eb8840d7d7ab0bbc18a Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 23 Feb 2024 09:58:46 -0800 Subject: [PATCH 3/3] PR feedback --- ipa-core/src/ff/boolean_array.rs | 8 -------- ipa-core/src/protocol/basics/if_else.rs | 6 +++--- ipa-core/src/protocol/basics/mul/mod.rs | 9 ++++----- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/ipa-core/src/ff/boolean_array.rs b/ipa-core/src/ff/boolean_array.rs index ba43f05ea..78265ca27 100644 --- a/ipa-core/src/ff/boolean_array.rs +++ b/ipa-core/src/ff/boolean_array.rs @@ -180,8 +180,6 @@ macro_rules! impl_serializable_trait { #[cfg(all(test, unit_test))] mod fallible_serialization_tests { - use rand::{thread_rng, Rng}; - use super::*; /// [`https://github.com/private-attribution/ipa/issues/911`] @@ -197,8 +195,6 @@ macro_rules! impl_serializable_trait { "Padding only makes sense for lengths that are not multiples of 8." ); - let mut rng = thread_rng(); - let mut non_zero_padding = $name::ZERO.0; non_zero_padding.set($bits, true); assert_eq!( @@ -212,10 +208,6 @@ macro_rules! impl_serializable_trait { let mut max_value = $name::ZERO.0; max_value[..$bits].fill(true); deserialize(max_value).unwrap(); - - let mut rnd_value = $name::ZERO.0; - rnd_value[..$bits].fill_with(|_| rng.gen()); - deserialize(rnd_value).unwrap(); } } }; diff --git a/ipa-core/src/protocol/basics/if_else.rs b/ipa-core/src/protocol/basics/if_else.rs index ee89c19ea..33a467908 100644 --- a/ipa-core/src/protocol/basics/if_else.rs +++ b/ipa-core/src/protocol/basics/if_else.rs @@ -75,9 +75,9 @@ where C: Context, B: Clone + BooleanArrayMul, { - let false_value = false_value.clone().into(); - let true_value = true_value.clone().into(); - let condition = B::expand(condition).into(); + let false_value = B::Vectorized::from(false_value.clone()); + let true_value = B::Vectorized::from(true_value.clone()); + let condition = B::Vectorized::from(B::expand(condition)); // If `condition` is a share of 1 (true), then // false_value + condition * (true_value - false_value) // = false_value + true_value - false_value diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 82b6e9420..e77672b6c 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -58,13 +58,12 @@ use semi_honest::multiply as semi_honest_mul; // breakdown key type BK is BA8) can invoke vectorized multiply. Without this trait, those // implementations would need to specify the `N` const parameter, which is tricky, because you // can't supply an expression involving a type parameter (BK::BITS) as a const parameter. -pub trait BooleanArrayMul: - Expand> + From + Into -{ - type Vectorized: Send - + Sync +pub trait BooleanArrayMul: Expand> + From { + type Vectorized: From + for<'a> Add<&'a Self::Vectorized, Output = Self::Vectorized> + for<'a> Sub<&'a Self::Vectorized, Output = Self::Vectorized> + + Send + + Sync + 'static; fn multiply<'fut, C>(