Skip to content

Commit

Permalink
Merge pull request #789 from richajaindce/user_aggregation
Browse files Browse the repository at this point in the history
Tree Aggregation on top of user level attribution and capping
  • Loading branch information
benjaminsavage authored Oct 3, 2023
2 parents 0599ba6 + 448a5b7 commit 35ccd66
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 8 deletions.
Binary file added images/tree_aggregation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
254 changes: 246 additions & 8 deletions src/protocol/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
use std::iter::{repeat, zip};

use futures_util::future::try_join;
use embed_doc_image::embed_doc_image;
use futures::{stream::iter as stream_iter, TryStreamExt};
use futures_util::{future::try_join, StreamExt};
use ipa_macros::Step;

use super::{basics::if_else, boolean::saturating_sum::SaturatingSum, step::BitOpStep};
use super::{
basics::if_else, boolean::saturating_sum::SaturatingSum, modulus_conversion::convert_bits,
step::BitOpStep,
};
use crate::{
error::Error,
ff::{Field, GaloisField, Gf2},
ff::{Field, GaloisField, Gf2, PrimeField, Serializable},
protocol::{
basics::{SecureMul, ShareKnownValue},
boolean::or::or,
context::{UpgradableContext, UpgradedContext, Validator},
RecordId,
},
secret_sharing::{
replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing},
BitDecomposed, SharedValue,
replicated::{
malicious::ExtendableField, semi_honest::AdditiveShare as Replicated,
ReplicatedSecretSharing,
},
BitDecomposed, Linear as LinearSecretSharing, SharedValue,
},
seq_join::seq_try_join_all,
seq_join::{seq_join, seq_try_join_all},
};

pub struct PrfShardedIpaInputRow<BK: GaloisField, TV: GaloisField> {
prf_of_match_key: u64,
is_trigger_bit: Replicated<Gf2>,
Expand Down Expand Up @@ -182,6 +189,18 @@ impl From<usize> for UserNthRowStep {
}
}

#[derive(Step)]
pub enum BinaryTreeDepthStep {
#[dynamic]
Depth(usize),
}

impl From<usize> for BinaryTreeDepthStep {
fn from(v: usize) -> Self {
Self::Depth(v)
}
}

#[derive(Step)]
pub(crate) enum Step {
BinaryValidator,
Expand All @@ -194,6 +213,9 @@ pub(crate) enum Step {
ComputeDifferenceToCap,
ComputedCappedAttributedTriggerValueNotSaturatedCase,
ComputedCappedAttributedTriggerValueJustSaturatedCase,
ModulusConvertBreakdownKeyBits,
ModulusConvertConversionValueBits,
MoveValueToCorrectBreakdown,
}

fn compute_histogram_of_users_with_row_count<S>(rows_chunked_by_user: &[Vec<S>]) -> Vec<usize> {
Expand Down Expand Up @@ -518,11 +540,176 @@ where
))
}

/// This circuit expects to receive records from multiple users,
/// but with all of the records from a given user adjacent to one another, and in time order.
///
/// This is a wrapper function to do attribution and capping per user followed by aggregating
/// the results per breakdown key
/// # Errors
/// If there is an issue in multiplication, it will error
pub async fn attribution_and_capping_and_aggregation<C, BK, TV, F, S, SB>(
sh_ctx: C,
input_rows: Vec<PrfShardedIpaInputRow<BK, TV>>,
num_saturating_sum_bits: usize,
) -> Result<Vec<S>, Error>
where
C: UpgradableContext,
C::UpgradedContext<F>: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C::UpgradedContext<F>>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = Replicated<Gf2>>,
F: PrimeField + ExtendableField,
TV: GaloisField,
BK: GaloisField,
{
let prime_field_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::<F>();
let prime_field_m_ctx = prime_field_validator.context();

let user_level_attributions: Vec<CappedAttributionOutputs> =
attribution_and_capping(sh_ctx, input_rows, num_saturating_sum_bits).await?;

do_aggregation::<_, BK, TV, F, S>(prime_field_m_ctx, user_level_attributions).await
}

/// Sub-protocol of the PRF-sharded IPA Protocol
///
/// This function receives capped user level contributions to breakdown key buckets. It does the following
/// 1. Convert bit-shares of breakdown keys and conversion values from binary field to prime field
/// 2. Transform conversion value bits to additive sharing
/// 3. Move all conversion values to corresponding breakdown key bucket
///
/// At the end of the function, all conversions are aggregated and placed in the appropriate breakdown key bucket
async fn do_aggregation<C, BK, TV, F, S>(
ctx: C,
user_level_attributions: Vec<CappedAttributionOutputs>,
) -> Result<Vec<S>, Error>
where
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C>,
BK: GaloisField,
TV: GaloisField,
F: PrimeField + ExtendableField,
{
let num_records = user_level_attributions.len();
let (bk_vec, tv_vec): (Vec<_>, Vec<_>) = user_level_attributions
.into_iter()
.map(|row| {
(
row.attributed_breakdown_key_bits,
row.capped_attributed_trigger_value,
)
})
.unzip();

// modulus convert breakdown keys
let converted_bks = convert_bits(
ctx.narrow(&Step::ModulusConvertBreakdownKeyBits)
.set_total_records(num_records),
stream_iter(bk_vec),
0..BK::BITS,
);
// modulus convert attributed value
let converted_values = convert_bits(
ctx.narrow(&Step::ModulusConvertConversionValueBits)
.set_total_records(num_records),
stream_iter(tv_vec),
0..TV::BITS,
);

// tranform value bits to large field
let large_field_values = converted_values
.map(|val| BitDecomposed::to_additive_sharing_in_large_field_consuming(val.unwrap()));

// move each value to the correct bucket
let row_contributions_stream = converted_bks
.zip(large_field_values)
.zip(futures::stream::repeat(
ctx.narrow(&Step::MoveValueToCorrectBreakdown)
.set_total_records(num_records),
))
.enumerate()
.map(|(i, ((bk_bits, value), ctx))| {
let record_id: RecordId = RecordId::from(i);
let bd_key = bk_bits.unwrap();
async move {
move_single_value_to_bucket::<BK, _, _, _>(ctx, record_id, bd_key, value).await
}
});

// aggregate all row level contributions
let row_contributions = seq_join(ctx.active_work(), row_contributions_stream);
row_contributions
.try_fold(
vec![S::ZERO; 1 << BK::BITS],
|mut running_sums, row_contribution| async move {
for (i, contribution) in row_contribution.iter().enumerate() {
running_sums[i] += contribution;
}
Ok(running_sums)
},
)
.await
}

#[embed_doc_image("tree-aggregation", "images/tree_aggregation.png")]
/// This function moves a single value to a correct bucket using tree aggregation approach
///
/// Here is how it works
/// The combined value, [`value`] forms the root of a binary tree as follows:
/// ![Tree propagation][tree-aggregation]
///
/// This value is propagated through the tree, with each subsequent iteration doubling the number of multiplications.
/// In the first round, r=BK-1, multiply the most significant bit ,[`bd_key`]_r by the value to get [`bd_key`]_r.[`value`]. From that,
/// produce [`row_contribution`]_r,0 =[`value`]-[`bd_key`]_r.[`value`] and [`row_contribution`]_r,1=[`bd_key`]_r.[`value`].
/// This takes the most significant bit of `bd_key` and places value in one of the two child nodes of the binary tree.
/// At each successive round, the next most significant bit is propagated from the leaf nodes of the tree into further leaf nodes:
/// [`row_contribution`]_r+1,q,0 =[`row_contribution`]_r,q - [`bd_key`]_r+1.[`row_contribution`]_r,q and [`row_contribution`]_r+1,q,1 =[`bd_key`]_r+1.[`row_contribution`]_r,q.
/// The work of each iteration therefore doubles relative to the one preceding.
async fn move_single_value_to_bucket<BK, C, S, F>(
ctx: C,
record_id: RecordId,
bd_key: BitDecomposed<S>,
value: S,
) -> Result<Vec<S>, Error>
where
BK: GaloisField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + Serializable + SecureMul<C>,
F: PrimeField + ExtendableField,
{
let mut step: usize = 1 << BK::BITS;
let mut row_contribution = vec![value; 1 << BK::BITS];

for (tree_depth, bit_of_bdkey) in bd_key.iter().rev().enumerate() {
let depth_c = ctx.narrow(&BinaryTreeDepthStep::from(tree_depth));
let span = step >> 1;
let mut futures = Vec::with_capacity((1 << BK::BITS) / step);
for i in (0..1 << BK::BITS).step_by(step) {
let bit_c = depth_c.narrow(&BitOpStep::from(i));

if i + span < 1 << BK::BITS {
futures.push(row_contribution[i].multiply(bit_of_bdkey, bit_c, record_id));
}
}
let contributions = ctx.parallel_join(futures).await?;

for (index, bdbit_contribution) in contributions.into_iter().enumerate() {
let left_index = index * step;
let right_index = left_index + span;

row_contribution[left_index] -= &bdbit_contribution;
row_contribution[right_index] = bdbit_contribution;
}
step = span;
}
Ok(row_contribution)
}

#[cfg(all(test, unit_test))]
pub mod tests {
use super::{attribution_and_capping, CappedAttributionOutputs, PrfShardedIpaInputRow};
use crate::{
ff::{Field, GaloisField, Gf2, Gf3Bit, Gf5Bit},
ff::{Field, Fp32BitPrime, GaloisField, Gf2, Gf3Bit, Gf5Bit},
protocol::prf_sharding::attribution_and_capping_and_aggregation,
rand::Rng,
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, IntoShares,
Expand Down Expand Up @@ -702,4 +889,55 @@ pub mod tests {
assert_eq!(result, &expected);
});
}

#[test]
fn semi_honest_aggregation_capping_attribution() {
run(|| async move {
let world = TestWorld::default();

let records: Vec<PreShardedAndSortedOPRFTestInput<Gf5Bit, Gf3Bit>> = vec![
/* First User */
test_input(123, false, 17, 0),
test_input(123, true, 0, 7),
test_input(123, false, 20, 0),
test_input(123, true, 0, 3),
/* Second User */
test_input(234, false, 12, 0),
test_input(234, true, 0, 5),
/* Third User */
test_input(345, false, 20, 0),
test_input(345, true, 0, 7),
test_input(345, false, 18, 0),
test_input(345, false, 12, 0),
test_input(345, true, 0, 7),
test_input(345, true, 0, 7),
test_input(345, true, 0, 7),
test_input(345, true, 0, 7),
];

let mut expected = [0_u128; 32];
expected[12] = 30;
expected[17] = 7;
expected[20] = 10;

let num_saturating_bits: usize = 5;

let result: Vec<_> = world
.semi_honest(records.into_iter(), |ctx, input_rows| async move {
attribution_and_capping_and_aggregation::<
_,
Gf5Bit,
Gf3Bit,
Fp32BitPrime,
_,
Replicated<Gf2>,
>(ctx, input_rows, num_saturating_bits)
.await
.unwrap()
})
.await
.reconstruct();
assert_eq!(result, &expected);
});
}
}
12 changes: 12 additions & 0 deletions src/secret_sharing/decomposed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ impl<S> BitDecomposed<S> {
acc + (b * F::truncate_from(1_u128 << i))
})
}

// Same as above, but without the need to HRTB, as this doesn't used references
// but rather takes ownership over the BitDecomposed
pub fn to_additive_sharing_in_large_field_consuming<F>(bits: BitDecomposed<S>) -> S
where
S: LinearSecretSharing<F>,
F: PrimeField,
{
bits.into_iter().enumerate().fold(S::ZERO, |acc, (i, b)| {
acc + (b * F::truncate_from(1_u128 << i))
})
}
}

impl<S> TryFrom<Vec<S>> for BitDecomposed<S> {
Expand Down

0 comments on commit 35ccd66

Please sign in to comment.