Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow addition and multiplication for references to secret shared values #787

Merged
merged 20 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 8 additions & 42 deletions src/ff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod field;
mod galois_field;
mod prime_field;

use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::ops::{Add, AddAssign, Sub, SubAssign};

pub use field::{Field, FieldType};
pub use galois_field::{GaloisField, Gf2, Gf32Bit, Gf3Bit, Gf40Bit, Gf5Bit, Gf8Bit};
Expand All @@ -15,59 +15,25 @@ use generic_array::{ArrayLength, GenericArray};
pub use prime_field::Fp31;
pub use prime_field::{Fp32BitPrime, PrimeField};

use crate::secret_sharing::SharedValue;

#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum Error {
#[error("unknown field type {type_str}")]
UnknownField { type_str: String },
}

pub trait ArithmeticOps:
Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Neg<Output = Self>
+ Sized
{
}

impl<T> ArithmeticOps for T where
T: Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Neg<Output = Self>
+ Sized
/// Addition and subtraction operations that are supported by secret sharings and shared values.
pub trait AddSub<Rhs = Self, Output = Self>:
Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + Sized
{
}

pub trait ArithmeticRefOps<V: SharedValue>:
for<'a> Add<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ Neg<Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ for<'a> SubAssign<&'a Self>
+ Mul<V, Output = Self>
impl<T, Rhs, Output> AddSub<Rhs, Output> for T where
T: Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + Sized
{
}

impl<T, V> ArithmeticRefOps<V> for T
where
T: for<'a> Add<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ Neg<Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ for<'a> SubAssign<&'a Self>
+ Mul<V, Output = Self>,
V: SharedValue,
{
}
pub trait AddSubAssign<Rhs = Self>: AddAssign<Rhs> + SubAssign<Rhs> {}
impl<T, Rhs> AddSubAssign<Rhs> for T where T: AddAssign<Rhs> + SubAssign<Rhs> {}

/// Trait for items that have fixed-byte length representation.
pub trait Serializable: Sized {
Expand Down
5 changes: 4 additions & 1 deletion src/protocol/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
semi_honest::AdditiveShare as Replicated,
ReplicatedSecretSharing,
},
BitDecomposed, Linear as LinearSecretSharing,
BitDecomposed, Linear as LinearSecretSharing, LinearRefOps,
},
seq_join::seq_join,
};
Expand Down Expand Up @@ -76,6 +76,7 @@ where
+ Serializable
+ DowngradeMalicious<Target = Replicated<F>>
+ 'static,
for<'r> &'r S: LinearRefOps<'r, S, F>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = SB>,
SB: LinearSecretSharing<Gf2> + BasicProtocols<C::UpgradedContext<Gf2>, Gf2> + 'static,
F: PrimeField + ExtendableField,
Expand Down Expand Up @@ -130,6 +131,7 @@ where
I2: Stream<Item = Result<BitDecomposed<S>, Error>> + Send,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let equality_check_ctx = ctx.narrow(&Step::ComputeEqualityChecks);

Expand Down Expand Up @@ -172,6 +174,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let check_times_value_ctx = ctx.narrow(&Step::CheckTimesValue);

Expand Down
7 changes: 5 additions & 2 deletions src/protocol/attribution/apply_attribution_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
context::{Context, UpgradedContext},
BasicProtocols, RecordId,
},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
};

/// This protocol applies the specified attribution window to trigger events. All trigger values of
Expand All @@ -37,6 +37,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + 'static,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
if let Some(attribution_window_seconds) = attribution_window_seconds {
let mut t_deltas = prefix_sum_time_deltas(&ctx, input, stop_bits).await?;
Expand Down Expand Up @@ -90,6 +91,7 @@ where
F: Field,
C: Context,
T: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a T: LinearRefOps<'a, T, F>,
{
let num_rows = input.len();

Expand All @@ -108,7 +110,7 @@ where
.map(|(i, ((prev, curr), b))| {
let c = t_delta_context.clone();
let record_id = RecordId::from(i);
let delta = curr.timestamp.clone() - &prev.timestamp;
let delta = &curr.timestamp - &prev.timestamp;
async move { delta.multiply(b, c, record_id).await }
}),
)
Expand Down Expand Up @@ -149,6 +151,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let ctx = ctx.set_total_records(input.len());
let random_bits_generator =
Expand Down
8 changes: 6 additions & 2 deletions src/protocol/attribution/credit_capping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
context::{Context, UpgradedContext},
BasicProtocols, RecordId,
},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
seq_join::seq_join,
};

Expand All @@ -35,6 +35,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
if cap == 1 {
return Ok(credit_capping_max_one(ctx, input)
Expand Down Expand Up @@ -241,6 +242,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let share_of_cap = S::share_known_value(&ctx, F::truncate_from(cap));
let cap_ref = &share_of_cap;
Expand Down Expand Up @@ -295,6 +297,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let ctx_ref = &ctx;
let ctx = ctx.set_total_records(prefix_summed_credits.len());
Expand Down Expand Up @@ -357,6 +360,7 @@ where
F: Field,
C: Context,
T: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a T: LinearRefOps<'a, T, F>,
{
let num_rows = input.len();
let cap_share = T::share_known_value(&ctx, F::try_from(cap.into()).unwrap());
Expand Down Expand Up @@ -413,7 +417,7 @@ where
record_id,
next_credit_exceeds_cap,
&T::ZERO,
&(cap.clone() - next_prefix_summed_credit),
&(cap - next_prefix_summed_credit),
)
.await?,
cap,
Expand Down
5 changes: 4 additions & 1 deletion src/protocol/attribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::{
malicious::{DowngradeMalicious, ExtendableField},
semi_honest::AdditiveShare as Replicated,
},
Linear as LinearSecretSharing,
Linear as LinearSecretSharing, LinearRefOps,
},
seq_join::assert_send,
};
Expand All @@ -62,11 +62,13 @@ where
+ Serializable
+ DowngradeMalicious<Target = Replicated<F>>
+ 'static,
for<'a> &'a S: LinearRefOps<'a, S, F>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = SB> + Context,
SB: LinearSecretSharing<Gf2>
+ BasicProtocols<C::UpgradedContext<Gf2>, Gf2>
+ DowngradeMalicious<Target = Replicated<Gf2>>
+ 'static,
for<'a> &'a SB: LinearRefOps<'a, SB, Gf2>,
F: PrimeField + ExtendableField,
ShuffledPermutationWrapper<S, C::UpgradedContext<F>>: DowngradeMalicious<Target = Vec<u32>>,
{
Expand Down Expand Up @@ -403,6 +405,7 @@ async fn compute_helper_bits_gf2<C, S>(
where
C: Context,
S: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
for<'a> &'a S: LinearRefOps<'a, S, Gf2>,
{
let narrowed_ctx = ctx
.narrow(&Step::ComputeHelperBits)
Expand Down
7 changes: 4 additions & 3 deletions src/protocol/basics/if_else.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
error::Error,
ff::Field,
protocol::{basics::SecureMul, context::Context, RecordId},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
};

/// Returns `true_value` if `condition` is a share of 1, else `false_value`.
Expand All @@ -19,6 +19,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + SecureMul<C>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
// If `condition` is a share of 1 (true), then
// = false_value + 1 * (true_value - false_value)
Expand All @@ -28,8 +29,8 @@ where
// If `condition` is a share of 0 (false), then
// = false_value + 0 * (true_value - false_value)
// = false_value
Ok(false_value.clone()
Ok(false_value
+ &condition
.multiply(&(true_value.clone() - false_value), ctx, record_id)
.multiply(&(true_value - false_value), ctx, record_id)
.await?)
}
11 changes: 5 additions & 6 deletions src/protocol/boolean/add_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
error::Error,
ff::Field,
protocol::{context::Context, step::BitOpStep, BasicProtocols, RecordId},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
};

/// This is an implementation of a Bitwise Sum of a bitwise-shared number with a constant.
Expand Down Expand Up @@ -52,6 +52,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
let mut output = Vec::with_capacity(a.len() + 1);

Expand All @@ -67,6 +68,7 @@ where
S::share_known_value(&ctx, F::ONE) - &a[0]
};
output.push(result_bit);
let two = F::truncate_from(2_u8);

for (bit_index, bit) in a.iter().enumerate().skip(1) {
let mult_result = if last_carry_known_to_be_zero {
Expand Down Expand Up @@ -94,12 +96,9 @@ where
// the current bit of `a` + the current bit of `b` + the carry from the previous bit `-2*next_carry`
// Since the current bit of `b` has a known value (either 1 or 0), we either add a `share_of_one`, or nothing.
let result_bit = if next_bit_a_one {
-next_carry.clone() * F::truncate_from(2_u128)
+ &S::share_known_value(&ctx, F::ONE)
+ bit
+ &last_carry
-(&next_carry * two) + &S::share_known_value(&ctx, F::ONE) + bit + &last_carry
} else {
-next_carry.clone() * F::truncate_from(2_u128) + bit + &last_carry
-(&next_carry * two) + bit + &last_carry
};
output.push(result_bit);

Expand Down
8 changes: 5 additions & 3 deletions src/protocol/boolean/bitwise_equal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
error::Error,
ff::{Field, Gf2},
protocol::{boolean::all_zeroes, context::Context, BasicProtocols, RecordId},
secret_sharing::Linear as LinearSecretSharing,
secret_sharing::{Linear as LinearSecretSharing, LinearRefOps},
};

/// Compares `[a]` and `c`, and returns 1 iff `a == c`
Expand All @@ -24,6 +24,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: LinearRefOps<'a, S, F>,
{
assert!(a.len() <= 128);

Expand All @@ -36,7 +37,7 @@ where
if ((c >> i) & 1) == 0 {
a_bit.clone()
} else {
one.clone() - a_bit
&one - a_bit
}
})
.collect::<Vec<_>>();
Expand All @@ -56,10 +57,11 @@ pub async fn bitwise_equal_gf2<C, S>(
where
C: Context,
S: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
for<'a> &'a S: LinearRefOps<'a, S, Gf2>,
{
debug_assert!(a.len() == b.len());
let c = zip(a.iter(), b.iter())
.map(|(a_bit, b_bit)| a_bit.clone() - b_bit)
.map(|(a_bit, b_bit)| a_bit - b_bit)
.collect::<Vec<_>>();

all_zeroes(ctx, record_id, &c).await
Expand Down
Loading