Skip to content

Commit

Permalink
comments from Alex
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage committed Oct 10, 2023
1 parent 7fd070c commit 7ca5813
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
typenum = "1.16"
# hpke is pinned to it
x25519-dalek = "2.0.0-pre.0"
stream-flatten-iters = "0.2.0"

[target.'cfg(not(target_env = "msvc"))'.dependencies]
tikv-jemallocator = "0.5.0"
Expand Down
24 changes: 14 additions & 10 deletions src/protocol/prf_sharding/feature_label_dot_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{iter::zip, pin::pin};
use futures::{stream::iter as stream_iter, TryStreamExt};
use futures_util::{future::try_join, stream::unfold, Stream, StreamExt};
use ipa_macros::Step;
use stream_flatten_iters::StreamExt as _;

use crate::{
error::Error,
Expand Down Expand Up @@ -171,8 +170,7 @@ where
IS: Stream<Item = PrfShardedIpaInputRow<FV>> + Unpin,
{
unfold(Some((input_stream, first_row)), |state| async move {
state.as_ref()?;
let (mut s, last_row) = state.unwrap();
let (mut s, last_row) = state?;
let last_row_prf = last_row.prf_of_match_key;
let mut current_chunk = vec![last_row];
while let Some(row) = s.next().await {
Expand Down Expand Up @@ -221,7 +219,6 @@ where
/// Propagates errors from multiplications
/// # Panics
/// Propagates errors from multiplications
#[allow(clippy::async_yields_async)]
pub async fn compute_feature_label_dot_product<C, FV, F, S>(
sh_ctx: C,
input_rows: Vec<PrfShardedIpaInputRow<FV>>,
Expand Down Expand Up @@ -252,7 +249,11 @@ where

// Chunk the incoming stream of records into stream of vectors of records with the same PRF
let mut input_stream = stream_iter(input_rows);
let first_row = input_stream.next().await.unwrap();
let first_row = input_stream.next().await;
if first_row.is_none() {
return Ok(vec![]);
}
let first_row = first_row.unwrap();
let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row);

// Convert to a stream of async futures that represent the result of executing the per-user circuit
Expand All @@ -261,16 +262,19 @@ where
let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned();
let record_ids = record_id_for_row_depth[..num_user_rows].to_owned();

for count in record_id_for_row_depth.iter_mut().take(rows_for_user.len()) {
for count in &mut record_id_for_row_depth[..num_user_rows] {
*count += 1;
}
async move { evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user) }
#[allow(clippy::async_yields_async)]
// this is ok, because seq join wants a stream of futures
async move {
evaluate_per_user_attribution_circuit(contexts, record_ids, rows_for_user)
}
}));

// Execute all of the async futures (sequentially), and flatten the result
let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits)
.map(|x| x.unwrap().into_iter())
.flatten_iters();
.flat_map(|x| stream_iter(x.unwrap()));

// modulus convert feature vector bits from shares in `Z_2` to shares in `Z_p`
let converted_feature_vector_bits = convert_bits(
Expand Down Expand Up @@ -317,7 +321,7 @@ where
for (i, (row, ctx)) in
zip(rows_for_user.iter().skip(1), ctx_for_row_number.into_iter()).enumerate()
{
let record_id_for_this_row_depth = RecordId(record_id_for_each_depth[i + 1]); // skip row 0
let record_id_for_this_row_depth = RecordId::from(record_id_for_each_depth[i + 1]); // skip row 0

let capped_attribution_outputs = prev_row_inputs
.compute_row_with_previous(ctx, record_id_for_this_row_depth, row)
Expand Down

0 comments on commit 7ca5813

Please sign in to comment.