Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow addition and multiplication for references to secret shared values #787

Merged
merged 20 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 55 additions & 36 deletions src/ff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,57 +15,76 @@ use generic_array::{ArrayLength, GenericArray};
pub use prime_field::Fp31;
pub use prime_field::{Fp32BitPrime, PrimeField};

use crate::secret_sharing::SharedValue;

#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum Error {
#[error("unknown field type {type_str}")]
UnknownField { type_str: String },
}

pub trait ArithmeticOps:
Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Neg<Output = Self>
+ Sized
/// Arithmetic operations that do not require communication in our MPC setting and can be performed
/// locally.
///
/// Note: Neg operation is also local, but is causing issues when added as a bound to this trait.
/// The problem is that it does not use `Rhs` generic and rustc overflows trying to compile functions
/// that use HRTB generics bounded by `LocalArithmeticOps`.
pub trait LocalArithmeticOps<Rhs = Self, Output = Self>:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addsub?

Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + Sized
{
}

impl<T> ArithmeticOps for T where
T: Add<Output = Self>
+ AddAssign
+ Sub<Output = Self>
+ SubAssign
+ Mul<Output = Self>
+ MulAssign
+ Neg<Output = Self>
+ Sized
impl<T, Rhs, Output> LocalArithmeticOps<Rhs, Output> for T where
T: Add<Rhs, Output = Output> + Sub<Rhs, Output = Output> + Sized
{
}

pub trait ArithmeticRefOps<V: SharedValue>:
for<'a> Add<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ Neg<Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ for<'a> SubAssign<&'a Self>
+ Mul<V, Output = Self>
pub trait LocalAssignOps<Rhs = Self>: AddAssign<Rhs> + SubAssign<Rhs> {}
impl<T, Rhs> LocalAssignOps<Rhs> for T where T: AddAssign<Rhs> + SubAssign<Rhs> {}

/// Arithmetic operations that may or may not require communication.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these really involve communication? None of these traits are async, so this can't do that. That suggests that these are still local operations, but only those that involve value-based l-values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to distinguish operations that can be safely performed locally and confusing ones like Mul that someone can accidentally implement for secret shares as (a.0 * b.0, a.1 * b.1). Lets chat later today about it - I'd love to have better names for all of those traits

/// for example, multiplying field values is a local operation, while multiplying secret shares is not.
pub trait ArithmeticOps<Rhs = Self, Output = Self>:
LocalArithmeticOps<Rhs, Output>
+ LocalAssignOps<Rhs>
+ Mul<Rhs, Output = Output>
+ MulAssign<Rhs>
+ Neg<Output = Output>
{
}

impl<T, V> ArithmeticRefOps<V> for T
where
T: for<'a> Add<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ Neg<Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ for<'a> SubAssign<&'a Self>
+ Mul<V, Output = Self>,
V: SharedValue,
impl<T, Rhs, Output> ArithmeticOps<Rhs, Output> for T where
T: LocalArithmeticOps<Rhs, Output>
+ LocalAssignOps<Rhs>
+ Mul<Rhs, Output = Output>
+ MulAssign<Rhs>
+ Neg<Output = Output>
{
}

/// The trait for references which implement local arithmetic operations, taking the
/// second operand either by value or by reference.
///
/// The need for this trait is dictated by [`rust-issue`] that causes us not being able to constrain
/// references to `Self`. Once this issue is fixed, we can simply get rid of it.
///
/// This is automatically implemented for types which implement the operators. The idea is borrowed
/// from [`RefNum`] trait, but I couldn't really make it work with HRTB and secret shares. Primitive
/// types worked just fine though, so it is possible that it is another compiler issue.
///
/// [`RefNum`]: https://docs.rs/num/0.4.1/num/traits/trait.RefNum.html
/// [`rust-issue`]: https://github.com/rust-lang/rust/issues/20671
pub trait RefOps<'a, Base: 'a, R: 'a>:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that ArithmeticOps is the better choice of name here, with the trait above being for ArithmeticAssignOps or something like that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to keep Ref somewhere in the name to indicate that it only makes sense to implement/bound if you need l-value references

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArithmeticRefOps then, I guess.

LocalArithmeticOps<Base, Base>
+ LocalArithmeticOps<&'a Base, Base>
+ Mul<R, Output = Base>
+ Mul<&'a R, Output = Base>
{
}
impl<'a, T, Base: 'a, R: 'a> RefOps<'a, Base, R> for T where
T: LocalArithmeticOps<Base, Base>
+ LocalArithmeticOps<&'a Base, Base>
+ 'a
+ Mul<R, Output = Base>
+ Mul<&'a R, Output = Base>
{
}

Expand Down
5 changes: 4 additions & 1 deletion src/protocol/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub use input::SparseAggregateInputRow;
use super::{context::Context, sort::bitwise_to_onehot, step::BitOpStep, RecordId};
use crate::{
error::Error,
ff::{Field, GaloisField, Gf2, PrimeField, Serializable},
ff::{Field, GaloisField, Gf2, PrimeField, RefOps, Serializable},
protocol::{
context::{UpgradableContext, UpgradedContext, Validator},
modulus_conversion::convert_bits,
Expand Down Expand Up @@ -76,6 +76,7 @@ where
+ Serializable
+ DowngradeMalicious<Target = Replicated<F>>
+ 'static,
for<'r> &'r S: RefOps<'r, S, F>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = SB>,
SB: LinearSecretSharing<Gf2> + BasicProtocols<C::UpgradedContext<Gf2>, Gf2> + 'static,
F: PrimeField + ExtendableField,
Expand Down Expand Up @@ -130,6 +131,7 @@ where
I2: Stream<Item = Result<BitDecomposed<S>, Error>> + Send,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
for<'a> &'a S: RefOps<'a, S, F>,
{
let equality_check_ctx = ctx.narrow(&Step::ComputeEqualityChecks);

Expand Down Expand Up @@ -172,6 +174,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
for<'a> &'a S: RefOps<'a, S, F>,
{
let check_times_value_ctx = ctx.narrow(&Step::CheckTimesValue);

Expand Down
7 changes: 5 additions & 2 deletions src/protocol/attribution/apply_attribution_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::{
};
use crate::{
error::Error,
ff::{Field, PrimeField},
ff::{Field, PrimeField, RefOps},
protocol::{
boolean::{greater_than_constant, random_bits_generator::RandomBitsGenerator},
context::{Context, UpgradedContext},
Expand All @@ -37,6 +37,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F> + 'static,
for<'a> &'a S: RefOps<'a, S, F>,
{
if let Some(attribution_window_seconds) = attribution_window_seconds {
let mut t_deltas = prefix_sum_time_deltas(&ctx, input, stop_bits).await?;
Expand Down Expand Up @@ -90,6 +91,7 @@ where
F: Field,
C: Context,
T: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a T: RefOps<'a, T, F>,
{
let num_rows = input.len();

Expand All @@ -108,7 +110,7 @@ where
.map(|(i, ((prev, curr), b))| {
let c = t_delta_context.clone();
let record_id = RecordId::from(i);
let delta = curr.timestamp.clone() - &prev.timestamp;
let delta = &curr.timestamp - &prev.timestamp;
async move { delta.multiply(b, c, record_id).await }
}),
)
Expand Down Expand Up @@ -149,6 +151,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: RefOps<'a, S, F>,
{
let ctx = ctx.set_total_records(input.len());
let random_bits_generator =
Expand Down
8 changes: 6 additions & 2 deletions src/protocol/attribution/credit_capping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ipa_macros::Step;
use super::{do_the_binary_tree_thing, input::CreditCappingInputRow, prefix_or_binary_tree_style};
use crate::{
error::Error,
ff::{Field, PrimeField},
ff::{Field, PrimeField, RefOps},
protocol::{
basics::{if_else, SecureMul},
boolean::{greater_than_constant, random_bits_generator::RandomBitsGenerator},
Expand All @@ -35,6 +35,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: RefOps<'a, S, F>,
{
if cap == 1 {
return Ok(credit_capping_max_one(ctx, input)
Expand Down Expand Up @@ -241,6 +242,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: RefOps<'a, S, F>,
{
let share_of_cap = S::share_known_value(&ctx, F::truncate_from(cap));
let cap_ref = &share_of_cap;
Expand Down Expand Up @@ -295,6 +297,7 @@ where
F: PrimeField,
C: UpgradedContext<F, Share = S>,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: RefOps<'a, S, F>,
{
let ctx_ref = &ctx;
let ctx = ctx.set_total_records(prefix_summed_credits.len());
Expand Down Expand Up @@ -357,6 +360,7 @@ where
F: Field,
C: Context,
T: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a T: RefOps<'a, T, F>,
{
let num_rows = input.len();
let cap_share = T::share_known_value(&ctx, F::try_from(cap.into()).unwrap());
Expand Down Expand Up @@ -413,7 +417,7 @@ where
record_id,
next_credit_exceeds_cap,
&T::ZERO,
&(cap.clone() - next_prefix_summed_credit),
&(cap - next_prefix_summed_credit),
)
.await?,
cap,
Expand Down
5 changes: 4 additions & 1 deletion src/protocol/attribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use self::{
};
use crate::{
error::Error,
ff::{Field, Gf2, PrimeField, Serializable},
ff::{Field, Gf2, PrimeField, RefOps, Serializable},
helpers::query::IpaQueryConfig,
protocol::{
basics::SecureMul,
Expand Down Expand Up @@ -62,11 +62,13 @@ where
+ Serializable
+ DowngradeMalicious<Target = Replicated<F>>
+ 'static,
for<'a> &'a S: RefOps<'a, S, F>,
C::UpgradedContext<Gf2>: UpgradedContext<Gf2, Share = SB> + Context,
SB: LinearSecretSharing<Gf2>
+ BasicProtocols<C::UpgradedContext<Gf2>, Gf2>
+ DowngradeMalicious<Target = Replicated<Gf2>>
+ 'static,
for<'a> &'a SB: RefOps<'a, SB, Gf2>,
F: PrimeField + ExtendableField,
ShuffledPermutationWrapper<S, C::UpgradedContext<F>>: DowngradeMalicious<Target = Vec<u32>>,
{
Expand Down Expand Up @@ -403,6 +405,7 @@ async fn compute_helper_bits_gf2<C, S>(
where
C: Context,
S: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
for<'a> &'a S: RefOps<'a, S, Gf2>,
{
let narrowed_ctx = ctx
.narrow(&Step::ComputeHelperBits)
Expand Down
5 changes: 3 additions & 2 deletions src/protocol/basics/if_else.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
error::Error,
ff::Field,
ff::{Field, RefOps},
protocol::{basics::SecureMul, context::Context, RecordId},
secret_sharing::Linear as LinearSecretSharing,
};
Expand All @@ -19,6 +19,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + SecureMul<C>,
for<'a> &'a S: RefOps<'a, S, F>,
{
// If `condition` is a share of 1 (true), then
// = false_value + 1 * (true_value - false_value)
Expand All @@ -30,6 +31,6 @@ where
// = false_value
Ok(false_value.clone()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed a clone opportunity?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh my...

+ &condition
.multiply(&(true_value.clone() - false_value), ctx, record_id)
.multiply(&(true_value - false_value), ctx, record_id)
.await?)
}
11 changes: 5 additions & 6 deletions src/protocol/boolean/add_constant.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
error::Error,
ff::Field,
ff::{Field, RefOps},
protocol::{context::Context, step::BitOpStep, BasicProtocols, RecordId},
secret_sharing::Linear as LinearSecretSharing,
};
Expand Down Expand Up @@ -52,6 +52,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: RefOps<'a, S, F>,
{
let mut output = Vec::with_capacity(a.len() + 1);

Expand All @@ -67,6 +68,7 @@ where
S::share_known_value(&ctx, F::ONE) - &a[0]
};
output.push(result_bit);
let two = F::truncate_from(2_u8);

for (bit_index, bit) in a.iter().enumerate().skip(1) {
let mult_result = if last_carry_known_to_be_zero {
Expand Down Expand Up @@ -94,12 +96,9 @@ where
// the current bit of `a` + the current bit of `b` + the carry from the previous bit `-2*next_carry`
// Since the current bit of `b` has a known value (either 1 or 0), we either add a `share_of_one`, or nothing.
let result_bit = if next_bit_a_one {
-next_carry.clone() * F::truncate_from(2_u128)
+ &S::share_known_value(&ctx, F::ONE)
+ bit
+ &last_carry
-(&next_carry * two) + &S::share_known_value(&ctx, F::ONE) + bit + &last_carry
} else {
-next_carry.clone() * F::truncate_from(2_u128) + bit + &last_carry
-(&next_carry * two) + bit + &last_carry
};
output.push(result_bit);

Expand Down
8 changes: 5 additions & 3 deletions src/protocol/boolean/bitwise_equal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::iter::zip;

use crate::{
error::Error,
ff::{Field, Gf2},
ff::{Field, Gf2, RefOps},
protocol::{boolean::all_zeroes, context::Context, BasicProtocols, RecordId},
secret_sharing::Linear as LinearSecretSharing,
};
Expand All @@ -24,6 +24,7 @@ where
F: Field,
C: Context,
S: LinearSecretSharing<F> + BasicProtocols<C, F>,
for<'a> &'a S: RefOps<'a, S, F>,
{
assert!(a.len() <= 128);

Expand All @@ -36,7 +37,7 @@ where
if ((c >> i) & 1) == 0 {
a_bit.clone()
} else {
one.clone() - a_bit
&one - a_bit
}
})
.collect::<Vec<_>>();
Expand All @@ -56,10 +57,11 @@ pub async fn bitwise_equal_gf2<C, S>(
where
C: Context,
S: LinearSecretSharing<Gf2> + BasicProtocols<C, Gf2>,
for<'a> &'a S: RefOps<'a, S, Gf2>,
{
debug_assert!(a.len() == b.len());
let c = zip(a.iter(), b.iter())
.map(|(a_bit, b_bit)| a_bit.clone() - b_bit)
.map(|(a_bit, b_bit)| a_bit - b_bit)
.collect::<Vec<_>>();

all_zeroes(ctx, record_id, &c).await
Expand Down
Loading