From 17af9244be6675b3c000d457f294a2cffa8d7517 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 7 Nov 2023 00:41:15 +0800 Subject: [PATCH 1/3] debugging stall in streaming OPRF ipa --- src/protocol/ipa_prf/prf_sharding/mod.rs | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/protocol/ipa_prf/prf_sharding/mod.rs b/src/protocol/ipa_prf/prf_sharding/mod.rs index db62c95e0..29c29385d 100644 --- a/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -102,6 +102,7 @@ impl InputsRequiredFromPrevRow { pub async fn compute_row_with_previous( &mut self, ctx: C, + depth: usize, record_id: RecordId, input_row: &PrfShardedIpaInputRow, num_saturating_sum_bits: usize, @@ -113,6 +114,7 @@ impl InputsRequiredFromPrevRow { TV: GaloisField, TS: GaloisField, { + println!("depth: {}, record_id: {}", depth, usize::from(record_id)); let (bd_key, tv, timestamp) = ( input_row.breakdown_key_bits(), input_row.trigger_value_bits(), @@ -349,6 +351,8 @@ where /// Takes an input stream of `PrfShardedIpaInputRecordRow` which is assumed to have all records with a given PRF adjacent /// and converts it into a stream of vectors of `PrfShardedIpaInputRecordRow` having the same PRF. /// +/// Filters out any users that only have a single row, since they will produce no attributed conversions. +/// fn chunk_rows_by_user( input_stream: IS, first_row: PrfShardedIpaInputRow, @@ -361,13 +365,18 @@ where { unfold(Some((input_stream, first_row)), |state| async move { let (mut s, last_row) = state?; - let last_row_prf = last_row.prf_of_match_key; + let mut last_row_prf = last_row.prf_of_match_key; let mut current_chunk = vec![last_row]; while let Some(row) = s.next().await { if row.prf_of_match_key == last_row_prf { current_chunk.push(row); } else { - return Some((current_chunk, Some((s, row)))); + if current_chunk.len() > 1 { + return Some((current_chunk, Some((s, row)))); + } else { + last_row_prf = row.prf_of_match_key; + current_chunk = vec![row]; + } } } Some((current_chunk, None)) @@ -413,6 +422,8 @@ where assert!(BK::BITS > 0); assert!(TS::BITS > 0); + println!("histogram: {:?}", histogram); + // Get the validator and context to use for Gf2 multiplication operations let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); let binary_m_ctx = binary_validator.context(); @@ -458,16 +469,15 @@ where })); // Execute all of the async futures (sequentially), and flatten the result - let collected_per_user_results = stream_of_per_user_circuits.collect::>().await; - let per_user_attribution_outputs = sh_ctx.parallel_join(collected_per_user_results).await?; - let flattenned_stream = per_user_attribution_outputs.into_iter().flatten(); + let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) + .flat_map(|x| stream_iter(x.unwrap())); // modulus convert breakdown keys and trigger values let converted_bks_and_tvs = convert_bits( prime_field_ctx .narrow(&Step::ModulusConvertBreakdownKeyBitsAndTriggerValues) .set_total_records(num_outputs), - stream_iter(flattenned_stream), + flattenned_stream, 0..BK::BITS + TV::BITS, ); @@ -543,6 +553,7 @@ where let capped_attribution_outputs = prev_row_inputs .compute_row_with_previous( ctx_for_this_row_depth, + i, record_id_for_this_row_depth, row, num_saturating_sum_bits, From 8e8b0e4af6847e04aeb29461ab1cce59df13a0a8 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 7 Nov 2023 21:48:38 +0800 Subject: [PATCH 2/3] OMG it works --- src/protocol/ipa_prf/prf_sharding/mod.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/protocol/ipa_prf/prf_sharding/mod.rs b/src/protocol/ipa_prf/prf_sharding/mod.rs index 29c29385d..636a41598 100644 --- a/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -102,7 +102,6 @@ impl InputsRequiredFromPrevRow { pub async fn compute_row_with_previous( &mut self, ctx: C, - depth: usize, record_id: RecordId, input_row: &PrfShardedIpaInputRow, num_saturating_sum_bits: usize, @@ -114,7 +113,6 @@ impl InputsRequiredFromPrevRow { TV: GaloisField, TS: GaloisField, { - println!("depth: {}, record_id: {}", depth, usize::from(record_id)); let (bd_key, tv, timestamp) = ( input_row.breakdown_key_bits(), input_row.trigger_value_bits(), @@ -422,8 +420,6 @@ where assert!(BK::BITS > 0); assert!(TS::BITS > 0); - println!("histogram: {:?}", histogram); - // Get the validator and context to use for Gf2 multiplication operations let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); let binary_m_ctx = binary_validator.context(); @@ -446,8 +442,11 @@ where let first_row = first_row.unwrap(); let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row); + let mut collected = rows_chunked_by_user.collect::>().await; + collected.sort_by(|a, b| std::cmp::Ord::cmp(&b.len(), &a.len())); + // Convert to a stream of async futures that represent the result of executing the per-user circuit - let stream_of_per_user_circuits = pin!(rows_chunked_by_user.then(|rows_for_user| { + let stream_of_per_user_circuits = pin!(stream_iter(collected).then(|rows_for_user| { let num_user_rows = rows_for_user.len(); let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); let record_ids = record_id_for_row_depth[..num_user_rows].to_owned(); @@ -553,7 +552,6 @@ where let capped_attribution_outputs = prev_row_inputs .compute_row_with_previous( ctx_for_this_row_depth, - i, record_id_for_this_row_depth, row, num_saturating_sum_bits, From 50ef10e7c4da378b36ac4233a7419624b990d27b Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 7 Nov 2023 21:55:38 +0800 Subject: [PATCH 3/3] making clippy happy --- src/protocol/ipa_prf/prf_sharding/mod.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/protocol/ipa_prf/prf_sharding/mod.rs b/src/protocol/ipa_prf/prf_sharding/mod.rs index 636a41598..5e0ae01d5 100644 --- a/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -368,13 +368,11 @@ where while let Some(row) = s.next().await { if row.prf_of_match_key == last_row_prf { current_chunk.push(row); + } else if current_chunk.len() > 1 { + return Some((current_chunk, Some((s, row)))); } else { - if current_chunk.len() > 1 { - return Some((current_chunk, Some((s, row)))); - } else { - last_row_prf = row.prf_of_match_key; - current_chunk = vec![row]; - } + last_row_prf = row.prf_of_match_key; + current_chunk = vec![row]; } } Some((current_chunk, None))