diff --git a/src/ff/galois_field.rs b/src/ff/galois_field.rs index 6d2e44ffc..703789ac8 100644 --- a/src/ff/galois_field.rs +++ b/src/ff/galois_field.rs @@ -560,6 +560,16 @@ bit_array_impl!( 0b1_0001_1011_u128 ); +bit_array_impl!( + bit_array_3, + Gf3Bit, + U8_1, + 3, + bitarr!(const u8, Lsb0; 1, 0, 0), + // x^3 + x + 1 + 0b1_011_u128 +); + bit_array_impl!( bit_array_1, Gf2, diff --git a/src/ff/mod.rs b/src/ff/mod.rs index 453876175..21b0b70d3 100644 --- a/src/ff/mod.rs +++ b/src/ff/mod.rs @@ -9,7 +9,7 @@ mod prime_field; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; pub use field::{Field, FieldType}; -pub use galois_field::{GaloisField, Gf2, Gf32Bit, Gf40Bit, Gf8Bit}; +pub use galois_field::{GaloisField, Gf2, Gf32Bit, Gf3Bit, Gf40Bit, Gf8Bit}; use generic_array::{ArrayLength, GenericArray}; #[cfg(any(test, feature = "weak-field"))] pub use prime_field::Fp31; diff --git a/src/helpers/transport/query.rs b/src/helpers/transport/query.rs index 92443885a..8f77a3f9f 100644 --- a/src/helpers/transport/query.rs +++ b/src/helpers/transport/query.rs @@ -348,7 +348,7 @@ impl Default for SparseAggregateQueryConfig { fn default() -> Self { Self { contribution_bits: ContributionBits::default(), - num_contributions: 1, + num_contributions: 8, } } } diff --git a/src/protocol/aggregation/mod.rs b/src/protocol/aggregation/mod.rs index a9b955eb9..c1fd6a8e0 100644 --- a/src/protocol/aggregation/mod.rs +++ b/src/protocol/aggregation/mod.rs @@ -1,10 +1,12 @@ mod input; -use futures::{stream::iter as stream_iter, TryStreamExt}; +use futures::{stream::iter as stream_iter, Stream, TryStreamExt}; +use futures_util::StreamExt; pub use input::SparseAggregateInputRow; use ipa_macros::step; use strum::AsRefStr; +use super::{context::Context, sort::bitwise_to_onehot, step::BitOpStep, RecordId}; use crate::{ error::Error, ff::{Field, GaloisField, Gf2, PrimeField, Serializable}, @@ -21,21 +23,32 @@ use crate::{ }, BitDecomposed, Linear as LinearSecretSharing, }, + seq_join::seq_join, }; #[step] pub(crate) enum Step { Validator, ConvertValueBits, + ConvertBreakdownKeyBits, + ComputeEqualityChecks, + CheckTimesValue, } -/// Binary-share aggregation protocol. +/// Binary-share aggregation protocol for a sparse breakdown key vector input. +/// It takes a tuple of two vectors, `contribution_values` and `breakdown_keys`, +/// and aggregate each value to the corresponding histogram bucket specified by +/// the breakdown key. Since breakdown keys are secret shared, we need to create +/// a vector of Z2 shares for each record indicating which bucket the value +/// should be aggregated to. The output is a vector of Zp shares - a histogram +/// of the aggregated values. /// /// # Errors /// Propagates errors from multiplications -pub async fn aggregate<'a, C, S, SB, F, CV, BK>( +pub async fn sparse_aggregate<'a, C, S, SB, F, CV, BK>( sh_ctx: C, input_rows: &[SparseAggregateInputRow], + num_buckets: usize, ) -> Result>, Error> where C: UpgradableContext, @@ -52,139 +65,229 @@ where BK: GaloisField, { let validator = sh_ctx.narrow(&Step::Validator).validator::(); + let ctx = validator.context().set_total_records(input_rows.len()); + let contributions = input_rows.iter().map(|row| &row.contribution_value); + let breakdowns = input_rows.iter().map(|row| &row.breakdown_key); - let (gf2_value_bits, _gf2_breakdown_keys) = ( - get_gf2_value_bits(input_rows), - get_gf2_breakdown_key_bits(input_rows), + // convert the input from `[Z2]^u` into `[Zp]^u` + let (converted_value_bits, converted_breakdown_key_bits) = ( + upgrade_bit_shares( + &ctx.narrow(&Step::ConvertValueBits), + contributions, + CV::BITS, + ), + upgrade_bit_shares( + &ctx.narrow(&Step::ConvertBreakdownKeyBits), + breakdowns, + BK::BITS, + ), ); - // TODO(taikiy): - // 1. slice the buckets into N streams and send them to the aggregation protocol - // 2. collect the results and return them + let output = sparse_aggregate_values_per_bucket( + ctx, + converted_value_bits, + converted_breakdown_key_bits, + num_buckets, + ) + .await?; - let output = aggregate_values(validator.context(), gf2_value_bits).await?; - - validator.validate(vec![output]).await + validator.validate(output).await } -/// Performs a set of aggregation protocols on binary shared values. /// This protocol assumes that devices and/or browsers have applied per-user /// capping. /// /// # Errors /// propagates errors from multiplications -#[tracing::instrument(name = "simple_aggregate_values", skip_all)] -pub async fn aggregate_values( +#[tracing::instrument(name = "aggregate_values_per_bucket", skip_all)] +pub async fn sparse_aggregate_values_per_bucket( ctx: C, - contribution_value_bits_gf2: Vec>>, -) -> Result + contribution_values: I1, + breakdown_keys: I2, + num_buckets: usize, +) -> Result, Error> where F: PrimeField, + I1: Stream, Error>> + Send, + I2: Stream, Error>> + Send, C: UpgradedContext, S: LinearSecretSharing + BasicProtocols + Serializable + 'static, { - let record_count = contribution_value_bits_gf2.len(); - let bits = contribution_value_bits_gf2[0].len(); - - // mod-convert for later validation - let convert_ctx = ctx - .narrow(&Step::ConvertValueBits) - .set_total_records(record_count); - let converted_contribution_values = convert_bits( - convert_ctx, - stream_iter(contribution_value_bits_gf2), - 0..u32::try_from(bits).unwrap(), - ); + let equality_check_ctx = ctx.narrow(&Step::ComputeEqualityChecks); - let aggregate = converted_contribution_values - .try_fold(S::ZERO, |mut acc, row| async move { - acc += &row.to_additive_sharing_in_large_field(); + // Generate N streams for each bucket specified by the `num_buckets`. + // A stream is pipeline of contribution values multiplied by the "equality bit". An equality + // bit is a bit that is a share of 1 if the breakdown key matches the bucket, or 0 otherwise. + let streams = seq_join( + ctx.active_work(), + breakdown_keys + .zip(contribution_values) + .enumerate() + .map(|(i, (bk, v))| { + let eq_ctx = &equality_check_ctx; + let c = ctx.clone(); + async move { + let equality_checks = bitwise_to_onehot(eq_ctx.clone(), i, &bk?).await?; + equality_bits_times_value(&c, equality_checks, num_buckets, v?, i).await + } + }), + ); + // for each bucket stream, sum up the contribution values + streams + .try_fold(vec![S::ZERO; num_buckets], |mut acc, bucket| async move { + for (i, b) in bucket.into_iter().enumerate() { + acc[i] += &b; + } Ok(acc) }) - .await?; - - Ok(aggregate) + .await } -fn get_gf2_value_bits( - input_rows: &[SparseAggregateInputRow], -) -> Vec>> +async fn equality_bits_times_value( + ctx: &C, + check_bits: BitDecomposed, + num_buckets: usize, + value_bits: BitDecomposed, + record_id: usize, +) -> Result, Error> where - CV: GaloisField, - BK: GaloisField, + F: PrimeField, + C: UpgradedContext, + S: LinearSecretSharing + BasicProtocols + Serializable + 'static, { - input_rows - .iter() - .map(|row| { - BitDecomposed::decompose(CV::BITS, |i| { - Replicated::new( - Gf2::truncate_from(row.contribution_value.left()[i]), - Gf2::truncate_from(row.contribution_value.right()[i]), - ) - }) - }) - .collect::>() + let check_times_value_ctx = ctx.narrow(&Step::CheckTimesValue); + + ctx.try_join( + check_bits + .into_iter() + .take(num_buckets) + .enumerate() + .map(|(check_idx, check)| { + let step = BitOpStep::from(check_idx); + let c = check_times_value_ctx.narrow(&step); + let record_id = RecordId::from(record_id); + let v = value_bits.to_additive_sharing_in_large_field(); + async move { check.multiply(&v, c, record_id).await } + }), + ) + .await } -fn get_gf2_breakdown_key_bits( - input_rows: &[SparseAggregateInputRow], -) -> Vec>> +fn upgrade_bit_shares<'a, F, C, S, I, G>( + ctx: &C, + input_rows: I, + num_bits: u32, +) -> impl Stream, Error>> + 'a where - CV: GaloisField, - BK: GaloisField, + F: PrimeField, + C: UpgradedContext + 'a, + S: LinearSecretSharing + BasicProtocols + Serializable + 'static, + I: Iterator> + Send + 'a, + G: GaloisField, { - input_rows - .iter() - .map(|row| { - BitDecomposed::decompose(BK::BITS, |i| { - Replicated::new( - Gf2::truncate_from(row.breakdown_key.left()[i]), - Gf2::truncate_from(row.breakdown_key.right()[i]), - ) - }) + let gf2_bits = input_rows.map(move |row| { + BitDecomposed::decompose(num_bits, |idx| { + Replicated::new( + Gf2::truncate_from(row.left()[idx]), + Gf2::truncate_from(row.right()[idx]), + ) }) - .collect::>() + }); + + convert_bits( + ctx.narrow(&Step::ConvertValueBits), + stream_iter(gf2_bits), + 0..num_bits, + ) } #[cfg(all(test, unit_test))] mod tests { - use super::aggregate_values; + use super::sparse_aggregate; use crate::{ - ff::{Fp32BitPrime, Gf2}, - protocol::context::{UpgradableContext, Validator}, - secret_sharing::BitDecomposed, + ff::{Field, Fp32BitPrime, GaloisField, Gf3Bit, Gf8Bit}, + protocol::aggregation::SparseAggregateInputRow, + secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, test_fixture::{Reconstruct, Runner, TestWorld}, }; + fn create_input_vec( + input: &[(Replicated, Replicated)], + ) -> Vec> + where + BK: GaloisField, + CV: GaloisField, + { + input + .iter() + .map(|x| SparseAggregateInputRow { + breakdown_key: x.0.clone(), + contribution_value: x.1.clone(), + }) + .collect::>() + } + #[tokio::test] pub async fn aggregate() { - const CONTRIBUTION_BITS: u32 = 8; - const EXPECTED: u128 = 36; + type BK = Gf3Bit; + type CV = Gf8Bit; + + const EXPECTED: &[u128] = &[28, 0, 0, 6, 1, 0, 0, 8]; + const NUM_BUCKETS: usize = 1 << BK::BITS; + const INPUT: &[(u32, u32)] = &[ + // (breakdown_key, contribution_value) + (0, 0), + (0, 0), + (0, 18), + (0, 0), + (0, 0), + (3, 5), + (0, 0), + (4, 1), + (0, 0), + (0, 0), + (7, 2), + (0, 0), + (0, 0), + (0, 0), + (0, 10), + (3, 1), + (0, 0), + (7, 6), + (0, 0), + ]; - const INPUT: &[u32] = &[0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 10, 0, 0, 6, 0]; + let bitwise_input = INPUT + .iter() + .map(|(bk, value)| (BK::truncate_from(*bk), CV::truncate_from(*value))); let world = TestWorld::default(); let result = world - .semi_honest( - INPUT.iter().map(|&value| { - BitDecomposed::decompose(CONTRIBUTION_BITS, |i| { - Gf2::try_from((u128::from(value) >> i) & 1).unwrap() - }) - }), - |ctx, shares| async move { - let validator = ctx.validator::(); - aggregate_values( - validator.context(), // note: not upgrading any inputs, so semi-honest only. - shares, - ) - .await - .unwrap() - }, - ) + .semi_honest(bitwise_input.clone(), |ctx, shares| async move { + sparse_aggregate::<_, _, _, Fp32BitPrime, CV, BK>( + ctx, + &create_input_vec(&shares), + NUM_BUCKETS, + ) + .await + .unwrap() + }) .await .reconstruct(); assert_eq!(result, EXPECTED); - } - //TODO(taikiy): add malicious test + let result = world + .malicious(bitwise_input.clone(), |ctx, shares| async move { + sparse_aggregate::<_, _, _, Fp32BitPrime, CV, BK>( + ctx, + &create_input_vec(&shares), + NUM_BUCKETS, + ) + .await + .unwrap() + }) + .await + .reconstruct(); + assert_eq!(result, EXPECTED); + } } diff --git a/src/protocol/attribution/aggregate_credit.rs b/src/protocol/attribution/aggregate_credit.rs index fc052256b..f661474a3 100644 --- a/src/protocol/attribution/aggregate_credit.rs +++ b/src/protocol/attribution/aggregate_credit.rs @@ -10,7 +10,7 @@ use crate::{ protocol::{ context::{UpgradableContext, UpgradedContext, Validator}, modulus_conversion::convert_bits, - sort::{check_everything, generate_permutation::ShuffledPermutationWrapper}, + sort::{bitwise_to_onehot, generate_permutation::ShuffledPermutationWrapper}, step::BitOpStep, BasicProtocols, RecordId, }, @@ -114,7 +114,7 @@ where let ceq = &equality_check_context; let cmul = &check_times_credit_context; async move { - let equality_checks = check_everything(ceq.clone(), i, &bk?).await?; + let equality_checks = bitwise_to_onehot(ceq.clone(), i, &bk?).await?; ceq.try_join(equality_checks.into_iter().take(to_take).enumerate().map( |(check_idx, check)| { let step = BitOpStep::from(check_idx); diff --git a/src/protocol/sort/mod.rs b/src/protocol/sort/mod.rs index 48fc98e3c..840fc986f 100644 --- a/src/protocol/sort/mod.rs +++ b/src/protocol/sort/mod.rs @@ -91,7 +91,7 @@ pub(crate) enum ReshareStep { /// /// # Errors /// If any multiplication fails, or if the record is too long (e.g. more than 64 multiplications required) -pub async fn check_everything( +pub async fn bitwise_to_onehot( ctx: C, record_idx: usize, record: &[S], diff --git a/src/protocol/sort/multi_bit_permutation.rs b/src/protocol/sort/multi_bit_permutation.rs index 472ebaa1d..8b9ccb90b 100644 --- a/src/protocol/sort/multi_bit_permutation.rs +++ b/src/protocol/sort/multi_bit_permutation.rs @@ -4,7 +4,7 @@ use crate::{ error::Error, ff::PrimeField, protocol::{ - basics::SumOfProducts, context::UpgradedContext, sort::check_everything, BasicProtocols, + basics::SumOfProducts, context::UpgradedContext, sort::bitwise_to_onehot, BasicProtocols, RecordId, }, secret_sharing::{ @@ -66,7 +66,7 @@ where .iter() .zip(repeat(ctx.set_total_records(num_records))) .enumerate() - .map(|(idx, (record, ctx))| check_everything(ctx, idx, record)), + .map(|(idx, (record, ctx))| bitwise_to_onehot(ctx, idx, record)), ) .await?; @@ -117,7 +117,7 @@ mod tests { ff::{Field, Fp31}, protocol::{ context::{Context, UpgradableContext, Validator}, - sort::check_everything, + sort::bitwise_to_onehot, }, secret_sharing::{BitDecomposed, SharedValue}, seq_join::SeqJoin, @@ -170,7 +170,7 @@ mod tests { let ctx = ctx.set_total_records(num_records); let mut equality_check_futures = Vec::with_capacity(num_records); for (i, record) in m_shares.iter().enumerate() { - equality_check_futures.push(check_everything(ctx.clone(), i, record)); + equality_check_futures.push(bitwise_to_onehot(ctx.clone(), i, record)); } ctx.try_join(equality_check_futures).await.unwrap() }) diff --git a/src/query/runner/aggregate.rs b/src/query/runner/aggregate.rs index c58ba7839..f25abff10 100644 --- a/src/query/runner/aggregate.rs +++ b/src/query/runner/aggregate.rs @@ -12,7 +12,7 @@ use crate::{ }, hpke::{KeyPair, KeyRegistry}, protocol::{ - aggregation::{aggregate, SparseAggregateInputRow}, + aggregation::{sparse_aggregate, SparseAggregateInputRow}, basics::{Reshare, ShareKnownValue}, context::{UpgradableContext, UpgradedContext}, BasicProtocols, BreakdownKey, RecordId, @@ -25,7 +25,7 @@ use crate::{ }; pub struct SparseAggregateQuery { - _config: SparseAggregateQueryConfig, + config: SparseAggregateQueryConfig, _key_registry: Arc>, phantom_data: PhantomData<(F, C, S)>, } @@ -36,7 +36,7 @@ impl SparseAggregateQuery { key_registry: Arc>, ) -> Self { Self { - _config: config, + config, _key_registry: key_registry, phantom_data: PhantomData, } @@ -70,7 +70,7 @@ where input_stream: BodyStream, ) -> Result>, Error> { let Self { - _config, + config, _key_registry, phantom_data: _, } = self; @@ -92,6 +92,11 @@ where v }; - aggregate(ctx, input.as_slice()).await + sparse_aggregate( + ctx, + input.as_slice(), + usize::try_from(config.num_contributions).unwrap(), + ) + .await } }