Skip to content

Commit

Permalink
Cleaned up and converted everything to streams
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage committed Oct 9, 2023
1 parent c3c6d96 commit 9f6b094
Showing 1 changed file with 71 additions and 30 deletions.
101 changes: 71 additions & 30 deletions src/protocol/prf_sharding/feature_label_dot_product.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use futures::{stream::iter as stream_iter, TryStreamExt};
use futures_util::{future::try_join, stream::unfold, StreamExt};
use futures_util::{future::try_join, stream::unfold, Stream, StreamExt};
use ipa_macros::Step;

use std::{iter::zip, pin::pin};
Expand Down Expand Up @@ -159,23 +159,70 @@ where
context_per_row_depth
}

///
/// Takes an input stream of `PrfShardedIpaInputRecordRow` which is assumed to have all records with a given PRF adjacent
/// and converts it into a stream of vectors of `PrfShardedIpaInputRecordRow` having the same PRF.
///
fn chunk_rows_by_user<FV, IS>(
input_stream: IS,
first_row: PrfShardedIpaInputRow<FV>,
) -> impl Stream<Item = Vec<PrfShardedIpaInputRow<FV>>>
where
FV: GaloisField,
IS: Stream<Item = PrfShardedIpaInputRow<FV>> + Unpin,
{
unfold(Some((input_stream, first_row)), |state| async move {
state.as_ref()?;
let (mut s, last_row) = state.unwrap();
let last_row_prf = last_row.prf_of_match_key;
let mut current_chunk = vec![last_row];
while let Some(row) = s.next().await {
if row.prf_of_match_key == last_row_prf {
current_chunk.push(row);
} else {
return Some((current_chunk, Some((s, row))));
}
}
Some((current_chunk, None))
})
}

/// Sub-protocol of the PRF-sharded IPA Protocol
///
/// After the computation of the per-user PRF, addition of dummy records and shuffling,
/// the PRF column can be revealed. After that, all of the records corresponding to a single
/// device can be processed together.
///
/// This circuit expects to receive records from multiple users,
/// but with all of the records from a given user adjacent to one another, and in time order.
/// but with all of the records from a given user adjacent to one another, and in reverse time order (most recent event comes first).
///
/// This circuit will compute attribution, and per-user capping.
///
/// The output of this circuit is the input to the next stage: Aggregation.
/// After those steps, source events to which trigger events were attributed will contribute their feature vectors to an aggregate
///
/// The aggregate is just the sum of all the feature vectors of source events which received attribution
///
/// This is useful for performing logistic regression: `https://github.com/patcg-individual-drafts/ipa/blob/main/logistic_regression.md`
///
/// Due to limitation in our infra, it's necessary to set the total number of records each channel will ever need to process.
/// The number of records each channel processes is a function of the distribution of number of records per user.
/// Rather than calculate this histogram within this function (challenging to do while streaming), at present the caller must pass this in.
///
/// The count at a given index indicates the number of users having at least that many rows of data.
///
/// Example:
/// If the input is from 3 users,
/// - the first having 2 records
/// - the second having 4 records
/// - the third having 6 records
/// Then the histogram that should be provided is:
/// - [3, 3, 2, 2, 1, 1]
///
/// # Errors
/// Propagates errors from multiplications
/// # Panics
/// Propagates errors from multiplications
#[allow(clippy::async_yields_async)]
pub async fn compute_feature_label_dot_product<C, FV, F, S>(
sh_ctx: C,
input_rows: Vec<PrfShardedIpaInputRow<FV>>,
Expand All @@ -191,49 +238,42 @@ where
{
assert!(FV::BITS > 0);

// Get the validator and context to use for Gf2 multiplication operations
let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::<Gf2>();
let binary_m_ctx = binary_validator.context();

// Get the validator and context to use for `Z_p` operations (modulus conversion)
let prime_field_validator = sh_ctx.narrow(&Step::PrimeFieldValidator).validator::<F>();
let prime_field_ctx = prime_field_validator.context();

// Tricky hacks to work around the limitations of our current infrastructure
let num_outputs = input_rows.len() - histogram[0];
let mut record_id_for_row_depth = vec![0_u32; histogram.len()];
let ctx_for_row_number = set_up_contexts(&binary_m_ctx, histogram);

// Chunk the incoming stream of records into stream of vectors of records with the same PRF
let mut input_stream = stream_iter(input_rows);
let first_row = input_stream.next().await.unwrap();
let rows_chunked_by_user = unfold(Some((input_stream, first_row)), |state| async move {
if state.is_none() {
return None;
}
let (mut s, last_row) = state.unwrap();
let last_row_prf = last_row.prf_of_match_key;
let mut current_chunk = vec![last_row];
while let Some(row) = s.next().await {
if row.prf_of_match_key == last_row_prf {
current_chunk.push(row);
} else {
return Some((current_chunk, Some((s, row))));
}
}
Some((current_chunk, None))
});
let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row);

let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::<Gf2>();
let binary_m_ctx = binary_validator.context();
let mut num_users_who_encountered_row_depth = vec![0_u32; histogram.len()];
let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram);
// Convert to a stream of async futures that represent the result of executing the per-user circuit
let stream_of_per_user_circuits = pin!(rows_chunked_by_user.then(|rows_for_user| {
let num_user_rows = rows_for_user.len();
let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned();
let record_ids = num_users_who_encountered_row_depth[..num_user_rows].to_owned();
let record_ids = record_id_for_row_depth[..num_user_rows].to_owned();

for i in 0..rows_for_user.len() {
num_users_who_encountered_row_depth[i] += 1;
for count in record_id_for_row_depth.iter_mut().take(rows_for_user.len()) {
*count += 1;
}
async move { evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) }
}));

// Execute all of the async futures (sequentially), and flatten the result
let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits)
.map(|x| x.unwrap().into_iter())
.flatten_iters();

let prime_field_validator = sh_ctx.narrow(&Step::PrimeFieldValidator).validator::<F>();
let prime_field_ctx = prime_field_validator.context();

// modulus convert feature vector bits
// modulus convert feature vector bits from shares in `Z_2` to shares in `Z_p`
let converted_feature_vector_bits = convert_bits(
prime_field_ctx
.narrow(&Step::ModulusConvertFeatureVectorBits)
Expand All @@ -242,6 +282,7 @@ where
0..FV::BITS,
);

// Sum up all the vectors
converted_feature_vector_bits
.try_fold(
vec![S::ZERO; usize::try_from(FV::BITS).unwrap()],
Expand Down

0 comments on commit 9f6b094

Please sign in to comment.