diff --git a/ipa-core/src/protocol/basics/if_else.rs b/ipa-core/src/protocol/basics/if_else.rs index 0a2416f9f..ee89c19ea 100644 --- a/ipa-core/src/protocol/basics/if_else.rs +++ b/ipa-core/src/protocol/basics/if_else.rs @@ -2,7 +2,10 @@ use crate::{ error::Error, ff::{boolean::Boolean, Field}, protocol::{ - basics::{mul::BooleanArrayMul, SecureMul}, + basics::{ + mul::{boolean_array_multiply, BooleanArrayMul}, + SecureMul, + }, context::Context, RecordId, }, @@ -84,7 +87,9 @@ where // false_value + condition * (true_value - false_value) // = false_value + 0 // = false_value - let product = B::multiply(ctx, record_id, &condition, &(true_value - &false_value)).await?; + let product = + boolean_array_multiply::<_, B>(ctx, record_id, &condition, &(true_value - &false_value)) + .await?; Ok((false_value + &product).into()) } diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 443982e00..82b6e9420 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -1,4 +1,7 @@ -use std::ops::{Add, Sub}; +use std::{ + future::Future, + ops::{Add, Sub}, +}; use async_trait::async_trait; @@ -55,40 +58,56 @@ use semi_honest::multiply as semi_honest_mul; // breakdown key type BK is BA8) can invoke vectorized multiply. Without this trait, those // implementations would need to specify the `N` const parameter, which is tricky, because you // can't supply an expression involving a type parameter (BK::BITS) as a const parameter. -#[async_trait] pub trait BooleanArrayMul: Expand> + From + Into { type Vectorized: Send + + Sync + for<'a> Add<&'a Self::Vectorized, Output = Self::Vectorized> - + for<'a> Sub<&'a Self::Vectorized, Output = Self::Vectorized>; + + for<'a> Sub<&'a Self::Vectorized, Output = Self::Vectorized> + + 'static; - async fn multiply<'fut, C>( + fn multiply<'fut, C>( ctx: C, record_id: RecordId, a: &'fut Self::Vectorized, b: &'fut Self::Vectorized, - ) -> Result + ) -> impl Future> + Send + 'fut where C: Context + 'fut; } +// Workaround for https://github.com/rust-lang/rust/issues/100013. Calling this wrapper function +// instead of `<_ as BooleanArrayMul>::multiply` seems to hide the BooleanArrayMul `impl Future` +// GAT. +pub fn boolean_array_multiply<'fut, C, B>( + ctx: C, + record_id: RecordId, + a: &'fut B::Vectorized, + b: &'fut B::Vectorized, +) -> impl Future> + Send + 'fut +where + C: Context + 'fut, + B: BooleanArrayMul, +{ + B::multiply(ctx, record_id, a, b) +} + macro_rules! boolean_array_mul { ($dim:expr, $vec:ty) => { - #[async_trait] impl BooleanArrayMul for Replicated<$vec> { type Vectorized = Replicated; - async fn multiply<'fut, C>( + fn multiply<'fut, C>( ctx: C, record_id: RecordId, a: &'fut Self::Vectorized, b: &'fut Self::Vectorized, - ) -> Result + ) -> impl Future> + Send + 'fut where C: Context + 'fut, { - semi_honest_mul(ctx, record_id, a, b, ZeroPositions::NONE).await + semi_honest_mul(ctx, record_id, a, b, ZeroPositions::NONE) } } };