diff --git a/src/protocol/prf_sharding/mod.rs b/src/protocol/prf_sharding/mod.rs index 07d966f41..5e92d27ad 100644 --- a/src/protocol/prf_sharding/mod.rs +++ b/src/protocol/prf_sharding/mod.rs @@ -1,9 +1,8 @@ -use ipa_macros::step; use std::iter::repeat; -use strum::AsRefStr; use futures_util::future::try_join_all; -use metrics_util::registry::Storage; +use ipa_macros::step; +use strum::AsRefStr; use super::step::BitOpStep; use crate::{ @@ -11,7 +10,7 @@ use crate::{ ff::{Field, GaloisField, Gf2}, protocol::{ basics::{SecureMul, ShareKnownValue}, - context::{Context, UpgradableContext, UpgradedContext, Validator}, + context::{UpgradableContext, UpgradedContext, Validator}, BasicProtocols, RecordId, }, repeat64str, @@ -74,32 +73,17 @@ pub(crate) enum Step { ComputedCappedAttributedTriggerValueJustSaturatedCase, } -fn compute_histogram_of_users_with_row_count( - input_rows: &[PrfShardedIpaInputRow], -) -> Vec -where - BK: GaloisField, - TV: GaloisField, -{ - let (_, _, hist) = input_rows.iter().fold( - (0, 0, vec![]), - |(last_prf, rows_for_user, mut histogram), input_row| { - if last_prf == input_row.prf_of_match_key { - if rows_for_user >= histogram.len() { - histogram.push(0); - } - histogram[rows_for_user] += 1; - (input_row.prf_of_match_key, rows_for_user + 1, histogram) - } else { - if histogram.is_empty() { - histogram.push(0); - } - histogram[0] += 1; - (input_row.prf_of_match_key, 1, histogram) +fn compute_histogram_of_users_with_row_count(rows_chunked_by_user: &[Vec]) -> Vec { + let mut output = vec![]; + for user_rows in rows_chunked_by_user { + for j in 0..user_rows.len() { + if j >= output.len() { + output.push(0); } - }, - ); - hist + output[j] += 1; + } + } + output } fn set_up_contexts(root_ctx: C, histogram: Vec) -> Vec @@ -120,6 +104,35 @@ where context_per_row_depth } +fn chunk_rows_by_user( + input_rows: Vec>, +) -> Vec>> +where + BK: GaloisField, + TV: GaloisField, +{ + let mut rows_for_user = vec![]; + + let mut rows_chunked_by_user = vec![]; + for row in input_rows { + if rows_for_user.is_empty() { + rows_for_user.push(row); + } else { + if row.prf_of_match_key == rows_for_user[0].prf_of_match_key { + rows_for_user.push(row); + } else { + rows_chunked_by_user.push(rows_for_user); + rows_for_user = vec![row]; + } + } + } + if !rows_for_user.is_empty() { + rows_chunked_by_user.push(rows_for_user); + } + + rows_chunked_by_user +} + /// Sub-protocol of the PRF-sharded IPA Protocol /// /// After the computation of the per-user PRF, addition of dummy records and shuffling, @@ -139,7 +152,7 @@ where /// Propagates errors from multiplications pub async fn attribution_and_capping( sh_ctx: C, - input_rows: &[PrfShardedIpaInputRow], + input_rows: Vec>, num_breakdown_key_bits: usize, num_trigger_value_bits: usize, num_saturating_sum_bits: usize, @@ -154,66 +167,90 @@ where assert!(num_trigger_value_bits > 0); assert!(num_breakdown_key_bits > 0); + let rows_chunked_by_user = chunk_rows_by_user(input_rows); + let histogram = compute_histogram_of_users_with_row_count(&rows_chunked_by_user); + let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); // TODO: fix num total records to be not a hard-coded constant, but variable per step // based on the histogram of how many users have how many records a piece let binary_m_ctx = binary_validator.context(); - let histogram = compute_histogram_of_users_with_row_count(input_rows); let ctx_for_row_number = set_up_contexts(binary_m_ctx.clone(), histogram); - let mut output = vec![]; + let mut futures = Vec::with_capacity(rows_chunked_by_user.len()); + let mut num_users_who_encountered_row_depth = vec![]; + for rows_for_user in rows_chunked_by_user { + for i in 0..rows_for_user.len() { + if i >= num_users_who_encountered_row_depth.len() { + num_users_who_encountered_row_depth.push(0); + } + num_users_who_encountered_row_depth[i] += 1; + } + + futures.push(evaluate_per_user_attribution_circuit( + &ctx_for_row_number, + num_users_who_encountered_row_depth + .iter() + .take(rows_for_user.len()) + .map(|x| RecordId(x - 1)) + .collect(), + rows_for_user, + num_breakdown_key_bits, + num_trigger_value_bits, + num_saturating_sum_bits, + )); + } + let outputs_chunked_by_user = try_join_all(futures).await?; + Ok(outputs_chunked_by_user + .into_iter() + .flatten() + .collect::>()) +} - assert!(!input_rows.is_empty()); - let first_row = &input_rows[0]; - let mut prev_prf = first_row.prf_of_match_key; +async fn evaluate_per_user_attribution_circuit( + ctx_for_row_number: &[C], + record_id_for_each_depth: Vec, + rows_for_user: Vec>, + num_breakdown_key_bits: usize, + num_trigger_value_bits: usize, + num_saturating_sum_bits: usize, +) -> Result, Error> +where + C: UpgradedContext>, + BK: GaloisField, + TV: GaloisField, +{ + assert!(!rows_for_user.is_empty()); + if rows_for_user.len() == 1 { + return Ok(vec![]); + } + let first_row = &rows_for_user[0]; let mut prev_row_inputs = initialize_new_device_attribution_variables( - Replicated::share_known_value(&binary_m_ctx, Gf2::ONE), + Replicated::share_known_value(&ctx_for_row_number[0], Gf2::ONE), first_row, num_breakdown_key_bits, num_trigger_value_bits, num_saturating_sum_bits, ); - let mut i: usize = 1; - let mut num_users_who_encountered_row_depth = vec![]; - let mut row_for_user = 0; - while i < input_rows.len() { - let cur_row = &input_rows[i]; - if prev_prf == cur_row.prf_of_match_key { - if row_for_user >= num_users_who_encountered_row_depth.len() { - num_users_who_encountered_row_depth.push(0); - } - let ctx_for_this_row_depth = ctx_for_row_number[row_for_user].clone(); - // Do some actual computation - let (inputs_required_for_next_row, capped_attribution_outputs) = - compute_row_with_previous( - ctx_for_this_row_depth, - RecordId(num_users_who_encountered_row_depth[row_for_user]), - cur_row, - &prev_row_inputs, - num_breakdown_key_bits, - num_trigger_value_bits, - num_saturating_sum_bits, - ) - .await?; - output.push(capped_attribution_outputs); - prev_row_inputs = inputs_required_for_next_row; - num_users_who_encountered_row_depth[row_for_user] += 1; + let mut output = Vec::with_capacity(rows_for_user.len() - 1); + for (i, row) in rows_for_user.iter().skip(1).enumerate() { + let ctx_for_this_row_depth = ctx_for_row_number[i].clone(); // no context was created for row 0 + let record_id_for_this_row_depth = record_id_for_each_depth[i + 1]; // skip row 0 + + let (inputs_required_for_next_row, capped_attribution_outputs) = compute_row_with_previous( + ctx_for_this_row_depth, + record_id_for_this_row_depth, + row, + &prev_row_inputs, + num_breakdown_key_bits, + num_trigger_value_bits, + num_saturating_sum_bits, + ) + .await?; - row_for_user += 1; - } else { - prev_prf = cur_row.prf_of_match_key; - prev_row_inputs = initialize_new_device_attribution_variables( - Replicated::share_known_value(&binary_m_ctx, Gf2::ONE), - cur_row, - num_breakdown_key_bits, - num_trigger_value_bits, - num_saturating_sum_bits, - ); - row_for_user = 0; - } - i += 1; + output.push(capped_attribution_outputs); + prev_row_inputs = inputs_required_for_next_row; } Ok(output) @@ -686,6 +723,34 @@ pub mod tests { attributed_breakdown_key: 12, capped_attributed_trigger_value: 5, }, + PreAggregationTestOutput { + attributed_breakdown_key: 20, + capped_attributed_trigger_value: 7, + }, + PreAggregationTestOutput { + attributed_breakdown_key: 18, + capped_attributed_trigger_value: 0, + }, + PreAggregationTestOutput { + attributed_breakdown_key: 12, + capped_attributed_trigger_value: 0, + }, + PreAggregationTestOutput { + attributed_breakdown_key: 12, + capped_attributed_trigger_value: 7, + }, + PreAggregationTestOutput { + attributed_breakdown_key: 12, + capped_attributed_trigger_value: 7, + }, + PreAggregationTestOutput { + attributed_breakdown_key: 12, + capped_attributed_trigger_value: 7, + }, + PreAggregationTestOutput { + attributed_breakdown_key: 12, + capped_attributed_trigger_value: 4, + }, ]; const NUM_BREAKDOWN_KEY_BITS: usize = 5; const NUM_TRIGGER_VALUE_BITS: usize = 3; @@ -695,6 +760,7 @@ pub mod tests { let world = TestWorld::default(); let records: Vec> = vec![ + /* First User */ PreShardedAndSortedOPRFTestInput { prf_of_match_key: 123, is_trigger_bit: Gf2::ZERO, @@ -719,11 +785,12 @@ pub mod tests { breakdown_key: Gf8Bit::truncate_from(0_u8), trigger_value: Gf8Bit::truncate_from(3_u8), }, + /* Second User */ PreShardedAndSortedOPRFTestInput { prf_of_match_key: 234, is_trigger_bit: Gf2::ZERO, breakdown_key: Gf8Bit::truncate_from(12_u8), - trigger_value: Gf8Bit::truncate_from(3_u8), + trigger_value: Gf8Bit::truncate_from(0_u8), }, PreShardedAndSortedOPRFTestInput { prf_of_match_key: 234, @@ -731,13 +798,62 @@ pub mod tests { breakdown_key: Gf8Bit::truncate_from(0_u8), trigger_value: Gf8Bit::truncate_from(5_u8), }, + /* Third User */ + PreShardedAndSortedOPRFTestInput { + prf_of_match_key: 345, + is_trigger_bit: Gf2::ZERO, + breakdown_key: Gf8Bit::truncate_from(20_u8), + trigger_value: Gf8Bit::truncate_from(0_u8), + }, + PreShardedAndSortedOPRFTestInput { + prf_of_match_key: 345, + is_trigger_bit: Gf2::ONE, + breakdown_key: Gf8Bit::truncate_from(0_u8), + trigger_value: Gf8Bit::truncate_from(7_u8), + }, + PreShardedAndSortedOPRFTestInput { + prf_of_match_key: 345, + is_trigger_bit: Gf2::ZERO, + breakdown_key: Gf8Bit::truncate_from(18_u8), + trigger_value: Gf8Bit::truncate_from(0_u8), + }, + PreShardedAndSortedOPRFTestInput { + prf_of_match_key: 345, + is_trigger_bit: Gf2::ZERO, + breakdown_key: Gf8Bit::truncate_from(12_u8), + trigger_value: Gf8Bit::truncate_from(0_u8), + }, + PreShardedAndSortedOPRFTestInput { + prf_of_match_key: 345, + is_trigger_bit: Gf2::ONE, + breakdown_key: Gf8Bit::truncate_from(0_u8), + trigger_value: Gf8Bit::truncate_from(7_u8), + }, + PreShardedAndSortedOPRFTestInput { + prf_of_match_key: 345, + is_trigger_bit: Gf2::ONE, + breakdown_key: Gf8Bit::truncate_from(0_u8), + trigger_value: Gf8Bit::truncate_from(7_u8), + }, + PreShardedAndSortedOPRFTestInput { + prf_of_match_key: 345, + is_trigger_bit: Gf2::ONE, + breakdown_key: Gf8Bit::truncate_from(0_u8), + trigger_value: Gf8Bit::truncate_from(7_u8), + }, + PreShardedAndSortedOPRFTestInput { + prf_of_match_key: 345, + is_trigger_bit: Gf2::ONE, + breakdown_key: Gf8Bit::truncate_from(0_u8), + trigger_value: Gf8Bit::truncate_from(7_u8), + }, ]; let result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { attribution_and_capping( ctx, - input_rows.as_slice(), + input_rows, NUM_BREAKDOWN_KEY_BITS, NUM_TRIGGER_VALUE_BITS, NUM_SATURATING_SUM_BITS,