Skip to content

Commit

Permalink
Streams in some places
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Sep 15, 2023
1 parent 90a9ba3 commit 905f997
Showing 1 changed file with 36 additions and 57 deletions.
93 changes: 36 additions & 57 deletions src/protocol/aggregation/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod input;

use futures::{future::try_join, stream::iter as stream_iter, TryStreamExt};
use futures::{stream::iter as stream_iter, Stream, TryStreamExt};
use futures_util::StreamExt;
pub use input::SparseAggregateInputRow;
use ipa_macros::step;
Expand Down Expand Up @@ -65,34 +65,19 @@ where
BK: GaloisField,
{
let validator = sh_ctx.narrow(&Step::Validator).validator::<F>();
let ctx = validator.context();
let ctx = validator.context().set_total_records(input_rows.len());
let contributions = input_rows.iter().map(|row| &row.contribution_value);
let breakdowns = input_rows.iter().map(|row| &row.breakdown_key);

// convert the input from `[Z2]^u` into `[Zp]^u`
let (converted_value_bits, converted_breakdown_key_bits) = try_join(
upgrade_bit_shares(
ctx.narrow(&Step::ConvertValueBits),
input_rows,
CV::BITS,
|row, i| {
Replicated::new(
Gf2::truncate_from(row.contribution_value.left()[i]),
Gf2::truncate_from(row.contribution_value.right()[i]),
)
},
),
let (converted_value_bits, converted_breakdown_key_bits) = (
upgrade_bit_shares(ctx.narrow(&Step::ConvertValueBits), contributions, CV::BITS),
upgrade_bit_shares(
ctx.narrow(&Step::ConvertBreakdownKeyBits),
input_rows,
breakdowns,
BK::BITS,
|row, i| {
Replicated::new(
Gf2::truncate_from(row.breakdown_key.left()[i]),
Gf2::truncate_from(row.breakdown_key.right()[i]),
)
},
),
)
.await?;
);

let output = sparse_aggregate_values_per_bucket(
ctx,
Expand All @@ -111,39 +96,38 @@ where
/// # Errors
/// propagates errors from multiplications
#[tracing::instrument(name = "aggregate_values_per_bucket", skip_all)]
pub async fn sparse_aggregate_values_per_bucket<F, C, S>(
pub async fn sparse_aggregate_values_per_bucket<F, I1, I2, C, S>(
ctx: C,
contribution_values: Vec<BitDecomposed<S>>,
breakdown_keys: Vec<BitDecomposed<S>>,
contribution_values: I1,
breakdown_keys: I2,
num_buckets: usize,
) -> Result<Vec<S>, Error>
where
F: PrimeField,
I1: Stream<Item = Result<BitDecomposed<S>, Error>> + Send,
I2: Stream<Item = Result<BitDecomposed<S>, Error>> + Send,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
{
debug_assert!(contribution_values.len() == breakdown_keys.len());
let num_records = contribution_values.len();
// TODO: use exactsizestream trait
// debug_assert!(contribution_values.len() == breakdown_keys.len());

let equality_check_ctx = ctx
.narrow(&Step::ComputeEqualityChecks)
.set_total_records(num_records);
let equality_check_ctx = ctx.narrow(&Step::ComputeEqualityChecks);

// Generate N streams for each bucket specified by the `num_buckets`.
// A stream is pipeline of contribution values multiplied by the "equality bit". An equality
// bit is a bit that is a share of 1 if the breakdown key matches the bucket, or 0 otherwise.
let streams = seq_join(
ctx.active_work(),
stream_iter(breakdown_keys)
.zip(stream_iter(contribution_values))
breakdown_keys
.zip(contribution_values)
.enumerate()
.map(|(i, (bk, v))| {
let eq_ctx = &equality_check_ctx;
let c = ctx.clone();
async move {
let equality_checks = check_everything(eq_ctx.clone(), i, &bk).await?;
equality_bits_times_value(&c, equality_checks, num_buckets, v, num_records, i)
.await
let equality_checks = check_everything(eq_ctx.clone(), i, &bk?).await?;
equality_bits_times_value(&c, equality_checks, num_buckets, v?, i).await
}
}),
);
Expand All @@ -163,17 +147,14 @@ async fn equality_bits_times_value<F, C, S>(
check_bits: BitDecomposed<S>,
num_buckets: usize,
value_bits: BitDecomposed<S>,
num_records: usize,
record_id: usize,
) -> Result<Vec<S>, Error>
where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
{
let check_times_value_ctx = ctx
.narrow(&Step::CheckTimesValue)
.set_total_records(num_records);
let check_times_value_ctx = ctx.narrow(&Step::CheckTimesValue);

ctx.try_join(
check_bits
Expand All @@ -191,34 +172,32 @@ where
.await
}

async fn upgrade_bit_shares<F, C, S, H, CV, BK>(
fn upgrade_bit_shares<'a, F, C, S, I, G>(
ctx: C,
input_rows: &[SparseAggregateInputRow<CV, BK>],
input_rows: I,
num_bits: u32,
f: H,
) -> Result<Vec<BitDecomposed<S>>, Error>
) -> impl Stream<Item = Result<BitDecomposed<S>, Error>> + 'a
where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
C: UpgradedContext<F, Share = S> + 'a,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
H: Fn(&SparseAggregateInputRow<CV, BK>, u32) -> Replicated<Gf2>,
CV: GaloisField,
BK: GaloisField,
I: Iterator<Item = &'a Replicated<G>> + Send + 'a,
G: GaloisField,
{
let num_records = input_rows.len();
let gf2_bits = input_rows
.iter()
.map(|row| BitDecomposed::decompose(num_bits, |i| f(row, i)))
.collect::<Vec<_>>();
let gf2_bits = input_rows.map(move |row| {
BitDecomposed::decompose(num_bits, |idx| {
Replicated::new(
Gf2::truncate_from(row.left()[idx]),
Gf2::truncate_from(row.right()[idx]),
)
})
});

convert_bits(
ctx.narrow(&Step::ConvertValueBits)
.set_total_records(num_records),
ctx.narrow(&Step::ConvertValueBits),
stream_iter(gf2_bits),
0..num_bits,
)
.try_collect::<Vec<_>>()
.await
}

#[cfg(all(test, unit_test))]
Expand Down

0 comments on commit 905f997

Please sign in to comment.