diff --git a/src/protocol/prf_sharding/feature_label_dot_product.rs b/src/protocol/prf_sharding/feature_label_dot_product.rs index 66dca764e..3fb35835d 100644 --- a/src/protocol/prf_sharding/feature_label_dot_product.rs +++ b/src/protocol/prf_sharding/feature_label_dot_product.rs @@ -2,6 +2,8 @@ use futures::{stream::iter as stream_iter, TryStreamExt}; use futures_util::{future::try_join, StreamExt}; use ipa_macros::Step; +use std::iter::zip; + use stream_flatten_iters::StreamExt as _; use crate::{ @@ -237,7 +239,7 @@ where for rows_for_user in rows_chunked_by_user { let num_user_rows = rows_for_user.len(); futures.push(evaluate_per_user_attribution_circuit( - &ctx_for_row_number, + ctx_for_row_number[..rows_for_user.len() - 1].to_owned(), num_users_who_encountered_row_depth .iter() .take(num_user_rows) @@ -280,7 +282,7 @@ where } async fn evaluate_per_user_attribution_circuit( - ctx_for_row_number: &[C], + ctx_for_row_number: Vec, record_id_for_each_depth: Vec, rows_for_user: Vec>, ) -> Result>>, Error> @@ -296,12 +298,15 @@ where let mut prev_row_inputs = initialize_new_device_attribution_variables(first_row); let mut output = Vec::with_capacity(rows_for_user.len() - 1); - for (i, row) in rows_for_user.iter().skip(1).enumerate() { - let ctx_for_this_row_depth = ctx_for_row_number[i].clone(); // no context was created for row 0 + // skip the first row as it requires no multiplications + // no context was created for the first row + for (i, (row, ctx)) in + zip(rows_for_user.iter().skip(1), ctx_for_row_number.into_iter()).enumerate() + { let record_id_for_this_row_depth = record_id_for_each_depth[i + 1]; // skip row 0 let capped_attribution_outputs = prev_row_inputs - .compute_row_with_previous(ctx_for_this_row_depth, record_id_for_this_row_depth, row) + .compute_row_with_previous(ctx, record_id_for_this_row_depth, row) .await?; output.push(capped_attribution_outputs);