Skip to content

Commit

Permalink
Merge pull request #927 from private-attribution/backport_better_aggr…
Browse files Browse the repository at this point in the history
…egation

Backport of improved aggregation logic to old IPA
  • Loading branch information
benjaminsavage authored Jan 19, 2024
2 parents d450bb4 + f5989ad commit 3b9eb5c
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 364 deletions.
58 changes: 18 additions & 40 deletions ipa-core/src/protocol/attribution/aggregate_credit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use crate::{
ff::{Gf2, PrimeField, Serializable},
protocol::{
context::{UpgradableContext, UpgradedContext, Validator},
ipa_prf::prf_sharding::bucket::move_single_value_to_bucket,
modulus_conversion::convert_bits,
sort::{bitwise_to_onehot, generate_permutation::ShuffledPermutationWrapper},
step::BitOpStep,
sort::generate_permutation::ShuffledPermutationWrapper,
BasicProtocols, RecordId,
},
secret_sharing::{
Expand All @@ -23,11 +23,6 @@ use crate::{
seq_join::seq_join,
};

/// This is the number of breakdown keys above which it is more efficient to SORT by breakdown key.
/// Below this number, it's more efficient to just do a ton of equality checks.
/// This number was determined empirically on 27 Feb 2023
const SIMPLE_AGGREGATION_BREAK_EVEN_POINT: u32 = 32;

/// Aggregation step for Oblivious Attribution protocol.
/// # Panics
/// It probably won't
Expand Down Expand Up @@ -56,16 +51,9 @@ where
ShuffledPermutationWrapper<S, C::UpgradedContext<F>>: DowngradeMalicious<Target = Vec<u32>>,
{
let m_ctx = validator.context();

if max_breakdown_key <= SIMPLE_AGGREGATION_BREAK_EVEN_POINT {
let res = simple_aggregate_credit(m_ctx, breakdown_keys, capped_credits, max_breakdown_key)
.await?;
Ok((validator, res))
} else {
Err(Error::Unsupported(
format!("query uses {max_breakdown_key} breakdown keys; only {SIMPLE_AGGREGATION_BREAK_EVEN_POINT} are supported")
))
}
let res =
simple_aggregate_credit(m_ctx, breakdown_keys, capped_credits, max_breakdown_key).await?;
Ok((validator, res))
}

async fn simple_aggregate_credit<F, C, IC, IB, S>(
Expand All @@ -84,17 +72,10 @@ where
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
{
let record_count = breakdown_keys.len();
// The number of records we compute is currently too high as the last row cannot have
// any credit associated with it. TODO: don't compute that row when cap > 1.

let to_take = usize::try_from(max_breakdown_key).unwrap();
let valid_bits_count = u32::BITS - (max_breakdown_key - 1).leading_zeros();

let equality_check_context = ctx
.narrow(&Step::ComputeEqualityChecks)
.set_total_records(record_count);
let check_times_credit_context = ctx
.narrow(&Step::CheckTimesCredit)
let move_value_to_bucket_context = ctx
.narrow(&Step::MoveValueToBucket)
.set_total_records(record_count);

let converted_bk = convert_bits(
Expand All @@ -110,23 +91,21 @@ where
.zip(stream_iter(capped_credits))
.enumerate()
.map(|(i, (bk, cred))| {
let ceq = &equality_check_context;
let cmul = &check_times_credit_context;
let ctx = move_value_to_bucket_context.clone();
async move {
let equality_checks = bitwise_to_onehot(ceq.clone(), i, &bk?).await?;
ceq.try_join(equality_checks.into_iter().take(to_take).enumerate().map(
|(check_idx, check)| {
let step = BitOpStep::from(check_idx);
let c = cmul.narrow(&step);
let record_id = RecordId::from(i);
let credit = &cred;
async move { check.multiply(credit, c, record_id).await }
},
))
move_single_value_to_bucket(
ctx,
RecordId::from(i),
bk.unwrap(),
cred,
usize::try_from(max_breakdown_key).unwrap(),
true,
)
.await
}
}),
);

let aggregate = increments
.try_fold(
vec![S::ZERO; max_breakdown_key as usize],
Expand All @@ -143,8 +122,7 @@ where

#[derive(Step)]
pub(crate) enum Step {
ComputeEqualityChecks,
CheckTimesCredit,
MoveValueToBucket,
ModConvBreakdownKeyBits,
}

Expand Down
24 changes: 12 additions & 12 deletions ipa-core/src/protocol/ipa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1125,9 +1125,9 @@ pub mod tests {
cap_one(),
SemiHonest,
PerfMetrics {
records_sent: 14_421,
bytes_sent: 47_100,
indexed_prss: 19_137,
records_sent: 14_397,
bytes_sent: 47_004,
indexed_prss: 19_113,
seq_prss: 1118,
},
)
Expand All @@ -1140,9 +1140,9 @@ pub mod tests {
cap_three(),
SemiHonest,
PerfMetrics {
records_sent: 21_756,
bytes_sent: 76_440,
indexed_prss: 28_146,
records_sent: 21_732,
bytes_sent: 76_344,
indexed_prss: 28_122,
seq_prss: 1118,
},
)
Expand All @@ -1155,9 +1155,9 @@ pub mod tests {
cap_one(),
Malicious,
PerfMetrics {
records_sent: 35_163,
bytes_sent: 130_068,
indexed_prss: 72_447,
records_sent: 35_115,
bytes_sent: 129_876,
indexed_prss: 72_375,
seq_prss: 1132,
},
)
Expand All @@ -1170,9 +1170,9 @@ pub mod tests {
cap_three(),
Malicious,
PerfMetrics {
records_sent: 53_865,
bytes_sent: 204_876,
indexed_prss: 109_734,
records_sent: 53_817,
bytes_sent: 204_684,
indexed_prss: 109_662,
seq_prss: 1132,
},
)
Expand Down
Loading

0 comments on commit 3b9eb5c

Please sign in to comment.