Skip to content

Commit

Permalink
Moved SaturatingSum into a separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage committed Sep 23, 2023
1 parent 34c32d8 commit 636e193
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 94 deletions.
1 change: 1 addition & 0 deletions src/protocol/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod comparison;
pub mod generate_random_bits;
pub mod or;
pub mod random_bits_generator;
pub mod saturating_sum;
pub mod solved_bits;
mod xor;

Expand Down
171 changes: 171 additions & 0 deletions src/protocol/boolean/saturating_sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
use crate::{
error::Error,
ff::Gf2,
protocol::{context::Context, step::BitOpStep, BasicProtocols, RecordId},
secret_sharing::{BitDecomposed, Linear as LinearSecretSharing},
};

#[derive(Debug)]
pub struct SaturatingSum<S: LinearSecretSharing<Gf2>> {
pub sum: BitDecomposed<S>,
pub is_saturated: S,
}

impl<S: LinearSecretSharing<Gf2>> SaturatingSum<S> {
pub fn new(value: BitDecomposed<S>, is_saturated: S) -> SaturatingSum<S> {
SaturatingSum {
sum: value,
is_saturated,
}
}

pub async fn add<C>(
&self,
ctx: C,
record_id: RecordId,
value: &BitDecomposed<S>,
) -> Result<SaturatingSum<S>, Error>
where
C: Context,
S: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
{
assert!(self.sum.len() >= value.len());

let mut output_sum = Vec::with_capacity(self.sum.len());
let mut carry_in = S::ZERO;
for i in 0..self.sum.len() {
let c = ctx.narrow(&BitOpStep::from(i));
let (sum_bit, carry_out) = if i < value.len() {
one_bit_adder(c, record_id, &value[i], &self.sum[i], &carry_in).await?
} else {
one_bit_adder(c, record_id, &S::ZERO, &self.sum[i], &carry_in).await?
};

output_sum.push(sum_bit);
carry_in = carry_out;
}
let is_saturated = -carry_in
.clone()
.multiply(
&self.is_saturated,
ctx.narrow(&BitOpStep::from(self.sum.len())),
record_id,
)
.await?
+ &carry_in
+ &self.is_saturated;

Ok(SaturatingSum::new(
BitDecomposed::new(output_sum),
is_saturated,
))
}
}

///
/// Returns (`sum_bit`, `carry_out`)
///
async fn one_bit_adder<C, SB>(
ctx: C,
record_id: RecordId,
x: &SB,
y: &SB,
carry_in: &SB,
) -> Result<(SB, SB), Error>
where
C: Context,
SB: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
{
// compute sum bit as x XOR y XOR carry_in
let sum_bit = x.clone() + y + carry_in;

let x_xor_carry_in = x.clone() + carry_in;
let y_xor_carry_in = y.clone() + carry_in;
let carry_out = x_xor_carry_in
.multiply(&y_xor_carry_in, ctx, record_id)
.await?
+ carry_in;

Ok((sum_bit, carry_out))
}

#[cfg(all(test, unit_test))]
mod tests {
use super::SaturatingSum;
use crate::{
ff::{Field, Gf2},
protocol::{context::Context, RecordId},
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, SharedValue,
},
test_fixture::{get_bits, Reconstruct, Runner, TestWorld},
};

impl Reconstruct<u128> for [SaturatingSum<Replicated<Gf2>>; 3] {
fn reconstruct(&self) -> u128 {
let [s0, s1, s2] = self;

let sum_bits: BitDecomposed<Gf2> = BitDecomposed::new(
s0.sum
.iter()
.zip(s1.sum.iter())
.zip(s2.sum.iter())
.map(|((a, b), c)| [a, b, c].reconstruct()),
);

let is_saturated = [&s0.is_saturated, &s1.is_saturated, &s2.is_saturated].reconstruct();

if is_saturated == Gf2::ZERO {
sum_bits
.iter()
.map(Field::as_u128)
.enumerate()
.fold(0_u128, |acc, (i, x)| acc + (x << i))
} else {
2_u128.pow(s0.sum.len() as u32)
}
}
}

#[tokio::test]
pub async fn simple() {
assert_eq!(2, saturating_add(1, 2, 1, 2).await);
assert_eq!(3, saturating_add(2, 2, 1, 2).await);
assert_eq!(4, saturating_add(3, 2, 1, 2).await);
assert_eq!(4, saturating_add(3, 2, 2, 2).await);
assert_eq!(4, saturating_add(3, 2, 3, 2).await);
assert_eq!(6, saturating_add(3, 5, 3, 3).await);
assert_eq!(6, saturating_add(3, 5, 3, 5).await);
assert_eq!(14, saturating_add(7, 5, 7, 3).await);
assert_eq!(14, saturating_add(7, 5, 7, 5).await);
assert_eq!(31, saturating_add(26, 5, 5, 3).await);
assert_eq!(32, saturating_add(26, 5, 6, 3).await);
assert_eq!(32, saturating_add(26, 5, 7, 3).await);
assert_eq!(32, saturating_add(31, 5, 7, 3).await);
assert_eq!(63, saturating_add(60, 6, 3, 3).await);
assert_eq!(64, saturating_add(60, 6, 4, 3).await);
assert_eq!(64, saturating_add(60, 6, 5, 3).await);
}

async fn saturating_add(a: u32, num_a_bits: u32, b: u32, num_b_bits: u32) -> u128 {
let world = TestWorld::default();

let a_bits = get_bits::<Gf2>(a, num_a_bits);
//let a_saturated = Gf2::ZERO;
let b_bits = get_bits::<Gf2>(b, num_b_bits);

let foo = world
.semi_honest(
(a_bits, b_bits),
|ctx, (a_bits, b_bits): (BitDecomposed<_>, BitDecomposed<_>)| async move {
let a = SaturatingSum::new(a_bits, Replicated::ZERO);
a.add(ctx.set_total_records(1), RecordId(0), &b_bits)
.await
.unwrap()
},
)
.await;

foo.reconstruct()
}
}
121 changes: 27 additions & 94 deletions src/protocol/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use futures_util::future::try_join_all;
use ipa_macros::step;
use strum::AsRefStr;

use super::step::BitOpStep;
use super::{boolean::saturating_sum::SaturatingSum, step::BitOpStep};
use crate::{
error::Error,
ff::{Field, GaloisField, Gf2},
Expand All @@ -30,8 +30,7 @@ pub struct PrfShardedIpaInputRow<BK: GaloisField, TV: GaloisField> {
struct InputsRequiredFromPrevRow {
ever_encountered_a_source_event: Replicated<Gf2>,
attributed_breakdown_key_bits: BitDecomposed<Replicated<Gf2>>,
saturating_sum: BitDecomposed<Replicated<Gf2>>,
is_saturated: Replicated<Gf2>,
saturating_sum: SaturatingSum<Replicated<Gf2>>,
difference_to_cap: BitDecomposed<Replicated<Gf2>>,
}

Expand Down Expand Up @@ -275,83 +274,16 @@ where
Gf2::truncate_from(input_row.breakdown_key.right()[i]),
)
}),
saturating_sum: BitDecomposed::new(vec![Replicated::ZERO; num_saturating_sum_bits]),
is_saturated: Replicated::ZERO,
saturating_sum: SaturatingSum::new(
BitDecomposed::new(vec![Replicated::ZERO; num_saturating_sum_bits]),
Replicated::ZERO,
),
// This is incorrect in the case that the CAP is less than the maximum value of "trigger value" for a single row
// Not a problem if you assume that's an invalid input
difference_to_cap: BitDecomposed::new(vec![Replicated::ZERO; num_trigger_value_bits]),
}
}

///
/// Returns (`sum_bit`, `carry_out`)
///
async fn one_bit_adder<C, SB>(
ctx: C,
record_id: RecordId,
x: &SB,
y: &SB,
carry_in: &SB,
) -> Result<(SB, SB), Error>
where
C: UpgradedContext<Gf2, Share = SB>,
SB: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
{
// compute sum bit as x XOR y XOR carry_in
let sum_bit = x.clone() + y + carry_in;

let x_xor_carry_in = x.clone() + carry_in;
let y_xor_carry_in = y.clone() + carry_in;
let carry_out = x_xor_carry_in
.multiply(&y_xor_carry_in, ctx, record_id)
.await?
+ carry_in;

Ok((sum_bit, carry_out))
}

async fn compute_saturating_sum<C, SB>(
ctx: C,
record_id: RecordId,
cur_value: &BitDecomposed<SB>,
prev_sum: &BitDecomposed<SB>,
prev_is_saturated: &SB,
num_trigger_value_bits: usize,
num_saturating_sum_bits: usize,
) -> Result<(BitDecomposed<SB>, SB), Error>
where
C: UpgradedContext<Gf2, Share = SB>,
SB: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
{
assert!(cur_value.len() == num_trigger_value_bits);
assert!(prev_sum.len() == num_saturating_sum_bits);

let mut carry_in = SB::ZERO;
let mut output = vec![];
for i in 0..num_saturating_sum_bits {
let c = ctx.narrow(&BitOpStep::from(i));
let (sum_bit, carry_out) = if i < num_trigger_value_bits {
one_bit_adder(c, record_id, &cur_value[i], &prev_sum[i], &carry_in).await?
} else {
one_bit_adder(c, record_id, &SB::ZERO, &prev_sum[i], &carry_in).await?
};

output.push(sum_bit);
carry_in = carry_out;
}
let updated_is_saturated = -carry_in
.clone()
.multiply(
prev_is_saturated,
ctx.narrow(&BitOpStep::from(num_saturating_sum_bits)),
record_id,
)
.await?
+ &carry_in
+ prev_is_saturated;
Ok((BitDecomposed::new(output), updated_is_saturated))
}

///
/// Returns (`difference_bit`, `carry_out`)
///
Expand Down Expand Up @@ -389,19 +321,19 @@ where
async fn compute_truncated_difference_to_cap<C, SB>(
ctx: C,
record_id: RecordId,
cur_sum: &BitDecomposed<SB>,
cur_sum: &SaturatingSum<SB>,
num_trigger_value_bits: usize,
num_saturating_sum_bits: usize,
) -> Result<BitDecomposed<SB>, Error>
where
C: UpgradedContext<Gf2, Share = SB>,
SB: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
{
assert!(cur_sum.len() == num_saturating_sum_bits);
assert!(cur_sum.sum.len() == num_saturating_sum_bits);

let mut carry_in = SB::share_known_value(&ctx, Gf2::ONE);
let mut output = vec![];
for (i, bit) in cur_sum.iter().enumerate().take(num_trigger_value_bits) {
for (i, bit) in cur_sum.sum.iter().enumerate().take(num_trigger_value_bits) {
let c = ctx.narrow(&BitOpStep::from(i));
let (difference_bit, carry_out) =
one_bit_subtractor(c, record_id, &SB::ZERO, bit, &carry_in).await?;
Expand Down Expand Up @@ -447,7 +379,7 @@ where
== num_breakdown_key_bits
);
assert!(tv.len() == num_trigger_value_bits);
assert!(inputs_required_from_previous_row.saturating_sum.len() == num_saturating_sum_bits);
assert!(inputs_required_from_previous_row.saturating_sum.sum.len() == num_saturating_sum_bits);

let share_of_one = Replicated::share_known_value(&ctx, Gf2::ONE);

Expand Down Expand Up @@ -519,21 +451,23 @@ where
.await?,
);

let (saturating_sum, is_saturated) = compute_saturating_sum(
ctx.narrow(&Step::ComputeSaturatingSum),
record_id,
&attributed_trigger_value,
&inputs_required_from_previous_row.saturating_sum,
&inputs_required_from_previous_row.is_saturated,
num_trigger_value_bits,
num_saturating_sum_bits,
)
.await?;
let updated_sum = inputs_required_from_previous_row
.saturating_sum
.add(
ctx.narrow(&Step::ComputeSaturatingSum),
record_id,
&attributed_trigger_value,
)
.await?;

// TODO: compute is_saturated_and_prev_row_not_saturated and difference_to_cap in parallel
let is_saturated_and_prev_row_not_saturated = is_saturated
let is_saturated_and_prev_row_not_saturated = updated_sum
.is_saturated
.multiply(
&(share_of_one - &inputs_required_from_previous_row.is_saturated),
&(share_of_one
- &inputs_required_from_previous_row
.saturating_sum
.is_saturated),
ctx.narrow(&Step::IsSaturatedAndPrevRowNotSaturated),
record_id,
)
Expand All @@ -543,7 +477,7 @@ where
compute_truncated_difference_to_cap(
ctx.narrow(&Step::ComputeDifferenceToCap),
record_id,
&saturating_sum,
&updated_sum,
num_trigger_value_bits,
num_saturating_sum_bits,
)
Expand All @@ -557,7 +491,7 @@ where
attributed_trigger_value
.iter()
.zip(inputs_required_from_previous_row.difference_to_cap.iter())
.zip(repeat(is_saturated.clone()))
.zip(repeat(updated_sum.is_saturated.clone()))
.zip(repeat(is_saturated_and_prev_row_not_saturated))
.enumerate()
.map(
Expand Down Expand Up @@ -589,8 +523,7 @@ where
let inputs_required_for_next_row = InputsRequiredFromPrevRow {
ever_encountered_a_source_event,
attributed_breakdown_key_bits: BitDecomposed::new(attributed_breakdown_key_bits.clone()),
saturating_sum,
is_saturated,
saturating_sum: updated_sum,
difference_to_cap,
};
let outputs_for_aggregation = CappedAttributionOutputs {
Expand Down

0 comments on commit 636e193

Please sign in to comment.