Skip to content

Commit

Permalink
Merge pull request #787 from private-attribution/ref-refs
Browse files Browse the repository at this point in the history
Allow addition and multiplication for references to secret shared values
  • Loading branch information
martinthomson authored Sep 29, 2023
2 parents 2b64044 + 16e8768 commit 0599ba6
Show file tree
Hide file tree
Showing 22 changed files with 394 additions and 116 deletions.
50 changes: 8 additions & 42 deletions src/ff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod field;
mod galois_field;
mod prime_field;

use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::ops::{Add, AddAssign, Sub, SubAssign};

pub use field::{Field, FieldType};
pub use galois_field::{GaloisField, Gf2, Gf32Bit, Gf3Bit, Gf40Bit, Gf5Bit, Gf8Bit};
Expand All @@ -15,59 +15,25 @@ use generic_array::{ArrayLength, GenericArray};
pub use prime_field::Fp31;
pub use prime_field::{Fp32BitPrime, PrimeField};

use crate::secret_sharing::SharedValue;

#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum Error {
#[error("unknown field type {type_str}")]
UnknownField { type_str: String },
}

pub trait ArithmeticOps:
Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Neg<Output = Self>
+ Sized
{
}

impl<T> ArithmeticOps for T where
T: Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Neg<Output = Self>
+ Sized
/// Addition and subtraction operations that are supported by secret sharings and shared values.
pub trait AddSub<Rhs = Self, Output = Self>:
Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + Sized
{
}

pub trait ArithmeticRefOps<V: SharedValue>:
for<'a> Add<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ Neg<Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ for<'a> SubAssign<&'a Self>
+ Mul<V, Output = Self>
impl<T, Rhs, Output> AddSub<Rhs, Output> for T where
T: Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + Sized
{
}

impl<T, V> ArithmeticRefOps<V> for T
where
T: for<'a> Add<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ Neg<Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ for<'a> SubAssign<&'a Self>
+ Mul<V, Output = Self>,
V: SharedValue,
{
}
pub trait AddSubAssign<Rhs = Self>: AddAssign<Rhs> + SubAssign<Rhs> {}
impl<T, Rhs> AddSubAssign<Rhs> for T where T: AddAssign<Rhs> + SubAssign<Rhs> {}

/// Trait for items that have fixed-byte length representation.
pub trait Serializable: Sized {
Expand Down
5 changes: 4 additions & 1 deletion src/protocol/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
semi_honest::AdditiveShare as Replicated,
ReplicatedSecretSharing,
},
BitDecomposed, Linear as LinearSecretSharing,
BitDecomposed, Linear as LinearSecretSharing, LinearRefOps,
},
seq_join::seq_join,
};
Expand Down Expand Up @@ -76,6 +76,7 @@ where
+ Serializable
+ DowngradeMalicious<Target = Replicated<F>>
+ 'static,
for<'r> &'r S: LinearRefOps<'r, S, F>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = SB>,
SB: LinearSecretSharing<Gf2> + BasicProtocols<C::UpgradedContext<Gf2>, Gf2> + 'static,
F: PrimeField + ExtendableField,
Expand Down Expand Up @@ -130,6 +131,7 @@ where
I2: Stream<Item = Result<BitDecomposed<S>, Error>> + Send,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let equality_check_ctx = ctx.narrow(&Step::ComputeEqualityChecks);

Expand Down Expand Up @@ -172,6 +174,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let check_times_value_ctx = ctx.narrow(&Step::CheckTimesValue);

Expand Down
7 changes: 5 additions & 2 deletions src/protocol/attribution/apply_attribution_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
context::{Context, UpgradedContext},
BasicProtocols, RecordId,
},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
};

/// This protocol applies the specified attribution window to trigger events. All trigger values of
Expand All @@ -37,6 +37,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + 'static,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
if let Some(attribution_window_seconds) = attribution_window_seconds {
let mut t_deltas = prefix_sum_time_deltas(&ctx, input, stop_bits).await?;
Expand Down Expand Up @@ -90,6 +91,7 @@ where
F: Field,
C: Context,
T: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a T: LinearRefOps<'a, T, F>,
{
let num_rows = input.len();

Expand All @@ -108,7 +110,7 @@ where
.map(|(i, ((prev, curr), b))| {
let c = t_delta_context.clone();
let record_id = RecordId::from(i);
let delta = curr.timestamp.clone() - &prev.timestamp;
let delta = &curr.timestamp - &prev.timestamp;
async move { delta.multiply(b, c, record_id).await }
}),
)
Expand Down Expand Up @@ -149,6 +151,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let ctx = ctx.set_total_records(input.len());
let random_bits_generator =
Expand Down
8 changes: 6 additions & 2 deletions src/protocol/attribution/credit_capping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
context::{Context, UpgradedContext},
BasicProtocols, RecordId,
},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
seq_join::seq_join,
};

Expand All @@ -35,6 +35,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
if cap == 1 {
return Ok(credit_capping_max_one(ctx, input)
Expand Down Expand Up @@ -241,6 +242,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let share_of_cap = S::share_known_value(&ctx, F::truncate_from(cap));
let cap_ref = &share_of_cap;
Expand Down Expand Up @@ -295,6 +297,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let ctx_ref = &ctx;
let ctx = ctx.set_total_records(prefix_summed_credits.len());
Expand Down Expand Up @@ -357,6 +360,7 @@ where
F: Field,
C: Context,
T: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a T: LinearRefOps<'a, T, F>,
{
let num_rows = input.len();
let cap_share = T::share_known_value(&ctx, F::try_from(cap.into()).unwrap());
Expand Down Expand Up @@ -413,7 +417,7 @@ where
record_id,
next_credit_exceeds_cap,
&T::ZERO,
&(cap.clone() - next_prefix_summed_credit),
&(cap - next_prefix_summed_credit),
)
.await?,
cap,
Expand Down
5 changes: 4 additions & 1 deletion src/protocol/attribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::{
malicious::{DowngradeMalicious, ExtendableField},
semi_honest::AdditiveShare as Replicated,
},
Linear as LinearSecretSharing,
Linear as LinearSecretSharing, LinearRefOps,
},
seq_join::assert_send,
};
Expand All @@ -62,11 +62,13 @@ where
+ Serializable
+ DowngradeMalicious<Target = Replicated<F>>
+ 'static,
for<'a> &'a S: LinearRefOps<'a, S, F>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = SB> + Context,
SB: LinearSecretSharing<Gf2>
+ BasicProtocols<C::UpgradedContext<Gf2>, Gf2>
+ DowngradeMalicious<Target = Replicated<Gf2>>
+ 'static,
for<'a> &'a SB: LinearRefOps<'a, SB, Gf2>,
F: PrimeField + ExtendableField,
ShuffledPermutationWrapper<S, C::UpgradedContext<F>>: DowngradeMalicious<Target = Vec<u32>>,
{
Expand Down Expand Up @@ -403,6 +405,7 @@ async fn compute_helper_bits_gf2<C, S>(
where
C: Context,
S: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
for<'a> &'a S: LinearRefOps<'a, S, Gf2>,
{
let narrowed_ctx = ctx
.narrow(&Step::ComputeHelperBits)
Expand Down
7 changes: 4 additions & 3 deletions src/protocol/basics/if_else.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
error::Error,
ff::Field,
protocol::{basics::SecureMul, context::Context, RecordId},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
};

/// Returns `true_value` if `condition` is a share of 1, else `false_value`.
Expand All @@ -19,6 +19,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + SecureMul<C>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
// If `condition` is a share of 1 (true), then
// = false_value + 1 * (true_value - false_value)
Expand All @@ -28,8 +29,8 @@ where
// If `condition` is a share of 0 (false), then
// = false_value + 0 * (true_value - false_value)
// = false_value
Ok(false_value.clone()
Ok(false_value
+ &condition
.multiply(&(true_value.clone() - false_value), ctx, record_id)
.multiply(&(true_value - false_value), ctx, record_id)
.await?)
}
11 changes: 5 additions & 6 deletions src/protocol/boolean/add_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
error::Error,
ff::Field,
protocol::{context::Context, step::BitOpStep, BasicProtocols, RecordId},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
};

/// This is an implementation of a Bitwise Sum of a bitwise-shared number with a constant.
Expand Down Expand Up @@ -52,6 +52,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let mut output = Vec::with_capacity(a.len() + 1);

Expand All @@ -67,6 +68,7 @@ where
S::share_known_value(&ctx, F::ONE) - &a[0]
};
output.push(result_bit);
let two = F::truncate_from(2_u8);

for (bit_index, bit) in a.iter().enumerate().skip(1) {
let mult_result = if last_carry_known_to_be_zero {
Expand Down Expand Up @@ -94,12 +96,9 @@ where
// the current bit of `a` + the current bit of `b` + the carry from the previous bit `-2*next_carry`
// Since the current bit of `b` has a known value (either 1 or 0), we either add a `share_of_one`, or nothing.
let result_bit = if next_bit_a_one {
-next_carry.clone() * F::truncate_from(2_u128)
+ &S::share_known_value(&ctx, F::ONE)
+ bit
+ &last_carry
-(&next_carry * two) + &S::share_known_value(&ctx, F::ONE) + bit + &last_carry
} else {
-next_carry.clone() * F::truncate_from(2_u128) + bit + &last_carry
-(&next_carry * two) + bit + &last_carry
};
output.push(result_bit);

Expand Down
8 changes: 5 additions & 3 deletions src/protocol/boolean/bitwise_equal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
error::Error,
ff::{Field, Gf2},
protocol::{boolean::all_zeroes, context::Context, BasicProtocols, RecordId},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
};

/// Compares `[a]` and `c`, and returns 1 iff `a == c`
Expand All @@ -24,6 +24,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
assert!(a.len() <= 128);

Expand All @@ -36,7 +37,7 @@ where
if ((c >> i) & 1) == 0 {
a_bit.clone()
} else {
one.clone() - a_bit
&one - a_bit
}
})
.collect::<Vec<_>>();
Expand All @@ -56,10 +57,11 @@ pub async fn bitwise_equal_gf2<C, S>(
where
C: Context,
S: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
for<'a> &'a S: LinearRefOps<'a, S, Gf2>,
{
debug_assert!(a.len() == b.len());
let c = zip(a.iter(), b.iter())
.map(|(a_bit, b_bit)| a_bit.clone() - b_bit)
.map(|(a_bit, b_bit)| a_bit - b_bit)
.collect::<Vec<_>>();

all_zeroes(ctx, record_id, &c).await
Expand Down
Loading

0 comments on commit 0599ba6

Please sign in to comment.