From 8e8b0e4af6847e04aeb29461ab1cce59df13a0a8 Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Tue, 7 Nov 2023 21:48:38 +0800 Subject: [PATCH] OMG it works --- src/protocol/ipa_prf/prf_sharding/mod.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/protocol/ipa_prf/prf_sharding/mod.rs b/src/protocol/ipa_prf/prf_sharding/mod.rs index 29c29385d..636a41598 100644 --- a/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -102,7 +102,6 @@ impl InputsRequiredFromPrevRow { pub async fn compute_row_with_previous( &mut self, ctx: C, - depth: usize, record_id: RecordId, input_row: &PrfShardedIpaInputRow, num_saturating_sum_bits: usize, @@ -114,7 +113,6 @@ impl InputsRequiredFromPrevRow { TV: GaloisField, TS: GaloisField, { - println!("depth: {}, record_id: {}", depth, usize::from(record_id)); let (bd_key, tv, timestamp) = ( input_row.breakdown_key_bits(), input_row.trigger_value_bits(), @@ -422,8 +420,6 @@ where assert!(BK::BITS > 0); assert!(TS::BITS > 0); - println!("histogram: {:?}", histogram); - // Get the validator and context to use for Gf2 multiplication operations let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::(); let binary_m_ctx = binary_validator.context(); @@ -446,8 +442,11 @@ where let first_row = first_row.unwrap(); let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row); + let mut collected = rows_chunked_by_user.collect::>().await; + collected.sort_by(|a, b| std::cmp::Ord::cmp(&b.len(), &a.len())); + // Convert to a stream of async futures that represent the result of executing the per-user circuit - let stream_of_per_user_circuits = pin!(rows_chunked_by_user.then(|rows_for_user| { + let stream_of_per_user_circuits = pin!(stream_iter(collected).then(|rows_for_user| { let num_user_rows = rows_for_user.len(); let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); let record_ids = record_id_for_row_depth[..num_user_rows].to_owned(); @@ -553,7 +552,6 @@ where let capped_attribution_outputs = prev_row_inputs .compute_row_with_previous( ctx_for_this_row_depth, - i, record_id_for_this_row_depth, row, num_saturating_sum_bits,