diff --git a/src/protocol/ipa_prf/prf_sharding/mod.rs b/src/protocol/ipa_prf/prf_sharding/mod.rs index db62c95e0..5e0ae01d5 100644 --- a/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -349,6 +349,8 @@ where /// Takes an input stream of `PrfShardedIpaInputRecordRow` which is assumed to have all records with a given PRF adjacent /// and converts it into a stream of vectors of `PrfShardedIpaInputRecordRow` having the same PRF. /// +/// Filters out any users that only have a single row, since they will produce no attributed conversions. +/// fn chunk_rows_by_user( input_stream: IS, first_row: PrfShardedIpaInputRow, @@ -361,13 +363,16 @@ where { unfold(Some((input_stream, first_row)), |state| async move { let (mut s, last_row) = state?; - let last_row_prf = last_row.prf_of_match_key; + let mut last_row_prf = last_row.prf_of_match_key; let mut current_chunk = vec![last_row]; while let Some(row) = s.next().await { if row.prf_of_match_key == last_row_prf { current_chunk.push(row); - } else { + } else if current_chunk.len() > 1 { return Some((current_chunk, Some((s, row)))); + } else { + last_row_prf = row.prf_of_match_key; + current_chunk = vec![row]; } } Some((current_chunk, None)) @@ -435,8 +440,11 @@ where let first_row = first_row.unwrap(); let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row); + let mut collected = rows_chunked_by_user.collect::>().await; + collected.sort_by(|a, b| std::cmp::Ord::cmp(&b.len(), &a.len())); + // Convert to a stream of async futures that represent the result of executing the per-user circuit - let stream_of_per_user_circuits = pin!(rows_chunked_by_user.then(|rows_for_user| { + let stream_of_per_user_circuits = pin!(stream_iter(collected).then(|rows_for_user| { let num_user_rows = rows_for_user.len(); let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned(); let record_ids = record_id_for_row_depth[..num_user_rows].to_owned(); @@ -458,16 +466,15 @@ where })); // Execute all of the async futures (sequentially), and flatten the result - let collected_per_user_results = stream_of_per_user_circuits.collect::>().await; - let per_user_attribution_outputs = sh_ctx.parallel_join(collected_per_user_results).await?; - let flattenned_stream = per_user_attribution_outputs.into_iter().flatten(); + let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) + .flat_map(|x| stream_iter(x.unwrap())); // modulus convert breakdown keys and trigger values let converted_bks_and_tvs = convert_bits( prime_field_ctx .narrow(&Step::ModulusConvertBreakdownKeyBitsAndTriggerValues) .set_total_records(num_outputs), - stream_iter(flattenned_stream), + flattenned_stream, 0..BK::BITS + TV::BITS, );