From 17af9244be6675b3c000d457f294a2cffa8d7517 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 7 Nov 2023 00:41:15 +0800 Subject: [PATCH] debugging stall in streaming OPRF ipa --- src/protocol/ipa_prf/prf_sharding/mod.rs | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/protocol/ipa_prf/prf_sharding/mod.rs b/src/protocol/ipa_prf/prf_sharding/mod.rs index db62c95e0..29c29385d 100644 --- a/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -102,6 +102,7 @@ impl InputsRequiredFromPrevRow { pub async fn compute_row_with_previous( &mut self, ctx: C, + depth: usize, record_id: RecordId, input_row: &PrfShardedIpaInputRow, num_saturating_sum_bits: usize, @@ -113,6 +114,7 @@ impl InputsRequiredFromPrevRow { TV: GaloisField, TS: GaloisField, { + println!("depth: {}, record_id: {}", depth, usize::from(record_id)); let (bd_key, tv, timestamp) = ( input_row.breakdown_key_bits(), input_row.trigger_value_bits(), @@ -349,6 +351,8 @@ where /// Takes an input stream of `PrfShardedIpaInputRecordRow` which is assumed to have all records with a given PRF adjacent /// and converts it into a stream of vectors of `PrfShardedIpaInputRecordRow` having the same PRF. /// +/// Filters out any users that only have a single row, since they will produce no attributed conversions. +/// fn chunk_rows_by_user( input_stream: IS, first_row: PrfShardedIpaInputRow, @@ -361,13 +365,18 @@ where { unfold(Some((input_stream, first_row)), |state| async move { let (mut s, last_row) = state?; - let last_row_prf = last_row.prf_of_match_key; + let mut last_row_prf = last_row.prf_of_match_key; let mut current_chunk = vec![last_row]; while let Some(row) = s.next().await { if row.prf_of_match_key == last_row_prf { current_chunk.push(row); } else { - return Some((current_chunk, Some((s, row)))); + if current_chunk.len() > 1 { + return Some((current_chunk, Some((s, row)))); + } else { + last_row_prf = row.prf_of_match_key; + current_chunk = vec![row]; + } } } Some((current_chunk, None)) @@ -413,6 +422,8 @@ where assert!(BK::BITS > 0); assert!(TS::BITS > 0); + println!("histogram: {:?}", histogram); + // Get the validator and context to use for Gf2 multiplication operations let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); let binary_m_ctx = binary_validator.context(); @@ -458,16 +469,15 @@ where })); // Execute all of the async futures (sequentially), and flatten the result - let collected_per_user_results = stream_of_per_user_circuits.collect::>().await; - let per_user_attribution_outputs = sh_ctx.parallel_join(collected_per_user_results).await?; - let flattenned_stream = per_user_attribution_outputs.into_iter().flatten(); + let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) + .flat_map(|x| stream_iter(x.unwrap())); // modulus convert breakdown keys and trigger values let converted_bks_and_tvs = convert_bits( prime_field_ctx .narrow(&Step::ModulusConvertBreakdownKeyBitsAndTriggerValues) .set_total_records(num_outputs), - stream_iter(flattenned_stream), + flattenned_stream, 0..BK::BITS + TV::BITS, ); @@ -543,6 +553,7 @@ where let capped_attribution_outputs = prev_row_inputs .compute_row_with_previous( ctx_for_this_row_depth, + i, record_id_for_this_row_depth, row, num_saturating_sum_bits,