diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index 3fb35835d..32615d1cc 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -1,5 +1,5 @@ use futures::{stream::iter as stream_iter, TryStreamExt}; -use futures_util::{future::try_join, StreamExt}; +use futures_util::{future::try_join, stream, StreamExt}; use ipa_macros::Step; use std::iter::zip; @@ -235,24 +235,36 @@ where let binary_m_ctx = binary_validator.context(); let mut num_users_who_encountered_row_depth = vec![0_u32; histogram.len()]; let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram); - let mut futures = Vec::with_capacity(rows_chunked_by_user.len()); - for rows_for_user in rows_chunked_by_user { - let num_user_rows = rows_for_user.len(); - futures.push(evaluate_per_user_attribution_circuit( - ctx_for_row_number[..rows_for_user.len() - 1].to_owned(), - num_users_who_encountered_row_depth - .iter() - .take(num_user_rows) - .map(|x| RecordId(*x)) - .collect(), - rows_for_user, - )); - for i in 0..num_user_rows { - num_users_who_encountered_row_depth[i] += 1; - } - } + let stream_of_per_user_circuits = stream::unfold( + ( + num_users_who_encountered_row_depth, + ctx_for_row_number, + stream_iter(rows_chunked_by_user), + ), + |state| async move { + let (mut count_by_row_depth, contexts, s) = state; + if let Some(rows_for_user) = s.next().await { + let num_user_rows = rows_for_user.len(); + let yielded = evaluate_per_user_attribution_circuit( + contexts[..num_user_rows - 1].to_owned(), + num_users_who_encountered_row_depth + .iter() + .take(num_user_rows) + .map(|x| RecordId(*x)) + .collect(), + rows_for_user, + ); + for i in 0..num_user_rows { + count_by_row_depth[i] += 1; + } + Some((yielded, (count_by_row_depth, contexts, s))) + } else { + None + } + }, + ); - let flattenned_stream = seq_join(sh_ctx.active_work(), stream_iter(futures)) + let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits) .map(|x| x.unwrap().into_iter()) .flatten_iters();