From 4f1f605b539fe8752c9332dd3653ba2f9b660998 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Wed, 4 Oct 2023 00:48:47 +0800 Subject: [PATCH 1/9] Feature vector label vector dot product --- .../prf_sharding/feature_label_dot_product.rs | 469 ++++++++++++++++++ src/protocol/prf_sharding/mod.rs | 3 + 2 files changed, 472 insertions(+) create mode 100644 src/protocol/prf_sharding/feature_label_dot_product.rs diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs new file mode 100644 index 000000000..bdea62dad --- /dev/null +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -0,0 +1,469 @@ +use futures::{stream::iter as stream_iter, TryStreamExt}; +use futures_util::future::try_join; +use ipa_macros::Step; + +use crate::{ + error::Error, + ff::{Field, GaloisField, Gf2, PrimeField, Serializable}, + protocol::{ + basics::{SecureMul, ShareKnownValue}, + boolean::or::or, + context::{Context, UpgradableContext, UpgradedContext, Validator}, + modulus_conversion::convert_bits, + step::BitOpStep, + RecordId, + }, + secret_sharing::{ + replicated::{ + malicious::ExtendableField, semi_honest::AdditiveShare as Replicated, + ReplicatedSecretSharing, + }, + BitDecomposed, Linear as LinearSecretSharing, + }, + seq_join::seq_try_join_all, +}; + +pub struct PrfShardedIpaInputRow { + prf_of_match_key: u64, + is_trigger_bit: Replicated, + feature_vector: Replicated, +} + +struct InputsRequiredFromPrevRow { + ever_encountered_a_trigger_event: Replicated, + is_saturated: Replicated, +} + +impl InputsRequiredFromPrevRow { + /// + /// This function contains the main logic for the per-user attribution circuit. + /// Multiple rows of data about a single user are processed in-order from newest to oldest. + /// + /// Summary: + /// - Last touch attribution + /// - Every source event which has a subsequent trigger event receives attribution + /// - Per user capping + /// - A cumulative count of "Source Events Receiving Attribution" is maintained + /// - Bitwise addition is used, and a single bit indicates if the sum is "saturated" + /// - The only available values for "cap" are powers of 2 (i.e. 1, 2, 4, 8, 16, 32, ...) + /// - Prior to saturation, feature vectors of source events receiving attribution contribute to the dot-product. + /// - All subsequent rows contribute zero + /// - Outputs + /// - If a user has `N` input rows, they will generate `N-1` output rows. (The first row cannot possibly contribute any value to the output) + /// - Each output row is a vector, either the feature vector or zeroes. + pub async fn compute_row_with_previous( + &mut self, + ctx: C, + record_id: RecordId, + input_row: &PrfShardedIpaInputRow, + ) -> Result>, Error> + where + C: UpgradedContext>, + FV: GaloisField, + { + let share_of_one = Replicated::share_known_value(&ctx, Gf2::ONE); + let is_source_event = &share_of_one - &input_row.is_trigger_bit; + + let (ever_encountered_a_trigger_event, did_source_get_attributed) = try_join( + or( + ctx.narrow(&Step::EverEncounteredTriggerEvent), + record_id, + &input_row.is_trigger_bit, + &self.ever_encountered_a_trigger_event, + ), + is_source_event.multiply( + &self.ever_encountered_a_trigger_event, + ctx.narrow(&Step::DidSourceReceiveAttribution), + record_id, + ), + ) + .await?; + + let (updated_is_saturated, capped_label) = try_join( + or( + ctx.narrow(&Step::ComputeSaturatingSum), + record_id, + &self.is_saturated, + &did_source_get_attributed, + ), + did_source_get_attributed.multiply( + &(share_of_one - &self.is_saturated), + ctx.narrow(&Step::IsAttributedSourceAndPrevRowNotSaturated), + record_id, + ), + ) + .await?; + + let unbitpacked_feature_vector = BitDecomposed::decompose(FV::BITS, |i| { + input_row.feature_vector.map(|v| Gf2::truncate_from(v[i])) + }); + + let capped_attributed_feature_vector = compute_capped_feature_vector( + ctx.narrow(&Step::ComputedCappedFeatureVector), + record_id, + &capped_label, + &unbitpacked_feature_vector, + ) + .await?; + + self.ever_encountered_a_trigger_event = ever_encountered_a_trigger_event; + self.is_saturated = updated_is_saturated; + + Ok(capped_attributed_feature_vector) + } +} + +#[derive(Step)] +pub enum UserNthRowStep { + #[dynamic] + Row(usize), +} + +impl From for UserNthRowStep { + fn from(v: usize) -> Self { + Self::Row(v) + } +} + +#[derive(Step)] +pub(crate) enum Step { + BinaryValidator, + PrimeFieldValidator, + EverEncounteredTriggerEvent, + DidSourceReceiveAttribution, + ComputeSaturatingSum, + IsAttributedSourceAndPrevRowNotSaturated, + ComputedCappedFeatureVector, + ModulusConvertFeatureVectorBits, +} + +fn compute_histogram_of_users_with_row_count(rows_chunked_by_user: &[Vec]) -> Vec { + let mut output = vec![]; + for user_rows in rows_chunked_by_user { + for j in 0..user_rows.len() { + if j >= output.len() { + output.push(0); + } + output[j] += 1; + } + } + output +} + +fn set_up_contexts(root_ctx: &C, histogram: &[usize]) -> Vec +where + C: UpgradedContext>, +{ + let mut context_per_row_depth = Vec::with_capacity(histogram.len()); + for (row_number, num_users_having_that_row_number) in histogram.iter().enumerate() { + if row_number == 0 { + // no multiplications needed for each user's row 0. No context needed + } else { + let ctx_for_row_number = root_ctx + .narrow(&UserNthRowStep::from(row_number)) + .set_total_records(*num_users_having_that_row_number); + context_per_row_depth.push(ctx_for_row_number); + } + } + context_per_row_depth +} + +fn chunk_rows_by_user( + input_rows: Vec>, +) -> Vec>> +where + FV: GaloisField, +{ + let mut rows_for_user: Vec> = vec![]; + + let mut rows_chunked_by_user = vec![]; + for row in input_rows { + if rows_for_user.is_empty() || row.prf_of_match_key == rows_for_user[0].prf_of_match_key { + rows_for_user.push(row); + } else { + rows_chunked_by_user.push(rows_for_user); + rows_for_user = vec![row]; + } + } + if !rows_for_user.is_empty() { + rows_chunked_by_user.push(rows_for_user); + } + + rows_chunked_by_user +} + +/// Sub-protocol of the PRF-sharded IPA Protocol +/// +/// After the computation of the per-user PRF, addition of dummy records and shuffling, +/// the PRF column can be revealed. After that, all of the records corresponding to a single +/// device can be processed together. +/// +/// This circuit expects to receive records from multiple users, +/// but with all of the records from a given user adjacent to one another, and in time order. +/// +/// This circuit will compute attribution, and per-user capping. +/// +/// The output of this circuit is the input to the next stage: Aggregation. +/// +/// # Errors +/// Propagates errors from multiplications +/// # Panics +/// Propagates errors from multiplications +pub async fn compute_feature_label_dot_product( + sh_ctx: C, + input_rows: Vec>, +) -> Result, Error> +where + C: UpgradableContext, + C::UpgradedContext: UpgradedContext>, + C::UpgradedContext: UpgradedContext, + S: LinearSecretSharing + Serializable + SecureMul>, + FV: GaloisField, + F: PrimeField + ExtendableField, +{ + assert!(FV::BITS > 0); + + let rows_chunked_by_user = chunk_rows_by_user(input_rows); + let histogram = compute_histogram_of_users_with_row_count(&rows_chunked_by_user); + let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); + let binary_m_ctx = binary_validator.context(); + let mut num_users_who_encountered_row_depth = Vec::with_capacity(histogram.len()); + let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram); + let mut futures = Vec::with_capacity(rows_chunked_by_user.len()); + for rows_for_user in rows_chunked_by_user { + for i in 0..rows_for_user.len() { + if i >= num_users_who_encountered_row_depth.len() { + num_users_who_encountered_row_depth.push(0); + } + num_users_who_encountered_row_depth[i] += 1; + } + + futures.push(evaluate_per_user_attribution_circuit( + &ctx_for_row_number, + num_users_who_encountered_row_depth + .iter() + .take(rows_for_user.len()) + .map(|x| RecordId(x - 1)) + .collect(), + rows_for_user, + )); + } + let outputs_chunked_by_user = seq_try_join_all(sh_ctx.active_work(), futures) + .await? + .into_iter() + .flatten() + .collect::>>>(); + + let prime_field_validator = sh_ctx.narrow(&Step::PrimeFieldValidator).validator::(); + let prime_field_ctx = prime_field_validator.context(); + + // modulus convert feature vector bits + let converted_feature_vector_bits = convert_bits( + prime_field_ctx + .narrow(&Step::ModulusConvertFeatureVectorBits) + .set_total_records(outputs_chunked_by_user.len()), + stream_iter(outputs_chunked_by_user), + 0..FV::BITS, + ); + + converted_feature_vector_bits + .try_fold( + vec![S::ZERO; 1 << FV::BITS], + |mut running_sums, row_contribution| async move { + for (i, contribution) in row_contribution.iter().enumerate() { + running_sums[i] += contribution; + } + Ok(running_sums) + }, + ) + .await +} + +async fn evaluate_per_user_attribution_circuit( + ctx_for_row_number: &[C], + record_id_for_each_depth: Vec, + rows_for_user: Vec>, +) -> Result>>, Error> +where + C: UpgradedContext>, + FV: GaloisField, +{ + assert!(!rows_for_user.is_empty()); + if rows_for_user.len() == 1 { + return Ok(Vec::new()); + } + let first_row = &rows_for_user[0]; + let mut prev_row_inputs = initialize_new_device_attribution_variables(first_row); + + let mut output = Vec::with_capacity(rows_for_user.len() - 1); + for (i, row) in rows_for_user.iter().skip(1).enumerate() { + let ctx_for_this_row_depth = ctx_for_row_number[i].clone(); // no context was created for row 0 + let record_id_for_this_row_depth = record_id_for_each_depth[i + 1]; // skip row 0 + + let capped_attribution_outputs = prev_row_inputs + .compute_row_with_previous(ctx_for_this_row_depth, record_id_for_this_row_depth, row) + .await?; + + output.push(capped_attribution_outputs); + } + + Ok(output) +} + +/// +/// Upon encountering the first row of data from a new user (as distinguished by a different OPRF of the match key) +/// this function encapsulates the variables that must be initialized. No communication is required for this first row. +/// +fn initialize_new_device_attribution_variables( + input_row: &PrfShardedIpaInputRow, +) -> InputsRequiredFromPrevRow +where + FV: GaloisField, +{ + InputsRequiredFromPrevRow { + ever_encountered_a_trigger_event: input_row.is_trigger_bit.clone(), + is_saturated: Replicated::ZERO, + } +} + +async fn compute_capped_feature_vector( + ctx: C, + record_id: RecordId, + capped_label: &Replicated, + feature_vector: &BitDecomposed>, +) -> Result>, Error> +where + C: UpgradedContext>, +{ + Ok(BitDecomposed::new( + ctx.parallel_join(feature_vector.iter().enumerate().map(|(i, bit)| { + let c1 = ctx.narrow(&BitOpStep::from(i)); + async move { capped_label.multiply(bit, c1, record_id).await } + })) + .await?, + )) +} + +#[cfg(all(test, unit_test))] +pub mod tests { + use crate::{ + ff::{Field, Fp32BitPrime, GaloisField, Gf2, Gf32Bit}, + protocol::prf_sharding::feature_label_dot_product::{ + compute_feature_label_dot_product, PrfShardedIpaInputRow, + }, + rand::Rng, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, + }, + test_executor::run, + test_fixture::{Reconstruct, Runner, TestWorld}, + }; + + struct PreShardedAndSortedOPRFTestInput { + prf_of_match_key: u64, + is_trigger_bit: Gf2, + feature_vector: FV, + } + + fn test_input( + prf_of_match_key: u64, + is_trigger: bool, + feature_vector: u32, + ) -> PreShardedAndSortedOPRFTestInput { + let is_trigger_bit = if is_trigger { Gf2::ONE } else { Gf2::ZERO }; + + PreShardedAndSortedOPRFTestInput { + prf_of_match_key, + is_trigger_bit, + feature_vector: Gf32Bit::truncate_from(feature_vector), + } + } + + impl IntoShares> for PreShardedAndSortedOPRFTestInput + where + FV: GaloisField + IntoShares>, + { + fn share_with(self, rng: &mut R) -> [PrfShardedIpaInputRow; 3] { + let PreShardedAndSortedOPRFTestInput { + prf_of_match_key, + is_trigger_bit, + feature_vector, + } = self; + + let [is_trigger_bit0, is_trigger_bit1, is_trigger_bit2] = + is_trigger_bit.share_with(rng); + let [feature_vector0, feature_vector1, feature_vector2] = + feature_vector.share_with(rng); + + [ + PrfShardedIpaInputRow { + prf_of_match_key, + is_trigger_bit: is_trigger_bit0, + feature_vector: feature_vector0, + }, + PrfShardedIpaInputRow { + prf_of_match_key, + is_trigger_bit: is_trigger_bit1, + feature_vector: feature_vector1, + }, + PrfShardedIpaInputRow { + prf_of_match_key, + is_trigger_bit: is_trigger_bit2, + feature_vector: feature_vector2, + }, + ] + } + } + + #[test] + fn semi_honest() { + run(|| async move { + let world = TestWorld::default(); + + let records: Vec> = vec![ + /* First User */ + test_input(123, true, 0b0000_0000_0000_0000_0000_0000_0000_0000), // trigger + test_input(123, false, 0b1101_0100_1111_0001_0111_0010_1010_1011), // this source DOES receive attribution + test_input(123, true, 0b0000_0000_0000_0000_0000_0000_0000_0000), // trigger + test_input(123, false, 0b0110_1101_0001_0100_1011_0100_1010_1001), // this source does not receive attribution (capped) + /* Second User */ + test_input(234, true, 0b0000_0000_0000_0000_0000_0000_0000_0000), // trigger + test_input(234, false, 0b0001_1010_0011_0111_0110_0010_1111_0000), // this source DOES receive attribution + /* Third User */ + test_input(345, true, 0b0000_0000_0000_0000_0000_0000_0000_0000), // trigger + test_input(345, true, 0b0000_0000_0000_0000_0000_0000_0000_0000), // trigger + test_input(345, true, 0b0000_0000_0000_0000_0000_0000_0000_0000), // trigger + test_input(345, true, 0b0000_0000_0000_0000_0000_0000_0000_0000), // trigger + test_input(345, false, 0b0111_0101_0001_0000_0111_0100_0101_0011), // this source DOES receive attribution + test_input(345, false, 0b1001_1000_1011_1101_0100_0110_0001_0100), // this source does not receive attribution (capped) + test_input(345, true, 0b0000_0000_0000_0000_0000_0000_0000_0000), // trigger + test_input(345, false, 0b1000_1001_0100_0011_0111_0010_0000_1101), // this source does not receive attribution (capped) + ]; + + let expected: [u128; 32] = [ + // 1101_0100_1111_0001_0111_0010_1010_1011 + // 0001_1010_0011_0111_0110_0010_1111_0000 + // + 0111_0101_0001_0000_0111_0100_0101_0011 + // ------------------------------------------- + // 1213_1211_1123_0112_0332_0120_2222_1022 + 1, 2, 1, 3, 1, 2, 1, 1, 1, 1, 2, 3, 0, 1, 1, 2, 0, 3, 3, 2, 0, 1, 2, 0, 2, 2, 2, 2, + 1, 0, 2, 2, + ]; + + let result: Vec<_> = world + .semi_honest(records.into_iter(), |ctx, input_rows| async move { + compute_feature_label_dot_product::< + _, + Gf32Bit, + Fp32BitPrime, + Replicated, + >(ctx, input_rows) + .await + .unwrap() + }) + .await + .reconstruct(); + assert_eq!(result, &expected); + }); + } +} diff --git a/src/protocol/prf_sharding/mod.rs b/src/protocol/prf_sharding/mod.rs index 0e0b55304..fdd3a6440 100644 --- a/src/protocol/prf_sharding/mod.rs +++ b/src/protocol/prf_sharding/mod.rs @@ -27,6 +27,9 @@ use crate::{ }, seq_join::{seq_join, seq_try_join_all}, }; + +pub mod feature_label_dot_product; + pub struct PrfShardedIpaInputRow { prf_of_match_key: u64, is_trigger_bit: Replicated, From 094eb28346c1e5fc6d22fe09bca59f4cc6111679 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Wed, 4 Oct 2023 21:40:17 +0800 Subject: [PATCH 2/9] It doesn't really compile yet --- Cargo.toml | 1 + .../prf_sharding/feature_label_dot_product.rs | 43 ++++++++++--------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e264d503f..a72bfcfbf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -95,6 +95,7 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } typenum = "1.16" # hpke is pinned to it x25519-dalek = "2.0.0-pre.0" +stream-flatten-iters = "0.2.0" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.5.0" diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index bdea62dad..66dca764e 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -1,7 +1,9 @@ use futures::{stream::iter as stream_iter, TryStreamExt}; -use futures_util::future::try_join; +use futures_util::{future::try_join, StreamExt}; use ipa_macros::Step; +use stream_flatten_iters::StreamExt as _; + use crate::{ error::Error, ff::{Field, GaloisField, Gf2, PrimeField, Serializable}, @@ -20,7 +22,7 @@ use crate::{ }, BitDecomposed, Linear as LinearSecretSharing, }, - seq_join::seq_try_join_all, + seq_join::seq_join, }; pub struct PrfShardedIpaInputRow { @@ -223,36 +225,34 @@ where { assert!(FV::BITS > 0); + let mut num_outputs = input_rows.len(); let rows_chunked_by_user = chunk_rows_by_user(input_rows); + num_outputs -= rows_chunked_by_user.len(); let histogram = compute_histogram_of_users_with_row_count(&rows_chunked_by_user); let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); let binary_m_ctx = binary_validator.context(); - let mut num_users_who_encountered_row_depth = Vec::with_capacity(histogram.len()); + let mut num_users_who_encountered_row_depth = vec![0_u32; histogram.len()]; let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram); let mut futures = Vec::with_capacity(rows_chunked_by_user.len()); for rows_for_user in rows_chunked_by_user { - for i in 0..rows_for_user.len() { - if i >= num_users_who_encountered_row_depth.len() { - num_users_who_encountered_row_depth.push(0); - } - num_users_who_encountered_row_depth[i] += 1; - } - + let num_user_rows = rows_for_user.len(); futures.push(evaluate_per_user_attribution_circuit( &ctx_for_row_number, num_users_who_encountered_row_depth .iter() - .take(rows_for_user.len()) - .map(|x| RecordId(x - 1)) + .take(num_user_rows) + .map(|x| RecordId(*x)) .collect(), rows_for_user, )); + for i in 0..num_user_rows { + num_users_who_encountered_row_depth[i] += 1; + } } - let outputs_chunked_by_user = seq_try_join_all(sh_ctx.active_work(), futures) - .await? - .into_iter() - .flatten() - .collect::>>>(); + + let flattenned_stream = seq_join(sh_ctx.active_work(), stream_iter(futures)) + .map(|x| x.unwrap().into_iter()) + .flatten_iters(); let prime_field_validator = sh_ctx.narrow(&Step::PrimeFieldValidator).validator::(); let prime_field_ctx = prime_field_validator.context(); @@ -261,14 +261,14 @@ where let converted_feature_vector_bits = convert_bits( prime_field_ctx .narrow(&Step::ModulusConvertFeatureVectorBits) - .set_total_records(outputs_chunked_by_user.len()), - stream_iter(outputs_chunked_by_user), + .set_total_records(num_outputs), + flattenned_stream, 0..FV::BITS, ); converted_feature_vector_bits .try_fold( - vec![S::ZERO; 1 << FV::BITS], + vec![S::ZERO; usize::try_from(FV::BITS).unwrap()], |mut running_sums, row_contribution| async move { for (i, contribution) in row_contribution.iter().enumerate() { running_sums[i] += contribution; @@ -440,7 +440,7 @@ pub mod tests { test_input(345, false, 0b1000_1001_0100_0011_0111_0010_0000_1101), // this source does not receive attribution (capped) ]; - let expected: [u128; 32] = [ + let mut expected: [u128; 32] = [ // 1101_0100_1111_0001_0111_0010_1010_1011 // 0001_1010_0011_0111_0110_0010_1111_0000 // + 0111_0101_0001_0000_0111_0100_0101_0011 @@ -449,6 +449,7 @@ pub mod tests { 1, 2, 1, 3, 1, 2, 1, 1, 1, 1, 2, 3, 0, 1, 1, 2, 0, 3, 3, 2, 0, 1, 2, 0, 2, 2, 2, 2, 1, 0, 2, 2, ]; + expected.reverse(); // convert to little-endian order let result: Vec<_> = world .semi_honest(records.into_iter(), |ctx, input_rows| async move { From 0106fb3526a2c3aaa8a8465b55272968a1d47020 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 5 Oct 2023 08:26:08 +0800 Subject: [PATCH 3/9] Now it compiles. thank you martin --- .../prf_sharding/feature_label_dot_product.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index 66dca764e..3fb35835d 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -2,6 +2,8 @@ use futures::{stream::iter as stream_iter, TryStreamExt}; use futures_util::{future::try_join, StreamExt}; use ipa_macros::Step; +use std::iter::zip; + use stream_flatten_iters::StreamExt as _; use crate::{ @@ -237,7 +239,7 @@ where for rows_for_user in rows_chunked_by_user { let num_user_rows = rows_for_user.len(); futures.push(evaluate_per_user_attribution_circuit( - &ctx_for_row_number, + ctx_for_row_number[..rows_for_user.len() - 1].to_owned(), num_users_who_encountered_row_depth .iter() .take(num_user_rows) @@ -280,7 +282,7 @@ where } async fn evaluate_per_user_attribution_circuit( - ctx_for_row_number: &[C], + ctx_for_row_number: Vec, record_id_for_each_depth: Vec, rows_for_user: Vec>, ) -> Result>>, Error> @@ -296,12 +298,15 @@ where let mut prev_row_inputs = initialize_new_device_attribution_variables(first_row); let mut output = Vec::with_capacity(rows_for_user.len() - 1); - for (i, row) in rows_for_user.iter().skip(1).enumerate() { - let ctx_for_this_row_depth = ctx_for_row_number[i].clone(); // no context was created for row 0 + // skip the first row as it requires no multiplications + // no context was created for the first row + for (i, (row, ctx)) in + zip(rows_for_user.iter().skip(1), ctx_for_row_number.into_iter()).enumerate() + { let record_id_for_this_row_depth = record_id_for_each_depth[i + 1]; // skip row 0 let capped_attribution_outputs = prev_row_inputs - .compute_row_with_previous(ctx_for_this_row_depth, record_id_for_this_row_depth, row) + .compute_row_with_previous(ctx, record_id_for_this_row_depth, row) .await?; output.push(capped_attribution_outputs); From 346240966a00f57016bbe974b5b0addc559f3466 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Fri, 6 Oct 2023 10:28:18 +0800 Subject: [PATCH 4/9] issue with unpin --- .../prf_sharding/feature_label_dot_product.rs | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index 3fb35835d..32615d1cc 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -1,5 +1,5 @@ use futures::{stream::iter as stream_iter, TryStreamExt}; -use futures_util::{future::try_join, StreamExt}; +use futures_util::{future::try_join, stream, StreamExt}; use ipa_macros::Step; use std::iter::zip; @@ -235,24 +235,36 @@ where let binary_m_ctx = binary_validator.context(); let mut num_users_who_encountered_row_depth = vec![0_u32; histogram.len()]; let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram); - let mut futures = Vec::with_capacity(rows_chunked_by_user.len()); - for rows_for_user in rows_chunked_by_user { - let num_user_rows = rows_for_user.len(); - futures.push(evaluate_per_user_attribution_circuit( - ctx_for_row_number[..rows_for_user.len() - 1].to_owned(), - num_users_who_encountered_row_depth - .iter() - .take(num_user_rows) - .map(|x| RecordId(*x)) - .collect(), - rows_for_user, - )); - for i in 0..num_user_rows { - num_users_who_encountered_row_depth[i] += 1; - } - } + let stream_of_per_user_circuits = stream::unfold( + ( + num_users_who_encountered_row_depth, + ctx_for_row_number, + stream_iter(rows_chunked_by_user), + ), + |state| async move { + let (mut count_by_row_depth, contexts, s) = state; + if let Some(rows_for_user) = s.next().await { + let num_user_rows = rows_for_user.len(); + let yielded = evaluate_per_user_attribution_circuit( + contexts[..num_user_rows - 1].to_owned(), + num_users_who_encountered_row_depth + .iter() + .take(num_user_rows) + .map(|x| RecordId(*x)) + .collect(), + rows_for_user, + ); + for i in 0..num_user_rows { + count_by_row_depth[i] += 1; + } + Some((yielded, (count_by_row_depth, contexts, s))) + } else { + None + } + }, + ); - let flattenned_stream = seq_join(sh_ctx.active_work(), stream_iter(futures)) + let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) .map(|x| x.unwrap().into_iter()) .flatten_iters(); From c3c6d96c4d72074a205baf739ab022c7e8ded463 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Mon, 9 Oct 2023 14:30:17 +0800 Subject: [PATCH 5/9] It finally works now --- .../prf_sharding/feature_label_dot_product.rs | 133 +++++++----------- 1 file changed, 50 insertions(+), 83 deletions(-) diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index 32615d1cc..4f313e8e9 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -1,8 +1,8 @@ use futures::{stream::iter as stream_iter, TryStreamExt}; -use futures_util::{future::try_join, stream, StreamExt}; +use futures_util::{future::try_join, stream::unfold, StreamExt}; use ipa_macros::Step; -use std::iter::zip; +use std::{iter::zip, pin::pin}; use stream_flatten_iters::StreamExt as _; @@ -141,19 +141,6 @@ pub(crate) enum Step { ModulusConvertFeatureVectorBits, } -fn compute_histogram_of_users_with_row_count(rows_chunked_by_user: &[Vec]) -> Vec { - let mut output = vec![]; - for user_rows in rows_chunked_by_user { - for j in 0..user_rows.len() { - if j >= output.len() { - output.push(0); - } - output[j] += 1; - } - } - output -} - fn set_up_contexts(root_ctx: &C, histogram: &[usize]) -> Vec where C: UpgradedContext>, @@ -172,30 +159,6 @@ where context_per_row_depth } -fn chunk_rows_by_user( - input_rows: Vec>, -) -> Vec>> -where - FV: GaloisField, -{ - let mut rows_for_user: Vec> = vec![]; - - let mut rows_chunked_by_user = vec![]; - for row in input_rows { - if rows_for_user.is_empty() || row.prf_of_match_key == rows_for_user[0].prf_of_match_key { - rows_for_user.push(row); - } else { - rows_chunked_by_user.push(rows_for_user); - rows_for_user = vec![row]; - } - } - if !rows_for_user.is_empty() { - rows_chunked_by_user.push(rows_for_user); - } - - rows_chunked_by_user -} - /// Sub-protocol of the PRF-sharded IPA Protocol /// /// After the computation of the per-user PRF, addition of dummy records and shuffling, @@ -216,6 +179,7 @@ where pub async fn compute_feature_label_dot_product( sh_ctx: C, input_rows: Vec>, + histogram: &[usize], ) -> Result, Error> where C: UpgradableContext, @@ -227,42 +191,40 @@ where { assert!(FV::BITS > 0); - let mut num_outputs = input_rows.len(); - let rows_chunked_by_user = chunk_rows_by_user(input_rows); - num_outputs -= rows_chunked_by_user.len(); - let histogram = compute_histogram_of_users_with_row_count(&rows_chunked_by_user); + let num_outputs = input_rows.len() - histogram[0]; + let mut input_stream = stream_iter(input_rows); + let first_row = input_stream.next().await.unwrap(); + let rows_chunked_by_user = unfold(Some((input_stream, first_row)), |state| async move { + if state.is_none() { + return None; + } + let (mut s, last_row) = state.unwrap(); + let last_row_prf = last_row.prf_of_match_key; + let mut current_chunk = vec![last_row]; + while let Some(row) = s.next().await { + if row.prf_of_match_key == last_row_prf { + current_chunk.push(row); + } else { + return Some((current_chunk, Some((s, row)))); + } + } + Some((current_chunk, None)) + }); + let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); let binary_m_ctx = binary_validator.context(); let mut num_users_who_encountered_row_depth = vec![0_u32; histogram.len()]; let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram); - let stream_of_per_user_circuits = stream::unfold( - ( - num_users_who_encountered_row_depth, - ctx_for_row_number, - stream_iter(rows_chunked_by_user), - ), - |state| async move { - let (mut count_by_row_depth, contexts, s) = state; - if let Some(rows_for_user) = s.next().await { - let num_user_rows = rows_for_user.len(); - let yielded = evaluate_per_user_attribution_circuit( - contexts[..num_user_rows - 1].to_owned(), - num_users_who_encountered_row_depth - .iter() - .take(num_user_rows) - .map(|x| RecordId(*x)) - .collect(), - rows_for_user, - ); - for i in 0..num_user_rows { - count_by_row_depth[i] += 1; - } - Some((yielded, (count_by_row_depth, contexts, s))) - } else { - None - } - }, - ); + let stream_of_per_user_circuits = pin!(rows_chunked_by_user.then(|rows_for_user| { + let num_user_rows = rows_for_user.len(); + let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); + let record_ids = num_users_who_encountered_row_depth[..num_user_rows].to_owned(); + + for i in 0..rows_for_user.len() { + num_users_who_encountered_row_depth[i] += 1; + } + async move { evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) } + })); let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) .map(|x| x.unwrap().into_iter()) @@ -295,7 +257,7 @@ where async fn evaluate_per_user_attribution_circuit( ctx_for_row_number: Vec, - record_id_for_each_depth: Vec, + record_id_for_each_depth: Vec, rows_for_user: Vec>, ) -> Result>>, Error> where @@ -315,7 +277,7 @@ where for (i, (row, ctx)) in zip(rows_for_user.iter().skip(1), ctx_for_row_number.into_iter()).enumerate() { - let record_id_for_this_row_depth = record_id_for_each_depth[i + 1]; // skip row 0 + let record_id_for_this_row_depth = RecordId(record_id_for_each_depth[i + 1]); // skip row 0 let capped_attribution_outputs = prev_row_inputs .compute_row_with_previous(ctx, record_id_for_this_row_depth, row) @@ -468,16 +430,21 @@ pub mod tests { ]; expected.reverse(); // convert to little-endian order - let result: Vec<_> = world - .semi_honest(records.into_iter(), |ctx, input_rows| async move { - compute_feature_label_dot_product::< - _, - Gf32Bit, - Fp32BitPrime, - Replicated, - >(ctx, input_rows) - .await - .unwrap() + let histogram = vec![3, 3, 2, 2, 1, 1, 1, 1]; + + let result: Vec = world + .semi_honest(records.into_iter(), |ctx, input_rows| { + let h = histogram.as_slice(); + async move { + compute_feature_label_dot_product::< + _, + Gf32Bit, + Fp32BitPrime, + Replicated, + >(ctx, input_rows, h) + .await + .unwrap() + } }) .await .reconstruct(); From 9f6b0946b93d2782126680a5c88de96d8ef4a317 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Mon, 9 Oct 2023 15:16:30 +0800 Subject: [PATCH 6/9] Cleaned up and converted everything to streams --- .../prf_sharding/feature_label_dot_product.rs | 101 ++++++++++++------ 1 file changed, 71 insertions(+), 30 deletions(-) diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index 4f313e8e9..d35cc6951 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -1,5 +1,5 @@ use futures::{stream::iter as stream_iter, TryStreamExt}; -use futures_util::{future::try_join, stream::unfold, StreamExt}; +use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; use ipa_macros::Step; use std::{iter::zip, pin::pin}; @@ -159,6 +159,34 @@ where context_per_row_depth } +/// +/// Takes an input stream of `PrfShardedIpaInputRecordRow` which is assumed to have all records with a given PRF adjacent +/// and converts it into a stream of vectors of `PrfShardedIpaInputRecordRow` having the same PRF. +/// +fn chunk_rows_by_user( + input_stream: IS, + first_row: PrfShardedIpaInputRow, +) -> impl Stream>> +where + FV: GaloisField, + IS: Stream> + Unpin, +{ + unfold(Some((input_stream, first_row)), |state| async move { + state.as_ref()?; + let (mut s, last_row) = state.unwrap(); + let last_row_prf = last_row.prf_of_match_key; + let mut current_chunk = vec![last_row]; + while let Some(row) = s.next().await { + if row.prf_of_match_key == last_row_prf { + current_chunk.push(row); + } else { + return Some((current_chunk, Some((s, row)))); + } + } + Some((current_chunk, None)) + }) +} + /// Sub-protocol of the PRF-sharded IPA Protocol /// /// After the computation of the per-user PRF, addition of dummy records and shuffling, @@ -166,16 +194,35 @@ where /// device can be processed together. /// /// This circuit expects to receive records from multiple users, -/// but with all of the records from a given user adjacent to one another, and in time order. +/// but with all of the records from a given user adjacent to one another, and in reverse time order (most recent event comes first). /// /// This circuit will compute attribution, and per-user capping. /// -/// The output of this circuit is the input to the next stage: Aggregation. +/// After those steps, source events to which trigger events were attributed will contribute their feature vectors to an aggregate +/// +/// The aggregate is just the sum of all the feature vectors of source events which received attribution +/// +/// This is useful for performing logistic regression: `https://github.com/patcg-individual-drafts/ipa/blob/main/logistic_regression.md` +/// +/// Due to limitation in our infra, it's necessary to set the total number of records each channel will ever need to process. +/// The number of records each channel processes is a function of the distribution of number of records per user. +/// Rather than calculate this histogram within this function (challenging to do while streaming), at present the caller must pass this in. +/// +/// The count at a given index indicates the number of users having at least that many rows of data. +/// +/// Example: +/// If the input is from 3 users, +/// - the first having 2 records +/// - the second having 4 records +/// - the third having 6 records +/// Then the histogram that should be provided is: +/// - [3, 3, 2, 2, 1, 1] /// /// # Errors /// Propagates errors from multiplications /// # Panics /// Propagates errors from multiplications +#[allow(clippy::async_yields_async)] pub async fn compute_feature_label_dot_product( sh_ctx: C, input_rows: Vec>, @@ -191,49 +238,42 @@ where { assert!(FV::BITS > 0); + // Get the validator and context to use for Gf2 multiplication operations + let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); + let binary_m_ctx = binary_validator.context(); + + // Get the validator and context to use for `Z_p` operations (modulus conversion) + let prime_field_validator = sh_ctx.narrow(&Step::PrimeFieldValidator).validator::(); + let prime_field_ctx = prime_field_validator.context(); + + // Tricky hacks to work around the limitations of our current infrastructure let num_outputs = input_rows.len() - histogram[0]; + let mut record_id_for_row_depth = vec![0_u32; histogram.len()]; + let ctx_for_row_number = set_up_contexts(&binary_m_ctx, histogram); + + // Chunk the incoming stream of records into stream of vectors of records with the same PRF let mut input_stream = stream_iter(input_rows); let first_row = input_stream.next().await.unwrap(); - let rows_chunked_by_user = unfold(Some((input_stream, first_row)), |state| async move { - if state.is_none() { - return None; - } - let (mut s, last_row) = state.unwrap(); - let last_row_prf = last_row.prf_of_match_key; - let mut current_chunk = vec![last_row]; - while let Some(row) = s.next().await { - if row.prf_of_match_key == last_row_prf { - current_chunk.push(row); - } else { - return Some((current_chunk, Some((s, row)))); - } - } - Some((current_chunk, None)) - }); + let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row); - let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); - let binary_m_ctx = binary_validator.context(); - let mut num_users_who_encountered_row_depth = vec![0_u32; histogram.len()]; - let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram); + // Convert to a stream of async futures that represent the result of executing the per-user circuit let stream_of_per_user_circuits = pin!(rows_chunked_by_user.then(|rows_for_user| { let num_user_rows = rows_for_user.len(); let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); - let record_ids = num_users_who_encountered_row_depth[..num_user_rows].to_owned(); + let record_ids = record_id_for_row_depth[..num_user_rows].to_owned(); - for i in 0..rows_for_user.len() { - num_users_who_encountered_row_depth[i] += 1; + for count in record_id_for_row_depth.iter_mut().take(rows_for_user.len()) { + *count += 1; } async move { evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) } })); + // Execute all of the async futures (sequentially), and flatten the result let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) .map(|x| x.unwrap().into_iter()) .flatten_iters(); - let prime_field_validator = sh_ctx.narrow(&Step::PrimeFieldValidator).validator::(); - let prime_field_ctx = prime_field_validator.context(); - - // modulus convert feature vector bits + // modulus convert feature vector bits from shares in `Z_2` to shares in `Z_p` let converted_feature_vector_bits = convert_bits( prime_field_ctx .narrow(&Step::ModulusConvertFeatureVectorBits) @@ -242,6 +282,7 @@ where 0..FV::BITS, ); + // Sum up all the vectors converted_feature_vector_bits .try_fold( vec![S::ZERO; usize::try_from(FV::BITS).unwrap()], From 3ccef7bf308a1bfdf6d47385c0190eca5a5f2422 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Mon, 9 Oct 2023 15:47:20 +0800 Subject: [PATCH 7/9] clippy imports thing --- src/protocol/prf_sharding/feature_label_dot_product.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index d35cc6951..a326c02ef 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -1,9 +1,8 @@ +use std::{iter::zip, pin::pin}; + use futures::{stream::iter as stream_iter, TryStreamExt}; use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; use ipa_macros::Step; - -use std::{iter::zip, pin::pin}; - use stream_flatten_iters::StreamExt as _; use crate::{ From 7fd070c0f45beb8a9b3325534d1f6ae0cc1e5144 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Mon, 9 Oct 2023 16:59:19 +0800 Subject: [PATCH 8/9] Config out the compact gate stuff --- src/protocol/prf_sharding/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/protocol/prf_sharding/mod.rs b/src/protocol/prf_sharding/mod.rs index fdd3a6440..4c950251c 100644 --- a/src/protocol/prf_sharding/mod.rs +++ b/src/protocol/prf_sharding/mod.rs @@ -28,6 +28,7 @@ use crate::{ seq_join::{seq_join, seq_try_join_all}, }; +#[cfg(feature = "descriptive-gate")] pub mod feature_label_dot_product; pub struct PrfShardedIpaInputRow { From 7ca5813e69750419944466b2daf1004db6f48c52 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 10 Oct 2023 15:34:56 +0800 Subject: [PATCH 9/9] comments from Alex --- Cargo.toml | 1 - .../prf_sharding/feature_label_dot_product.rs | 24 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a72bfcfbf..e264d503f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -95,7 +95,6 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } typenum = "1.16" # hpke is pinned to it x25519-dalek = "2.0.0-pre.0" -stream-flatten-iters = "0.2.0" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.5.0" diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index a326c02ef..0dfc28f3a 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -3,7 +3,6 @@ use std::{iter::zip, pin::pin}; use futures::{stream::iter as stream_iter, TryStreamExt}; use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; use ipa_macros::Step; -use stream_flatten_iters::StreamExt as _; use crate::{ error::Error, @@ -171,8 +170,7 @@ where IS: Stream> + Unpin, { unfold(Some((input_stream, first_row)), |state| async move { - state.as_ref()?; - let (mut s, last_row) = state.unwrap(); + let (mut s, last_row) = state?; let last_row_prf = last_row.prf_of_match_key; let mut current_chunk = vec![last_row]; while let Some(row) = s.next().await { @@ -221,7 +219,6 @@ where /// Propagates errors from multiplications /// # Panics /// Propagates errors from multiplications -#[allow(clippy::async_yields_async)] pub async fn compute_feature_label_dot_product( sh_ctx: C, input_rows: Vec>, @@ -252,7 +249,11 @@ where // Chunk the incoming stream of records into stream of vectors of records with the same PRF let mut input_stream = stream_iter(input_rows); - let first_row = input_stream.next().await.unwrap(); + let first_row = input_stream.next().await; + if first_row.is_none() { + return Ok(vec![]); + } + let first_row = first_row.unwrap(); let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row); // Convert to a stream of async futures that represent the result of executing the per-user circuit @@ -261,16 +262,19 @@ where let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); let record_ids = record_id_for_row_depth[..num_user_rows].to_owned(); - for count in record_id_for_row_depth.iter_mut().take(rows_for_user.len()) { + for count in &mut record_id_for_row_depth[..num_user_rows] { *count += 1; } - async move { evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) } + #[allow(clippy::async_yields_async)] + // this is ok, because seq join wants a stream of futures + async move { + evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) + } })); // Execute all of the async futures (sequentially), and flatten the result let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) - .map(|x| x.unwrap().into_iter()) - .flatten_iters(); + .flat_map(|x| stream_iter(x.unwrap())); // modulus convert feature vector bits from shares in `Z_2` to shares in `Z_p` let converted_feature_vector_bits = convert_bits( @@ -317,7 +321,7 @@ where for (i, (row, ctx)) in zip(rows_for_user.iter().skip(1), ctx_for_row_number.into_iter()).enumerate() { - let record_id_for_this_row_depth = RecordId(record_id_for_each_depth[i + 1]); // skip row 0 + let record_id_for_this_row_depth = RecordId::from(record_id_for_each_depth[i + 1]); // skip row 0 let capped_attribution_outputs = prev_row_inputs .compute_row_with_previous(ctx, record_id_for_this_row_depth, row)