Skip to content

Commit

Permalink
issue with unpin
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage committed Oct 6, 2023
1 parent 0106fb3 commit 3462409
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions src/protocol/prf_sharding/feature_label_dot_product.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use futures::{stream::iter as stream_iter, TryStreamExt};
use futures_util::{future::try_join, StreamExt};
use futures_util::{future::try_join, stream, StreamExt};
use ipa_macros::Step;

use std::iter::zip;
Expand Down Expand Up @@ -235,24 +235,36 @@ where
let binary_m_ctx = binary_validator.context();
let mut num_users_who_encountered_row_depth = vec![0_u32; histogram.len()];
let ctx_for_row_number = set_up_contexts(&binary_m_ctx, &histogram);
let mut futures = Vec::with_capacity(rows_chunked_by_user.len());
for rows_for_user in rows_chunked_by_user {
let num_user_rows = rows_for_user.len();
futures.push(evaluate_per_user_attribution_circuit(
ctx_for_row_number[..rows_for_user.len() - 1].to_owned(),
num_users_who_encountered_row_depth
.iter()
.take(num_user_rows)
.map(|x| RecordId(*x))
.collect(),
rows_for_user,
));
for i in 0..num_user_rows {
num_users_who_encountered_row_depth[i] += 1;
}
}
let stream_of_per_user_circuits = stream::unfold(
(
num_users_who_encountered_row_depth,
ctx_for_row_number,
stream_iter(rows_chunked_by_user),
),
|state| async move {
let (mut count_by_row_depth, contexts, s) = state;
if let Some(rows_for_user) = s.next().await {
let num_user_rows = rows_for_user.len();
let yielded = evaluate_per_user_attribution_circuit(
contexts[..num_user_rows - 1].to_owned(),
num_users_who_encountered_row_depth
.iter()
.take(num_user_rows)
.map(|x| RecordId(*x))
.collect(),
rows_for_user,
);
for i in 0..num_user_rows {
count_by_row_depth[i] += 1;
}
Some((yielded, (count_by_row_depth, contexts, s)))
} else {
None
}
},
);

let flattenned_stream = seq_join(sh_ctx.active_work(), stream_iter(futures))
let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits)
.map(|x| x.unwrap().into_iter())
.flatten_iters();

Expand Down

0 comments on commit 3462409

Please sign in to comment.