diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index 4f313e8e9..d35cc6951 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -1,5 +1,5 @@ use futures::{stream::iter as stream_iter, TryStreamExt}; -use futures_util::{future::try_join, stream::unfold, StreamExt}; +use futures_util::{future::try_join, stream::unfold, Stream, StreamExt}; use ipa_macros::Step; use std::{iter::zip, pin::pin}; @@ -159,6 +159,34 @@ where context_per_row_depth } +/// +/// Takes an input stream of `PrfShardedIpaInputRecordRow` which is assumed to have all records with a given PRF adjacent +/// and converts it into a stream of vectors of `PrfShardedIpaInputRecordRow` having the same PRF. +/// +fn chunk_rows_by_user( + input_stream: IS, + first_row: PrfShardedIpaInputRow, +) -> impl Stream>> +where + FV: GaloisField, + IS: Stream> + Unpin, +{ + unfold(Some((input_stream, first_row)), |state| async move { + state.as_ref()?; + let (mut s, last_row) = state.unwrap(); + let last_row_prf = last_row.prf_of_match_key; + let mut current_chunk = vec![last_row]; + while let Some(row) = s.next().await { + if row.prf_of_match_key == last_row_prf { + current_chunk.push(row); + } else { + return Some((current_chunk, Some((s, row)))); + } + } + Some((current_chunk, None)) + }) +} + /// Sub-protocol of the PRF-sharded IPA Protocol /// /// After the computation of the per-user PRF, addition of dummy records and shuffling, @@ -166,16 +194,35 @@ where /// device can be processed together. /// /// This circuit expects to receive records from multiple users, -/// but with all of the records from a given user adjacent to one another, and in time order. +/// but with all of the records from a given user adjacent to one another, and in reverse time order (most recent event comes first). /// /// This circuit will compute attribution, and per-user capping. /// -/// The output of this circuit is the input to the next stage: Aggregation. +/// After those steps, source events to which trigger events were attributed will contribute their feature vectors to an aggregate +/// +/// The aggregate is just the sum of all the feature vectors of source events which received attribution +/// +/// This is useful for performing logistic regression: `https://github.com/patcg-individual-drafts/ipa/blob/main/logistic_regression.md` +/// +/// Due to limitation in our infra, it's necessary to set the total number of records each channel will ever need to process. +/// The number of records each channel processes is a function of the distribution of number of records per user. +/// Rather than calculate this histogram within this function (challenging to do while streaming), at present the caller must pass this in. +/// +/// The count at a given index indicates the number of users having at least that many rows of data. +/// +/// Example: +/// If the input is from 3 users, +/// - the first having 2 records +/// - the second having 4 records +/// - the third having 6 records +/// Then the histogram that should be provided is: +/// - [3, 3, 2, 2, 1, 1] /// /// # Errors /// Propagates errors from multiplications /// # Panics /// Propagates errors from multiplications +#[allow(clippy::async_yields_async)] pub async fn compute_feature_label_dot_product( sh_ctx: C, input_rows: Vec>, @@ -191,49 +238,42 @@ where { assert!(FV::BITS > 0); + // Get the validator and context to use for Gf2 multiplication operations + let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); + let binary_m_ctx = binary_validator.context(); + + // Get the validator and context to use for `Z_p` operations (modulus conversion) + let prime_field_validator = sh_ctx.narrow(&Step::PrimeFieldValidator).validator::(); + let prime_field_ctx = prime_field_validator.context(); + + // Tricky hacks to work around the limitations of our current infrastructure let num_outputs = input_rows.len() - histogram[0]; + let mut record_id_for_row_depth = vec![0_u32; histogram.len()]; + let ctx_for_row_number = set_up_contexts(&binary_m_ctx, histogram); + + // Chunk the incoming stream of records into stream of vectors of records with the same PRF let mut input_stream = stream_iter(input_rows); let first_row = input_stream.next().await.unwrap(); - let rows_chunked_by_user = unfold(Some((input_stream, first_row)), |state| async move { - if state.is_none() { - return None; - } - let (mut s, last_row) = state.unwrap(); - let last_row_prf = last_row.prf_of_match_key; - let mut current_chunk = vec![last_row]; - while let Some(row) = s.next().await { - if row.prf_of_match_key == last_row_prf { - current_chunk.push(row); - } else { - return Some((current_chunk, Some((s, row)))); - } - } - Some((current_chunk, None)) - }); + let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row); - let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); - let binary_m_ctx = binary_validator.context(); - let mut num_users_who_encountered_row_depth = vec![0_u32; histogram.len()]; - let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram); + // Convert to a stream of async futures that represent the result of executing the per-user circuit let stream_of_per_user_circuits = pin!(rows_chunked_by_user.then(|rows_for_user| { let num_user_rows = rows_for_user.len(); let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); - let record_ids = num_users_who_encountered_row_depth[..num_user_rows].to_owned(); + let record_ids = record_id_for_row_depth[..num_user_rows].to_owned(); - for i in 0..rows_for_user.len() { - num_users_who_encountered_row_depth[i] += 1; + for count in record_id_for_row_depth.iter_mut().take(rows_for_user.len()) { + *count += 1; } async move { evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) } })); + // Execute all of the async futures (sequentially), and flatten the result let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) .map(|x| x.unwrap().into_iter()) .flatten_iters(); - let prime_field_validator = sh_ctx.narrow(&Step::PrimeFieldValidator).validator::(); - let prime_field_ctx = prime_field_validator.context(); - - // modulus convert feature vector bits + // modulus convert feature vector bits from shares in `Z_2` to shares in `Z_p` let converted_feature_vector_bits = convert_bits( prime_field_ctx .narrow(&Step::ModulusConvertFeatureVectorBits) @@ -242,6 +282,7 @@ where 0..FV::BITS, ); + // Sum up all the vectors converted_feature_vector_bits .try_fold( vec![S::ZERO; usize::try_from(FV::BITS).unwrap()],