Skip to content

Commit

Permalink
Incorporate Ben's feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
richajaindce committed Oct 2, 2023
1 parent 9c00477 commit 071cb78
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 143 deletions.
Binary file added images/tree_aggregation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions src/protocol/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ where
let c = ctx.clone();
async move {
let equality_checks = bitwise_to_onehot(eq_ctx.clone(), i, &bk?).await?;
equality_bits_times_value(&c, equality_checks, num_buckets, v?, i).await
equality_bits_times_value(&c, equality_checks, num_buckets, &mut v?, i).await
}
}),
);
Expand All @@ -167,7 +167,7 @@ async fn equality_bits_times_value<F, C, S>(
ctx: &C,
check_bits: BitDecomposed<S>,
num_buckets: usize,
value_bits: BitDecomposed<S>,
value_bits: &mut BitDecomposed<S>,
record_id: usize,
) -> Result<Vec<S>, Error>
where
Expand Down
323 changes: 182 additions & 141 deletions src/protocol/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::iter::{repeat, zip};

use embed_doc_image::embed_doc_image;
use futures::{stream::iter as stream_iter, TryStreamExt};
use futures_util::{future::try_join, StreamExt};
use ipa_macros::Step;
Expand All @@ -22,7 +23,7 @@ use crate::{
malicious::ExtendableField, semi_honest::AdditiveShare as Replicated,
ReplicatedSecretSharing,
},
BitDecomposed, Linear as LinearSecretSharing, LinearRefOps, SharedValue,
BitDecomposed, Linear as LinearSecretSharing, SharedValue,
},
seq_join::{seq_join, seq_try_join_all},
};
Expand Down Expand Up @@ -176,11 +177,6 @@ pub struct CappedAttributionOutputs {
pub capped_attributed_trigger_value: BitDecomposed<Replicated<Gf2>>,
}

pub struct PrimeFieldAggregationInputs<F: Field> {
pub attributed_breakdown_key_bits: BitDecomposed<Replicated<F>>,
pub capped_attributed_trigger_value: Replicated<F>,
}

#[derive(Step)]
pub enum UserNthRowStep {
#[dynamic]
Expand All @@ -193,6 +189,18 @@ impl From<usize> for UserNthRowStep {
}
}

#[derive(Step)]
pub enum BinaryTreeDepthStep {
#[dynamic]
Depth(usize),
}

impl From<usize> for BinaryTreeDepthStep {
fn from(v: usize) -> Self {
Self::Depth(v)
}
}

#[derive(Step)]
pub(crate) enum Step {
BinaryValidator,
Expand All @@ -205,8 +213,8 @@ pub(crate) enum Step {
ComputeDifferenceToCap,
ComputedCappedAttributedTriggerValueNotSaturatedCase,
ComputedCappedAttributedTriggerValueJustSaturatedCase,
ComputedAttributedBreakdownKey,
ComputedAttributedValue,
ModulusConvertBreakdownKeyBits,
ModulusConvertConversionValueBits,
MoveValueToCorrectBreakdown,
}

Expand Down Expand Up @@ -266,138 +274,6 @@ where
rows_chunked_by_user
}

/// # Errors
/// If there is an issue in multiplication, it will error
pub async fn attribution_and_capping_and_aggregation<C, BK, TV, F, S, SB>(
sh_ctx: C,
input_rows: Vec<PrfShardedIpaInputRow<BK, TV>>,
num_saturating_sum_bits: usize,
) -> Result<Vec<S>, Error>
where
C: UpgradableContext,
C::UpgradedContext<F>: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C::UpgradedContext<F>>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = Replicated<Gf2>>,
F: PrimeField + ExtendableField,
TV: GaloisField,
BK: GaloisField,
{
// Call attribution_and_capping
let user_level_attributions: Vec<CappedAttributionOutputs> =
attribution_and_capping(sh_ctx.clone(), input_rows, num_saturating_sum_bits).await?;

let prime_field_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::<F>();
let prime_field_m_ctx = prime_field_validator.context();

do_aggregation::<_, BK, TV, F, S>(prime_field_m_ctx, user_level_attributions).await
}

#[derive(Step)]
pub enum BinaryTreeDepthStep {
#[dynamic]
Depth(usize),
}

impl From<usize> for BinaryTreeDepthStep {
fn from(v: usize) -> Self {
Self::Depth(v)
}
}

async fn do_aggregation<C, BK, TV, F, S>(
ctx: C,
user_level_attributions: Vec<CappedAttributionOutputs>,
) -> Result<Vec<S>, Error>
where
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
BK: GaloisField,
TV: GaloisField,
F: PrimeField + ExtendableField,
{
let num_records = user_level_attributions.len();
let (bk_vec, tv_vec): (Vec<_>, Vec<_>) = user_level_attributions
.into_iter()
.map(|row| {
(
row.attributed_breakdown_key_bits,
row.capped_attributed_trigger_value,
)
})
.unzip();

// convert bk
let converted_bks = convert_bits(
ctx.narrow(&Step::ComputedAttributedBreakdownKey)
.set_total_records(num_records),
stream_iter(bk_vec),
0..BK::BITS,
);
// convert attributed value
let converted_values = convert_bits(
ctx.narrow(&Step::ComputedAttributedValue)
.set_total_records(num_records),
stream_iter(tv_vec),
0..TV::BITS,
);
let large_field_value =
converted_values.map(|val| val.unwrap().to_additive_sharing_in_large_field());

let row_contributions_stream = converted_bks
.zip(large_field_value)
.zip(futures::stream::repeat(
ctx.narrow(&Step::MoveValueToCorrectBreakdown)
.set_total_records(num_records),
))
.enumerate()
.map(|(i, ((bk_bits, cred), ctx))| {
let mut row_contribution = vec![cred.clone(); 1 << BK::BITS];
let mut step: usize = 1 << BK::BITS;
let record_id: RecordId = RecordId::from(i);
let bd_key = bk_bits.unwrap();
async move {
for (tree_depth, bit_of_bdkey) in bd_key.iter().rev().enumerate() {
let depth_c = ctx.narrow(&BinaryTreeDepthStep::from(tree_depth));
let span = step >> 1;
let mut futures = vec![];
for i in (0..1 << BK::BITS).step_by(step) {
let bit_c = depth_c.narrow(&BitOpStep::from(i));

if i + span < 1 << BK::BITS {
let vb = row_contribution[i].multiply(bit_of_bdkey, bit_c, record_id);
futures.push(vb);
}
}
let vbs = ctx.parallel_join(futures).await?;

for (index, vb) in vbs.into_iter().enumerate() {
let left_index = index * step;
let right_index = left_index + span;

row_contribution[left_index] -= &vb;
row_contribution[right_index] = vb;
}
step = span;
}
Ok(row_contribution)
}
});
let row_contributions = seq_join(ctx.active_work(), row_contributions_stream);
row_contributions
.try_fold(
vec![S::ZERO; 1 << BK::BITS],
|mut running_sums, row_contribution| async move {
for (i, contribution) in row_contribution.iter().enumerate() {
running_sums[i] += contribution;
}
Ok(running_sums)
},
)
.await
}

/// Sub-protocol of the PRF-sharded IPA Protocol
///
/// After the computation of the per-user PRF, addition of dummy records and shuffling,
Expand Down Expand Up @@ -664,6 +540,171 @@ where
))
}

/// This circuit expects to receive records from multiple users,
/// but with all of the records from a given user adjacent to one another, and in time order.
///
/// This is a wrapper function to do attribution and capping per user followed by aggregating
/// the results per breakdown key
/// # Errors
/// If there is an issue in multiplication, it will error
pub async fn attribution_and_capping_and_aggregation<C, BK, TV, F, S, SB>(
sh_ctx: C,
input_rows: Vec<PrfShardedIpaInputRow<BK, TV>>,
num_saturating_sum_bits: usize,
) -> Result<Vec<S>, Error>
where
C: UpgradableContext,
C::UpgradedContext<F>: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C::UpgradedContext<F>>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = Replicated<Gf2>>,
F: PrimeField + ExtendableField,
TV: GaloisField,
BK: GaloisField,
{
let prime_field_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::<F>();
let prime_field_m_ctx = prime_field_validator.context();

let user_level_attributions: Vec<CappedAttributionOutputs> =
attribution_and_capping(sh_ctx, input_rows, num_saturating_sum_bits).await?;

do_aggregation::<_, BK, TV, F, S>(prime_field_m_ctx, user_level_attributions).await
}

/// Sub-protocol of the PRF-sharded IPA Protocol
///
/// This function receives capped user level contributions to breakdown key buckets. It does the following
/// 1. Convert bit-shares of breakdown keys and conversion values from binary field to prime field
/// 2. Transform conversion value bits to additive sharing
/// 3. Move all conversion values to corresponding breakdown key bucket
///
/// At the end of the function, all conversions are aggregated and placed in the appropriate breakdown key bucket
async fn do_aggregation<C, BK, TV, F, S>(
ctx: C,
user_level_attributions: Vec<CappedAttributionOutputs>,
) -> Result<Vec<S>, Error>
where
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C>,
BK: GaloisField,
TV: GaloisField,
F: PrimeField + ExtendableField,
{
let num_records = user_level_attributions.len();
let (bk_vec, tv_vec): (Vec<_>, Vec<_>) = user_level_attributions
.into_iter()
.map(|row| {
(
row.attributed_breakdown_key_bits,
row.capped_attributed_trigger_value,
)
})
.unzip();

// modulus convert breakdown keys
let converted_bks = convert_bits(
ctx.narrow(&Step::ModulusConvertBreakdownKeyBits)
.set_total_records(num_records),
stream_iter(bk_vec),
0..BK::BITS,
);
// modulus convert attributed value
let converted_values = convert_bits(
ctx.narrow(&Step::ModulusConvertConversionValueBits)
.set_total_records(num_records),
stream_iter(tv_vec),
0..TV::BITS,
);

// tranform value bits to large field
let large_field_values = converted_values
.map(|val| BitDecomposed::to_additive_sharing_in_large_field_consuming(val.unwrap()));

// move each value to the correct bucket
let row_contributions_stream = converted_bks
.zip(large_field_values)
.zip(futures::stream::repeat(
ctx.narrow(&Step::MoveValueToCorrectBreakdown)
.set_total_records(num_records),
))
.enumerate()
.map(|(i, ((bk_bits, value), ctx))| {
let record_id: RecordId = RecordId::from(i);
let bd_key = bk_bits.unwrap();
let row_contribution = vec![value.clone(); 1 << BK::BITS];
async move {
move_single_value_to_bucket::<BK, _, _, _>(ctx, record_id, bd_key, row_contribution)
.await
}
});

// aggregate all row level contributions
let row_contributions = seq_join(ctx.active_work(), row_contributions_stream);
row_contributions
.try_fold(
vec![S::ZERO; 1 << BK::BITS],
|mut running_sums, row_contribution| async move {
for (i, contribution) in row_contribution.iter().enumerate() {
running_sums[i] += contribution;
}
Ok(running_sums)
},
)
.await
}

#[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")]
/// This function moves a single value to a correct bucket using tree aggregation approach
///
/// Here is how it works
/// The combined value, [`value`] forms the root of a binary tree as follows:
/// ![Tree propagation][tree-aggregation]
///
/// This value is propagated through the tree, with each subsequent iteration doubling the number of multiplications.
/// In the first round, r=BK-1, multiply the most significant bit ,[`bd_key`]r by the value to get [`bd_key`]r.[`value`]. From that,
/// produce [`row_contribution`]r,0 =[`value`]-[`bd_key`]r.[`value`] and [`row_contribution`]r,1=[`bd_key`]r.[`value`].
/// This takes the most significant bit of `bd_key` and places value in one of the two child nodes of the binary tree.
/// At each successive round, the next most significant bit is propagated from the leaf nodes of the tree into further leaf nodes:
/// [`row_contribution`]r+1,q,0 =[`row_contribution`]r,q - [`bd_key`]r+1.[`row_contribution`]r,q and [`row_contribution`]r+1,q,1 =[`bd_key`]r+1.[`row_contribution`]r,q.
/// The work of each iteration therefore doubles relative to the one preceding.
async fn move_single_value_to_bucket<BK, C, S, F>(
ctx: C,
record_id: RecordId,
bd_key: BitDecomposed<S>,
mut row_contribution: Vec<S>,
) -> Result<Vec<S>, Error>
where
BK: GaloisField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C>,
F: PrimeField + ExtendableField,
{
let mut step: usize = 1 << BK::BITS;

for (tree_depth, bit_of_bdkey) in bd_key.iter().rev().enumerate() {
let depth_c = ctx.narrow(&BinaryTreeDepthStep::from(tree_depth));
let span = step >> 1;
let mut futures = vec![];
for i in (0..1 << BK::BITS).step_by(step) {
let bit_c = depth_c.narrow(&BitOpStep::from(i));

if i + span < 1 << BK::BITS {
futures.push(row_contribution[i].multiply(bit_of_bdkey, bit_c, record_id));
}
}
let contributions = ctx.parallel_join(futures).await?;

for (index, bdbit_contribution) in contributions.into_iter().enumerate() {
let left_index = index * step;
let right_index = left_index + span;

row_contribution[left_index] -= &bdbit_contribution;
row_contribution[right_index] = bdbit_contribution;
}
step = span;
}
Ok(row_contribution)
}

#[cfg(all(test, unit_test))]
pub mod tests {
use super::{attribution_and_capping, CappedAttributionOutputs, PrfShardedIpaInputRow};
Expand Down Expand Up @@ -851,7 +892,7 @@ pub mod tests {
}

#[test]
fn semi_honest_aggregation() {
fn semi_honest_aggregation_capping_attribution() {
run(|| async move {
let world = TestWorld::default();

Expand Down
Loading

0 comments on commit 071cb78

Please sign in to comment.