Skip to content

Commit

Permalink
debugging stall in streaming OPRF ipa
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage committed Nov 7, 2023
1 parent 079b476 commit 17af924
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ impl InputsRequiredFromPrevRow {
pub async fn compute_row_with_previous<C, BK, TV, TS>(
&mut self,
ctx: C,
depth: usize,
record_id: RecordId,
input_row: &PrfShardedIpaInputRow<BK, TV, TS>,
num_saturating_sum_bits: usize,
Expand All @@ -113,6 +114,7 @@ impl InputsRequiredFromPrevRow {
TV: GaloisField,
TS: GaloisField,
{
println!("depth: {}, record_id: {}", depth, usize::from(record_id));
let (bd_key, tv, timestamp) = (
input_row.breakdown_key_bits(),
input_row.trigger_value_bits(),
Expand Down Expand Up @@ -349,6 +351,8 @@ where
/// Takes an input stream of `PrfShardedIpaInputRecordRow` which is assumed to have all records with a given PRF adjacent
/// and converts it into a stream of vectors of `PrfShardedIpaInputRecordRow` having the same PRF.
///
/// Filters out any users that only have a single row, since they will produce no attributed conversions.
///
fn chunk_rows_by_user<IS, BK, TV, TS>(
input_stream: IS,
first_row: PrfShardedIpaInputRow<BK, TV, TS>,
Expand All @@ -361,13 +365,18 @@ where
{
unfold(Some((input_stream, first_row)), |state| async move {
let (mut s, last_row) = state?;
let last_row_prf = last_row.prf_of_match_key;
let mut last_row_prf = last_row.prf_of_match_key;
let mut current_chunk = vec![last_row];
while let Some(row) = s.next().await {
if row.prf_of_match_key == last_row_prf {
current_chunk.push(row);
} else {
return Some((current_chunk, Some((s, row))));
if current_chunk.len() > 1 {
return Some((current_chunk, Some((s, row))));
} else {
last_row_prf = row.prf_of_match_key;
current_chunk = vec![row];
}
}
}
Some((current_chunk, None))
Expand Down Expand Up @@ -413,6 +422,8 @@ where
assert!(BK::BITS > 0);
assert!(TS::BITS > 0);

println!("histogram: {:?}", histogram);

// Get the validator and context to use for Gf2 multiplication operations
let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::<Gf2>();
let binary_m_ctx = binary_validator.context();
Expand Down Expand Up @@ -458,16 +469,15 @@ where
}));

// Execute all of the async futures (sequentially), and flatten the result
let collected_per_user_results = stream_of_per_user_circuits.collect::<Vec<_>>().await;
let per_user_attribution_outputs = sh_ctx.parallel_join(collected_per_user_results).await?;
let flattenned_stream = per_user_attribution_outputs.into_iter().flatten();
let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits)
.flat_map(|x| stream_iter(x.unwrap()));

// modulus convert breakdown keys and trigger values
let converted_bks_and_tvs = convert_bits(
prime_field_ctx
.narrow(&Step::ModulusConvertBreakdownKeyBitsAndTriggerValues)
.set_total_records(num_outputs),
stream_iter(flattenned_stream),
flattenned_stream,
0..BK::BITS + TV::BITS,
);

Expand Down Expand Up @@ -543,6 +553,7 @@ where
let capped_attribution_outputs = prev_row_inputs
.compute_row_with_previous(
ctx_for_this_row_depth,
i,
record_id_for_this_row_depth,
row,
num_saturating_sum_bits,
Expand Down

0 comments on commit 17af924

Please sign in to comment.