diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index 29faff337..536cb7d14 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -10,6 +10,10 @@ use super::super::{ use crate::{ error::Error, helpers::{Direction, Message, ReceivingEnd, Role}, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + SharedValue, + }, }; #[derive(Step)] @@ -25,16 +29,11 @@ pub(crate) enum OPRFShuffleStep { /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn shuffle( - ctx: C, - batch_size: u32, - shares: (Sl, Sr), -) -> Result<(Vec, Vec), Error> +pub async fn shuffle(ctx: C, batch_size: u32, shares: I) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - S: Clone + Add + Message, + I: IntoIterator>, + S: SharedValue + Add + Message, for<'b> &'b S: Add, for<'b> &'b S: Add<&'b S, Output = S>, Standard: Distribution, @@ -49,24 +48,23 @@ where match ctx.role() { Role::H1 => run_h1(&ctx, batch_size, shares, pis, zs).await, Role::H2 => run_h2(&ctx, batch_size, shares, pis, zs).await, - Role::H3 => run_h3(&ctx, batch_size, shares, pis, zs).await, + Role::H3 => run_h3(&ctx, batch_size, pis, zs).await, } } -async fn run_h1( +async fn run_h1( ctx: &C, batch_size: u32, - (a, b): (Sl, Sr), + shares: I, (pi_31, pi_12): (Vec, Vec), (z_31, z_12): (Zl, Zr), ) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, + I: IntoIterator>, + S: SharedValue + Add + Message, Zl: IntoIterator, Zr: IntoIterator, - S: Clone + Add + Message, for<'a> &'a S: Add, Standard: Distribution, { @@ -80,7 +78,10 @@ where generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect(); // 2. Run computations - let mut x_1: Vec = add_single_shares(add_single_shares(a, b), z_12).collect(); + let a_add_b_iter = shares + .into_iter() + .map(|s: AdditiveShare| s.left().add(s.right())); + let mut x_1: Vec = add_single_shares(a_add_b_iter, z_12).collect(); apply_permutation(&pi_12, &mut x_1); let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); @@ -91,20 +92,19 @@ where Ok((a_hat, b_hat)) } -async fn run_h2( +async fn run_h2( ctx: &C, batch_size: u32, - (_b, c): (Sl, Sr), + shares: I, (pi_12, pi_23): (Vec, Vec), (z_12, z_23): (Zl, Zr), ) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, + I: IntoIterator>, + S: SharedValue + Add + Message, Zl: IntoIterator, Zr: IntoIterator, - S: Clone + Add + Message, for<'a> &'a S: Add, for<'a> &'a S: Add<&'a S, Output = S>, Standard: Distribution, @@ -115,6 +115,7 @@ where generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Left).collect(); // 2. Run computations + let c = shares.into_iter().map(|s| s.right()); let mut y_1: Vec = add_single_shares(c, z_12).collect(); apply_permutation(&pi_12, &mut y_1); @@ -153,17 +154,15 @@ where Ok((b_hat, c_hat)) } -async fn run_h3( +async fn run_h3( ctx: &C, batch_size: u32, - (_c, _a): (Sl, Sr), (pi_23, pi_31): (Vec, Vec), (z_23, z_31): (Zl, Zr), ) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, + S: SharedValue + Add + Message, Zl: IntoIterator, Zr: IntoIterator, S: Clone + Add + Message, diff --git a/src/query/runner/oprf_shuffle/query.rs b/src/query/runner/oprf_shuffle/query.rs index cdcac3a31..21365fc0b 100644 --- a/src/query/runner/oprf_shuffle/query.rs +++ b/src/query/runner/oprf_shuffle/query.rs @@ -9,6 +9,7 @@ use crate::{ }, one_off_fns::assert_stream_send, protocol::{context::Context, oprf::shuffle::shuffle}, + secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, }; pub struct OPRFShuffleQuery { @@ -39,11 +40,7 @@ impl OPRFShuffleQuery { )) })?; - let shares = ( - split_shares(input.as_slice(), Direction::Left), - split_shares(input.as_slice(), Direction::Right), - ); - + let shares = split_shares(&input); let (res_l, res_r) = shuffle(ctx, batch_size, shares).await?; Ok(combine_shares(res_l, res_r)) } @@ -51,9 +48,13 @@ impl OPRFShuffleQuery { fn split_shares( input_rows: &[ShuffleInputRow], - direction: Direction, -) -> impl Iterator + '_ { - let f = move |input_row| ShuffleShare::from_input_row(input_row, direction); +) -> impl Iterator> + '_ { + let f = move |input_row| { + let l = ShuffleShare::from_input_row(input_row, Direction::Left); + let r = ShuffleShare::from_input_row(input_row, Direction::Right); + ReplicatedSecretSharing::new(l, r) + }; + input_rows.iter().map(f) } diff --git a/src/query/runner/oprf_shuffle/share.rs b/src/query/runner/oprf_shuffle/share.rs index 168691a40..8d1edac2b 100644 --- a/src/query/runner/oprf_shuffle/share.rs +++ b/src/query/runner/oprf_shuffle/share.rs @@ -1,4 +1,4 @@ -use std::ops::Add; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use generic_array::GenericArray; use rand::{distributions::Standard, Rng}; @@ -7,13 +7,14 @@ use typenum::Unsigned; use super::ShuffleInputRow; use crate::{ ff::{Field, Gf32Bit, Gf40Bit, Gf8Bit, Serializable}, - helpers::{Direction, Message}, + helpers::Direction, + secret_sharing::SharedValue, }; pub type ShuffleShareMK = Gf40Bit; pub type ShuffleShareBK = Gf8Bit; pub type ShuffleShareF = Gf32Bit; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ShuffleShare { pub timestamp: ShuffleShareF, pub mk: ShuffleShareMK, @@ -22,6 +23,91 @@ pub struct ShuffleShare { pub trigger_value: ShuffleShareF, } +impl AddAssign for ShuffleShare { + fn add_assign(&mut self, rhs: Self) { + self.timestamp += rhs.timestamp; + self.mk += rhs.mk; + self.is_trigger_bit += rhs.is_trigger_bit; + self.breakdown_key += rhs.breakdown_key; + self.trigger_value += rhs.trigger_value; + } +} +impl Sub for ShuffleShare { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self::Output { + timestamp: self.timestamp - rhs.timestamp, + mk: self.mk - rhs.mk, + is_trigger_bit: self.is_trigger_bit - rhs.is_trigger_bit, + breakdown_key: self.breakdown_key - rhs.breakdown_key, + trigger_value: self.trigger_value - rhs.trigger_value, + } + } +} +impl SubAssign for ShuffleShare { + fn sub_assign(&mut self, rhs: Self) { + self.timestamp -= rhs.timestamp; + self.mk -= rhs.mk; + self.is_trigger_bit -= rhs.is_trigger_bit; + self.breakdown_key -= rhs.breakdown_key; + self.trigger_value -= rhs.trigger_value; + } +} +impl Neg for ShuffleShare { + type Output = Self; + + fn neg(self) -> Self::Output { + Self::Output { + timestamp: self.timestamp.neg(), + mk: self.mk.neg(), + is_trigger_bit: self.is_trigger_bit.neg(), + breakdown_key: self.breakdown_key.neg(), + trigger_value: self.trigger_value.neg(), + } + } +} +impl Mul for ShuffleShare { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self::Output { + timestamp: self.timestamp * rhs.timestamp, + mk: self.mk * rhs.mk, + is_trigger_bit: self.is_trigger_bit * rhs.is_trigger_bit, + breakdown_key: self.breakdown_key * rhs.breakdown_key, + trigger_value: self.trigger_value * rhs.trigger_value, + } + } +} +impl MulAssign for ShuffleShare { + fn mul_assign(&mut self, rhs: Self) { + self.timestamp *= rhs.timestamp; + self.mk *= rhs.mk; + self.is_trigger_bit *= rhs.is_trigger_bit; + self.breakdown_key *= rhs.breakdown_key; + self.trigger_value *= rhs.trigger_value; + } +} + +impl SharedValue for ShuffleShare { + type Storage = ::Storage; + + const BITS: u32 = ShuffleShareF::BITS + + ShuffleShareMK::BITS + + ShuffleShareF::BITS + + ShuffleShareBK::BITS + + ShuffleShareF::BITS; + + const ZERO: Self = Self { + timestamp: ShuffleShareF::ZERO, + mk: ShuffleShareMK::ZERO, + is_trigger_bit: ShuffleShareF::ZERO, + breakdown_key: ShuffleShareBK::ZERO, + trigger_value: ShuffleShareF::ZERO, + }; +} + impl ShuffleShare { #[must_use] pub fn from_input_row(input_row: &ShuffleInputRow, shared_with: Direction) -> Self { @@ -153,5 +239,3 @@ impl Serializable for ShuffleShare { } } } - -impl Message for ShuffleShare {} diff --git a/src/secret_sharing/replicated/semi_honest/additive_share.rs b/src/secret_sharing/replicated/semi_honest/additive_share.rs index 10b8b2b39..b1eb2a850 100644 --- a/src/secret_sharing/replicated/semi_honest/additive_share.rs +++ b/src/secret_sharing/replicated/semi_honest/additive_share.rs @@ -20,7 +20,6 @@ pub struct AdditiveShare(V, V); impl SecretSharing for AdditiveShare { const ZERO: Self = AdditiveShare::ZERO; } - impl LinearSecretSharing for AdditiveShare {} impl Debug for AdditiveShare {