Skip to content

Commit

Permalink
[oprf][shuffle] Reimplemented the input to shuffle function as an int…
Browse files Browse the repository at this point in the history
…erator of ReplicatedSecretShares

Instead of accepting a tuple of 2 iterators returning SharedValues
accept a single iterator of AdditiveShares: ReplicatedSecretShares

This should make interaction with TestWorld easier
  • Loading branch information
Artem Ignatyev committed Oct 23, 2023
1 parent adc46da commit 38ca599
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 38 deletions.
47 changes: 23 additions & 24 deletions src/protocol/oprf/shuffle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ use super::super::{
use crate::{
error::Error,
helpers::{Direction, Message, ReceivingEnd, Role},
secret_sharing::{
replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing},
SharedValue,
},
};

#[derive(Step)]
Expand All @@ -25,16 +29,11 @@ pub(crate) enum OPRFShuffleStep {

/// # Errors
/// Will propagate errors from transport and a few typecasts
pub async fn shuffle<C, S, Sl, Sr>(
ctx: C,
batch_size: u32,
shares: (Sl, Sr),
) -> Result<(Vec<S>, Vec<S>), Error>
pub async fn shuffle<C, I, S>(ctx: C, batch_size: u32, shares: I) -> Result<(Vec<S>, Vec<S>), Error>
where
C: Context,
Sl: IntoIterator<Item = S>,
Sr: IntoIterator<Item = S>,
S: Clone + Add<Output = S> + Message,
I: IntoIterator<Item = AdditiveShare<S>>,
S: SharedValue + Add<Output = S> + Message,
for<'b> &'b S: Add<S, Output = S>,
for<'b> &'b S: Add<&'b S, Output = S>,
Standard: Distribution<S>,
Expand All @@ -49,24 +48,23 @@ where
match ctx.role() {
Role::H1 => run_h1(&ctx, batch_size, shares, pis, zs).await,
Role::H2 => run_h2(&ctx, batch_size, shares, pis, zs).await,
Role::H3 => run_h3(&ctx, batch_size, shares, pis, zs).await,
Role::H3 => run_h3(&ctx, batch_size, pis, zs).await,
}
}

async fn run_h1<C, Sl, S, Sr, Zl, Zr>(
async fn run_h1<C, I, S, Zl, Zr>(
ctx: &C,
batch_size: u32,
(a, b): (Sl, Sr),
shares: I,
(pi_31, pi_12): (Vec<u32>, Vec<u32>),
(z_31, z_12): (Zl, Zr),
) -> Result<(Vec<S>, Vec<S>), Error>
where
C: Context,
Sl: IntoIterator<Item = S>,
Sr: IntoIterator<Item = S>,
I: IntoIterator<Item = AdditiveShare<S>>,
S: SharedValue + Add<Output = S> + Message,
Zl: IntoIterator<Item = S>,
Zr: IntoIterator<Item = S>,
S: Clone + Add<Output = S> + Message,
for<'a> &'a S: Add<Output = S>,
Standard: Distribution<S>,
{
Expand All @@ -80,7 +78,10 @@ where
generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect();

// 2. Run computations
let mut x_1: Vec<S> = add_single_shares(add_single_shares(a, b), z_12).collect();
let a_add_b_iter = shares
.into_iter()
.map(|s: AdditiveShare<S>| s.left().add(s.right()));
let mut x_1: Vec<S> = add_single_shares(a_add_b_iter, z_12).collect();
apply_permutation(&pi_12, &mut x_1);

let mut x_2: Vec<S> = add_single_shares(x_1, z_31).collect();
Expand All @@ -91,20 +92,19 @@ where
Ok((a_hat, b_hat))
}

async fn run_h2<C, S, Sl, Sr, Zl, Zr>(
async fn run_h2<C, I, S, Zl, Zr>(
ctx: &C,
batch_size: u32,
(_b, c): (Sl, Sr),
shares: I,
(pi_12, pi_23): (Vec<u32>, Vec<u32>),
(z_12, z_23): (Zl, Zr),
) -> Result<(Vec<S>, Vec<S>), Error>
where
C: Context,
Sl: IntoIterator<Item = S>,
Sr: IntoIterator<Item = S>,
I: IntoIterator<Item = AdditiveShare<S>>,
S: SharedValue + Add<Output = S> + Message,
Zl: IntoIterator<Item = S>,
Zr: IntoIterator<Item = S>,
S: Clone + Add<Output = S> + Message,
for<'a> &'a S: Add<S, Output = S>,
for<'a> &'a S: Add<&'a S, Output = S>,
Standard: Distribution<S>,
Expand All @@ -115,6 +115,7 @@ where
generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Left).collect();

// 2. Run computations
let c = shares.into_iter().map(|s| s.right());
let mut y_1: Vec<S> = add_single_shares(c, z_12).collect();
apply_permutation(&pi_12, &mut y_1);

Expand Down Expand Up @@ -153,17 +154,15 @@ where
Ok((b_hat, c_hat))
}

async fn run_h3<C, S, Sl, Sr, Zl, Zr>(
async fn run_h3<C, S, Zl, Zr>(
ctx: &C,
batch_size: u32,
(_c, _a): (Sl, Sr),
(pi_23, pi_31): (Vec<u32>, Vec<u32>),
(z_23, z_31): (Zl, Zr),
) -> Result<(Vec<S>, Vec<S>), Error>
where
C: Context,
Sl: IntoIterator<Item = S>,
Sr: IntoIterator<Item = S>,
S: SharedValue + Add<Output = S> + Message,
Zl: IntoIterator<Item = S>,
Zr: IntoIterator<Item = S>,
S: Clone + Add<Output = S> + Message,
Expand Down
17 changes: 9 additions & 8 deletions src/query/runner/oprf_shuffle/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
},
one_off_fns::assert_stream_send,
protocol::{context::Context, oprf::shuffle::shuffle},
secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing},
};

pub struct OPRFShuffleQuery {
Expand Down Expand Up @@ -39,21 +40,21 @@ impl OPRFShuffleQuery {
))
})?;

let shares = (
split_shares(input.as_slice(), Direction::Left),
split_shares(input.as_slice(), Direction::Right),
);

let shares = split_shares(&input);
let (res_l, res_r) = shuffle(ctx, batch_size, shares).await?;
Ok(combine_shares(res_l, res_r))
}
}

fn split_shares(
input_rows: &[ShuffleInputRow],
direction: Direction,
) -> impl Iterator<Item = ShuffleShare> + '_ {
let f = move |input_row| ShuffleShare::from_input_row(input_row, direction);
) -> impl Iterator<Item = AdditiveShare<ShuffleShare>> + '_ {
let f = move |input_row| {
let l = ShuffleShare::from_input_row(input_row, Direction::Left);
let r = ShuffleShare::from_input_row(input_row, Direction::Right);
ReplicatedSecretSharing::new(l, r)
};

input_rows.iter().map(f)
}

Expand Down
94 changes: 89 additions & 5 deletions src/query/runner/oprf_shuffle/share.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::Add;
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use generic_array::GenericArray;
use rand::{distributions::Standard, Rng};
Expand All @@ -7,13 +7,14 @@ use typenum::Unsigned;
use super::ShuffleInputRow;
use crate::{
ff::{Field, Gf32Bit, Gf40Bit, Gf8Bit, Serializable},
helpers::{Direction, Message},
helpers::Direction,
secret_sharing::SharedValue,
};
pub type ShuffleShareMK = Gf40Bit;
pub type ShuffleShareBK = Gf8Bit;
pub type ShuffleShareF = Gf32Bit;

#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ShuffleShare {
pub timestamp: ShuffleShareF,
pub mk: ShuffleShareMK,
Expand All @@ -22,6 +23,91 @@ pub struct ShuffleShare {
pub trigger_value: ShuffleShareF,
}

impl AddAssign for ShuffleShare {
fn add_assign(&mut self, rhs: Self) {
self.timestamp += rhs.timestamp;
self.mk += rhs.mk;
self.is_trigger_bit += rhs.is_trigger_bit;
self.breakdown_key += rhs.breakdown_key;
self.trigger_value += rhs.trigger_value;
}
}
impl Sub for ShuffleShare {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
Self::Output {
timestamp: self.timestamp - rhs.timestamp,
mk: self.mk - rhs.mk,
is_trigger_bit: self.is_trigger_bit - rhs.is_trigger_bit,
breakdown_key: self.breakdown_key - rhs.breakdown_key,
trigger_value: self.trigger_value - rhs.trigger_value,
}
}
}
impl SubAssign for ShuffleShare {
fn sub_assign(&mut self, rhs: Self) {
self.timestamp -= rhs.timestamp;
self.mk -= rhs.mk;
self.is_trigger_bit -= rhs.is_trigger_bit;
self.breakdown_key -= rhs.breakdown_key;
self.trigger_value -= rhs.trigger_value;
}
}
impl Neg for ShuffleShare {
type Output = Self;

fn neg(self) -> Self::Output {
Self::Output {
timestamp: self.timestamp.neg(),
mk: self.mk.neg(),
is_trigger_bit: self.is_trigger_bit.neg(),
breakdown_key: self.breakdown_key.neg(),
trigger_value: self.trigger_value.neg(),
}
}
}
impl Mul for ShuffleShare {
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
Self::Output {
timestamp: self.timestamp * rhs.timestamp,
mk: self.mk * rhs.mk,
is_trigger_bit: self.is_trigger_bit * rhs.is_trigger_bit,
breakdown_key: self.breakdown_key * rhs.breakdown_key,
trigger_value: self.trigger_value * rhs.trigger_value,
}
}
}
impl MulAssign for ShuffleShare {
fn mul_assign(&mut self, rhs: Self) {
self.timestamp *= rhs.timestamp;
self.mk *= rhs.mk;
self.is_trigger_bit *= rhs.is_trigger_bit;
self.breakdown_key *= rhs.breakdown_key;
self.trigger_value *= rhs.trigger_value;
}
}

impl SharedValue for ShuffleShare {
type Storage = <ShuffleShareF as SharedValue>::Storage;

const BITS: u32 = ShuffleShareF::BITS
+ ShuffleShareMK::BITS
+ ShuffleShareF::BITS
+ ShuffleShareBK::BITS
+ ShuffleShareF::BITS;

const ZERO: Self = Self {
timestamp: ShuffleShareF::ZERO,
mk: ShuffleShareMK::ZERO,
is_trigger_bit: ShuffleShareF::ZERO,
breakdown_key: ShuffleShareBK::ZERO,
trigger_value: ShuffleShareF::ZERO,
};
}

impl ShuffleShare {
#[must_use]
pub fn from_input_row(input_row: &ShuffleInputRow, shared_with: Direction) -> Self {
Expand Down Expand Up @@ -153,5 +239,3 @@ impl Serializable for ShuffleShare {
}
}
}

impl Message for ShuffleShare {}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ pub struct AdditiveShare<V: SharedValue>(V, V);
impl<V: SharedValue> SecretSharing<V> for AdditiveShare<V> {
const ZERO: Self = AdditiveShare::ZERO;
}

impl<V: SharedValue> LinearSecretSharing<V> for AdditiveShare<V> {}

impl<V: SharedValue + Debug> Debug for AdditiveShare<V> {
Expand Down

0 comments on commit 38ca599

Please sign in to comment.