Skip to content

Commit

Permalink
no more stalling
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage committed Sep 22, 2023
1 parent 261872e commit 34c32d8
Showing 1 changed file with 192 additions and 76 deletions.
268 changes: 192 additions & 76 deletions src/protocol/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use ipa_macros::step;
use std::iter::repeat;
use strum::AsRefStr;

use futures_util::future::try_join_all;
use metrics_util::registry::Storage;
use ipa_macros::step;
use strum::AsRefStr;

use super::step::BitOpStep;
use crate::{
error::Error,
ff::{Field, GaloisField, Gf2},
protocol::{
basics::{SecureMul, ShareKnownValue},
context::{Context, UpgradableContext, UpgradedContext, Validator},
context::{UpgradableContext, UpgradedContext, Validator},
BasicProtocols, RecordId,
},
repeat64str,
Expand Down Expand Up @@ -74,32 +73,17 @@ pub(crate) enum Step {
ComputedCappedAttributedTriggerValueJustSaturatedCase,
}

fn compute_histogram_of_users_with_row_count<BK, TV>(
input_rows: &[PrfShardedIpaInputRow<BK, TV>],
) -> Vec<usize>
where
BK: GaloisField,
TV: GaloisField,
{
let (_, _, hist) = input_rows.iter().fold(
(0, 0, vec![]),
|(last_prf, rows_for_user, mut histogram), input_row| {
if last_prf == input_row.prf_of_match_key {
if rows_for_user >= histogram.len() {
histogram.push(0);
}
histogram[rows_for_user] += 1;
(input_row.prf_of_match_key, rows_for_user + 1, histogram)
} else {
if histogram.is_empty() {
histogram.push(0);
}
histogram[0] += 1;
(input_row.prf_of_match_key, 1, histogram)
fn compute_histogram_of_users_with_row_count<S>(rows_chunked_by_user: &[Vec<S>]) -> Vec<usize> {
let mut output = vec![];
for user_rows in rows_chunked_by_user {
for j in 0..user_rows.len() {
if j >= output.len() {
output.push(0);
}
},
);
hist
output[j] += 1;
}
}
output
}

fn set_up_contexts<C>(root_ctx: C, histogram: Vec<usize>) -> Vec<C>
Expand All @@ -120,6 +104,35 @@ where
context_per_row_depth
}

fn chunk_rows_by_user<BK, TV>(
input_rows: Vec<PrfShardedIpaInputRow<BK, TV>>,
) -> Vec<Vec<PrfShardedIpaInputRow<BK, TV>>>
where
BK: GaloisField,
TV: GaloisField,
{
let mut rows_for_user = vec![];

let mut rows_chunked_by_user = vec![];
for row in input_rows {
if rows_for_user.is_empty() {
rows_for_user.push(row);
} else {
if row.prf_of_match_key == rows_for_user[0].prf_of_match_key {
rows_for_user.push(row);
} else {
rows_chunked_by_user.push(rows_for_user);
rows_for_user = vec![row];
}
}
}
if !rows_for_user.is_empty() {
rows_chunked_by_user.push(rows_for_user);
}

rows_chunked_by_user
}

/// Sub-protocol of the PRF-sharded IPA Protocol
///
/// After the computation of the per-user PRF, addition of dummy records and shuffling,
Expand All @@ -139,7 +152,7 @@ where
/// Propagates errors from multiplications
pub async fn attribution_and_capping<C, BK, TV>(
sh_ctx: C,
input_rows: &[PrfShardedIpaInputRow<BK, TV>],
input_rows: Vec<PrfShardedIpaInputRow<BK, TV>>,
num_breakdown_key_bits: usize,
num_trigger_value_bits: usize,
num_saturating_sum_bits: usize,
Expand All @@ -154,66 +167,90 @@ where
assert!(num_trigger_value_bits > 0);
assert!(num_breakdown_key_bits > 0);

let rows_chunked_by_user = chunk_rows_by_user(input_rows);
let histogram = compute_histogram_of_users_with_row_count(&rows_chunked_by_user);

let binary_validator = sh_ctx.narrow(&Step::BinaryValidator).validator::<Gf2>();
// TODO: fix num total records to be not a hard-coded constant, but variable per step
// based on the histogram of how many users have how many records a piece
let binary_m_ctx = binary_validator.context();

let histogram = compute_histogram_of_users_with_row_count(input_rows);
let ctx_for_row_number = set_up_contexts(binary_m_ctx.clone(), histogram);

let mut output = vec![];
let mut futures = Vec::with_capacity(rows_chunked_by_user.len());
let mut num_users_who_encountered_row_depth = vec![];
for rows_for_user in rows_chunked_by_user {
for i in 0..rows_for_user.len() {
if i >= num_users_who_encountered_row_depth.len() {
num_users_who_encountered_row_depth.push(0);
}
num_users_who_encountered_row_depth[i] += 1;
}

futures.push(evaluate_per_user_attribution_circuit(
&ctx_for_row_number,
num_users_who_encountered_row_depth
.iter()
.take(rows_for_user.len())
.map(|x| RecordId(x - 1))
.collect(),
rows_for_user,
num_breakdown_key_bits,
num_trigger_value_bits,
num_saturating_sum_bits,
));
}
let outputs_chunked_by_user = try_join_all(futures).await?;
Ok(outputs_chunked_by_user
.into_iter()
.flatten()
.collect::<Vec<CappedAttributionOutputs>>())
}

assert!(!input_rows.is_empty());
let first_row = &input_rows[0];
let mut prev_prf = first_row.prf_of_match_key;
async fn evaluate_per_user_attribution_circuit<C, BK, TV>(
ctx_for_row_number: &[C],
record_id_for_each_depth: Vec<RecordId>,
rows_for_user: Vec<PrfShardedIpaInputRow<BK, TV>>,
num_breakdown_key_bits: usize,
num_trigger_value_bits: usize,
num_saturating_sum_bits: usize,
) -> Result<Vec<CappedAttributionOutputs>, Error>
where
C: UpgradedContext<Gf2, Share = Replicated<Gf2>>,
BK: GaloisField,
TV: GaloisField,
{
assert!(!rows_for_user.is_empty());
if rows_for_user.len() == 1 {
return Ok(vec![]);
}
let first_row = &rows_for_user[0];
let mut prev_row_inputs = initialize_new_device_attribution_variables(
Replicated::share_known_value(&binary_m_ctx, Gf2::ONE),
Replicated::share_known_value(&ctx_for_row_number[0], Gf2::ONE),
first_row,
num_breakdown_key_bits,
num_trigger_value_bits,
num_saturating_sum_bits,
);
let mut i: usize = 1;
let mut num_users_who_encountered_row_depth = vec![];
let mut row_for_user = 0;
while i < input_rows.len() {
let cur_row = &input_rows[i];
if prev_prf == cur_row.prf_of_match_key {
if row_for_user >= num_users_who_encountered_row_depth.len() {
num_users_who_encountered_row_depth.push(0);
}

let ctx_for_this_row_depth = ctx_for_row_number[row_for_user].clone();
// Do some actual computation
let (inputs_required_for_next_row, capped_attribution_outputs) =
compute_row_with_previous(
ctx_for_this_row_depth,
RecordId(num_users_who_encountered_row_depth[row_for_user]),
cur_row,
&prev_row_inputs,
num_breakdown_key_bits,
num_trigger_value_bits,
num_saturating_sum_bits,
)
.await?;
output.push(capped_attribution_outputs);
prev_row_inputs = inputs_required_for_next_row;
num_users_who_encountered_row_depth[row_for_user] += 1;
let mut output = Vec::with_capacity(rows_for_user.len() - 1);
for (i, row) in rows_for_user.iter().skip(1).enumerate() {
let ctx_for_this_row_depth = ctx_for_row_number[i].clone(); // no context was created for row 0
let record_id_for_this_row_depth = record_id_for_each_depth[i + 1]; // skip row 0

let (inputs_required_for_next_row, capped_attribution_outputs) = compute_row_with_previous(
ctx_for_this_row_depth,
record_id_for_this_row_depth,
row,
&prev_row_inputs,
num_breakdown_key_bits,
num_trigger_value_bits,
num_saturating_sum_bits,
)
.await?;

row_for_user += 1;
} else {
prev_prf = cur_row.prf_of_match_key;
prev_row_inputs = initialize_new_device_attribution_variables(
Replicated::share_known_value(&binary_m_ctx, Gf2::ONE),
cur_row,
num_breakdown_key_bits,
num_trigger_value_bits,
num_saturating_sum_bits,
);
row_for_user = 0;
}
i += 1;
output.push(capped_attribution_outputs);
prev_row_inputs = inputs_required_for_next_row;
}

Ok(output)
Expand Down Expand Up @@ -686,6 +723,34 @@ pub mod tests {
attributed_breakdown_key: 12,
capped_attributed_trigger_value: 5,
},
PreAggregationTestOutput {
attributed_breakdown_key: 20,
capped_attributed_trigger_value: 7,
},
PreAggregationTestOutput {
attributed_breakdown_key: 18,
capped_attributed_trigger_value: 0,
},
PreAggregationTestOutput {
attributed_breakdown_key: 12,
capped_attributed_trigger_value: 0,
},
PreAggregationTestOutput {
attributed_breakdown_key: 12,
capped_attributed_trigger_value: 7,
},
PreAggregationTestOutput {
attributed_breakdown_key: 12,
capped_attributed_trigger_value: 7,
},
PreAggregationTestOutput {
attributed_breakdown_key: 12,
capped_attributed_trigger_value: 7,
},
PreAggregationTestOutput {
attributed_breakdown_key: 12,
capped_attributed_trigger_value: 4,
},
];
const NUM_BREAKDOWN_KEY_BITS: usize = 5;
const NUM_TRIGGER_VALUE_BITS: usize = 3;
Expand All @@ -695,6 +760,7 @@ pub mod tests {
let world = TestWorld::default();

let records: Vec<PreShardedAndSortedOPRFTestInput<Gf8Bit, Gf8Bit>> = vec![
/* First User */
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 123,
is_trigger_bit: Gf2::ZERO,
Expand All @@ -719,25 +785,75 @@ pub mod tests {
breakdown_key: Gf8Bit::truncate_from(0_u8),
trigger_value: Gf8Bit::truncate_from(3_u8),
},
/* Second User */
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 234,
is_trigger_bit: Gf2::ZERO,
breakdown_key: Gf8Bit::truncate_from(12_u8),
trigger_value: Gf8Bit::truncate_from(3_u8),
trigger_value: Gf8Bit::truncate_from(0_u8),
},
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 234,
is_trigger_bit: Gf2::ONE,
breakdown_key: Gf8Bit::truncate_from(0_u8),
trigger_value: Gf8Bit::truncate_from(5_u8),
},
/* Third User */
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 345,
is_trigger_bit: Gf2::ZERO,
breakdown_key: Gf8Bit::truncate_from(20_u8),
trigger_value: Gf8Bit::truncate_from(0_u8),
},
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 345,
is_trigger_bit: Gf2::ONE,
breakdown_key: Gf8Bit::truncate_from(0_u8),
trigger_value: Gf8Bit::truncate_from(7_u8),
},
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 345,
is_trigger_bit: Gf2::ZERO,
breakdown_key: Gf8Bit::truncate_from(18_u8),
trigger_value: Gf8Bit::truncate_from(0_u8),
},
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 345,
is_trigger_bit: Gf2::ZERO,
breakdown_key: Gf8Bit::truncate_from(12_u8),
trigger_value: Gf8Bit::truncate_from(0_u8),
},
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 345,
is_trigger_bit: Gf2::ONE,
breakdown_key: Gf8Bit::truncate_from(0_u8),
trigger_value: Gf8Bit::truncate_from(7_u8),
},
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 345,
is_trigger_bit: Gf2::ONE,
breakdown_key: Gf8Bit::truncate_from(0_u8),
trigger_value: Gf8Bit::truncate_from(7_u8),
},
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 345,
is_trigger_bit: Gf2::ONE,
breakdown_key: Gf8Bit::truncate_from(0_u8),
trigger_value: Gf8Bit::truncate_from(7_u8),
},
PreShardedAndSortedOPRFTestInput {
prf_of_match_key: 345,
is_trigger_bit: Gf2::ONE,
breakdown_key: Gf8Bit::truncate_from(0_u8),
trigger_value: Gf8Bit::truncate_from(7_u8),
},
];

let result: Vec<_> = world
.semi_honest(records.into_iter(), |ctx, input_rows| async move {
attribution_and_capping(
ctx,
input_rows.as_slice(),
input_rows,
NUM_BREAKDOWN_KEY_BITS,
NUM_TRIGGER_VALUE_BITS,
NUM_SATURATING_SUM_BITS,
Expand Down

0 comments on commit 34c32d8

Please sign in to comment.