diff --git a/src/protocol/aggregation/mod.rs b/src/protocol/aggregation/mod.rs index d824ec3aa..c1fd6a8e0 100644 --- a/src/protocol/aggregation/mod.rs +++ b/src/protocol/aggregation/mod.rs @@ -6,7 +6,7 @@ pub use input::SparseAggregateInputRow; use ipa_macros::step; use strum::AsRefStr; -use super::{context::Context, sort::check_everything, step::BitOpStep, RecordId}; +use super::{context::Context, sort::bitwise_to_onehot, step::BitOpStep, RecordId}; use crate::{ error::Error, ff::{Field, GaloisField, Gf2, PrimeField, Serializable}, @@ -71,9 +71,13 @@ where // convert the input from `[Z2]^u` into `[Zp]^u` let (converted_value_bits, converted_breakdown_key_bits) = ( - upgrade_bit_shares(ctx.narrow(&Step::ConvertValueBits), contributions, CV::BITS), upgrade_bit_shares( - ctx.narrow(&Step::ConvertBreakdownKeyBits), + &ctx.narrow(&Step::ConvertValueBits), + contributions, + CV::BITS, + ), + upgrade_bit_shares( + &ctx.narrow(&Step::ConvertBreakdownKeyBits), breakdowns, BK::BITS, ), @@ -109,9 +113,6 @@ where C: UpgradedContext, S: LinearSecretSharing + BasicProtocols + Serializable + 'static, { - // TODO: use exactsizestream trait - // debug_assert!(contribution_values.len() == breakdown_keys.len()); - let equality_check_ctx = ctx.narrow(&Step::ComputeEqualityChecks); // Generate N streams for each bucket specified by the `num_buckets`. @@ -126,7 +127,7 @@ where let eq_ctx = &equality_check_ctx; let c = ctx.clone(); async move { - let equality_checks = check_everything(eq_ctx.clone(), i, &bk?).await?; + let equality_checks = bitwise_to_onehot(eq_ctx.clone(), i, &bk?).await?; equality_bits_times_value(&c, equality_checks, num_buckets, v?, i).await } }), @@ -173,7 +174,7 @@ where } fn upgrade_bit_shares<'a, F, C, S, I, G>( - ctx: C, + ctx: &C, input_rows: I, num_bits: u32, ) -> impl Stream, Error>> + 'a diff --git a/src/protocol/attribution/aggregate_credit.rs b/src/protocol/attribution/aggregate_credit.rs index fc052256b..f661474a3 100644 --- a/src/protocol/attribution/aggregate_credit.rs +++ b/src/protocol/attribution/aggregate_credit.rs @@ -10,7 +10,7 @@ use crate::{ protocol::{ context::{UpgradableContext, UpgradedContext, Validator}, modulus_conversion::convert_bits, - sort::{check_everything, generate_permutation::ShuffledPermutationWrapper}, + sort::{bitwise_to_onehot, generate_permutation::ShuffledPermutationWrapper}, step::BitOpStep, BasicProtocols, RecordId, }, @@ -114,7 +114,7 @@ where let ceq = &equality_check_context; let cmul = &check_times_credit_context; async move { - let equality_checks = check_everything(ceq.clone(), i, &bk?).await?; + let equality_checks = bitwise_to_onehot(ceq.clone(), i, &bk?).await?; ceq.try_join(equality_checks.into_iter().take(to_take).enumerate().map( |(check_idx, check)| { let step = BitOpStep::from(check_idx); diff --git a/src/protocol/sort/mod.rs b/src/protocol/sort/mod.rs index 48fc98e3c..840fc986f 100644 --- a/src/protocol/sort/mod.rs +++ b/src/protocol/sort/mod.rs @@ -91,7 +91,7 @@ pub(crate) enum ReshareStep { /// /// # Errors /// If any multiplication fails, or if the record is too long (e.g. more than 64 multiplications required) -pub async fn check_everything( +pub async fn bitwise_to_onehot( ctx: C, record_idx: usize, record: &[S], diff --git a/src/protocol/sort/multi_bit_permutation.rs b/src/protocol/sort/multi_bit_permutation.rs index 472ebaa1d..8b9ccb90b 100644 --- a/src/protocol/sort/multi_bit_permutation.rs +++ b/src/protocol/sort/multi_bit_permutation.rs @@ -4,7 +4,7 @@ use crate::{ error::Error, ff::PrimeField, protocol::{ - basics::SumOfProducts, context::UpgradedContext, sort::check_everything, BasicProtocols, + basics::SumOfProducts, context::UpgradedContext, sort::bitwise_to_onehot, BasicProtocols, RecordId, }, secret_sharing::{ @@ -66,7 +66,7 @@ where .iter() .zip(repeat(ctx.set_total_records(num_records))) .enumerate() - .map(|(idx, (record, ctx))| check_everything(ctx, idx, record)), + .map(|(idx, (record, ctx))| bitwise_to_onehot(ctx, idx, record)), ) .await?; @@ -117,7 +117,7 @@ mod tests { ff::{Field, Fp31}, protocol::{ context::{Context, UpgradableContext, Validator}, - sort::check_everything, + sort::bitwise_to_onehot, }, secret_sharing::{BitDecomposed, SharedValue}, seq_join::SeqJoin, @@ -170,7 +170,7 @@ mod tests { let ctx = ctx.set_total_records(num_records); let mut equality_check_futures = Vec::with_capacity(num_records); for (i, record) in m_shares.iter().enumerate() { - equality_check_futures.push(check_everything(ctx.clone(), i, record)); + equality_check_futures.push(bitwise_to_onehot(ctx.clone(), i, record)); } ctx.try_join(equality_check_futures).await.unwrap() })