Skip to content

Commit

Permalink
Now it compiles. thank you martin
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage committed Oct 5, 2023
1 parent 094eb28 commit 0106fb3
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/protocol/prf_sharding/feature_label_dot_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use futures::{stream::iter as stream_iter, TryStreamExt};
use futures_util::{future::try_join, StreamExt};
use ipa_macros::Step;

use std::iter::zip;

use stream_flatten_iters::StreamExt as _;

use crate::{
Expand Down Expand Up @@ -237,7 +239,7 @@ where
for rows_for_user in rows_chunked_by_user {
let num_user_rows = rows_for_user.len();
futures.push(evaluate_per_user_attribution_circuit(
&ctx_for_row_number,
ctx_for_row_number[..rows_for_user.len() - 1].to_owned(),
num_users_who_encountered_row_depth
.iter()
.take(num_user_rows)
Expand Down Expand Up @@ -280,7 +282,7 @@ where
}

async fn evaluate_per_user_attribution_circuit<C, FV>(
ctx_for_row_number: &[C],
ctx_for_row_number: Vec<C>,
record_id_for_each_depth: Vec<RecordId>,
rows_for_user: Vec<PrfShardedIpaInputRow<FV>>,
) -> Result<Vec<BitDecomposed<Replicated<Gf2>>>, Error>
Expand All @@ -296,12 +298,15 @@ where
let mut prev_row_inputs = initialize_new_device_attribution_variables(first_row);

let mut output = Vec::with_capacity(rows_for_user.len() - 1);
for (i, row) in rows_for_user.iter().skip(1).enumerate() {
let ctx_for_this_row_depth = ctx_for_row_number[i].clone(); // no context was created for row 0
// skip the first row as it requires no multiplications
// no context was created for the first row
for (i, (row, ctx)) in
zip(rows_for_user.iter().skip(1), ctx_for_row_number.into_iter()).enumerate()
{
let record_id_for_this_row_depth = record_id_for_each_depth[i + 1]; // skip row 0

let capped_attribution_outputs = prev_row_inputs
.compute_row_with_previous(ctx_for_this_row_depth, record_id_for_this_row_depth, row)
.compute_row_with_previous(ctx, record_id_for_this_row_depth, row)
.await?;

output.push(capped_attribution_outputs);
Expand Down

0 comments on commit 0106fb3

Please sign in to comment.