From dae74499631fa3974541154e1a8ffa461526efac Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Wed, 11 Oct 2023 14:28:22 -0700 Subject: [PATCH 01/15] [oprf][shuffle] OPRF Shuffle using a 2-round 4-message shuffle protocol 1. New query type 2. Implementation of the protocol (not sharded) --- .../transport/{query.rs => query/mod.rs} | 5 + src/helpers/transport/query/oprf_shuffle.rs | 17 + src/net/http_serde.rs | 8 + src/protocol/context/mod.rs | 2 + src/protocol/context/oprf.rs | 88 +++ src/protocol/mod.rs | 1 + src/protocol/oprf/mod.rs | 610 ++++++++++++++++++ src/query/executor.rs | 16 +- src/query/runner/mod.rs | 6 +- src/query/runner/oprf_shuffle.rs | 48 ++ 10 files changed, 798 insertions(+), 3 deletions(-) rename src/helpers/transport/{query.rs => query/mod.rs} (98%) create mode 100644 src/helpers/transport/query/oprf_shuffle.rs create mode 100644 src/protocol/context/oprf.rs create mode 100644 src/protocol/oprf/mod.rs create mode 100644 src/query/runner/oprf_shuffle.rs diff --git a/src/helpers/transport/query.rs b/src/helpers/transport/query/mod.rs similarity index 98% rename from src/helpers/transport/query.rs rename to src/helpers/transport/query/mod.rs index ab4e0761b..7a5f5c3bb 100644 --- a/src/helpers/transport/query.rs +++ b/src/helpers/transport/query/mod.rs @@ -1,3 +1,5 @@ +pub mod oprf_shuffle; + use std::{ fmt::{Debug, Display, Formatter}, num::NonZeroU32, @@ -206,6 +208,7 @@ pub enum QueryType { MaliciousIpa(IpaQueryConfig), SemiHonestSparseAggregate(SparseAggregateQueryConfig), MaliciousSparseAggregate(SparseAggregateQueryConfig), + OPRFShuffle(oprf_shuffle::QueryConfig), } impl QueryType { @@ -214,6 +217,7 @@ impl QueryType { pub const MALICIOUS_IPA_STR: &'static str = "malicious-ipa"; pub const SEMIHONEST_AGGREGATE_STR: &'static str = "semihonest-sparse-aggregate"; pub const MALICIOUS_AGGREGATE_STR: &'static str = "malicious-sparse-aggregate"; + pub const OPRF_SHUFFLE_STR: &'static str = "oprf-shuffle"; } /// TODO: should this `AsRef` impl (used for `Substep`) take into account config of IPA? @@ -226,6 +230,7 @@ impl AsRef for QueryType { QueryType::MaliciousIpa(_) => Self::MALICIOUS_IPA_STR, QueryType::SemiHonestSparseAggregate(_) => Self::SEMIHONEST_AGGREGATE_STR, QueryType::MaliciousSparseAggregate(_) => Self::MALICIOUS_AGGREGATE_STR, + QueryType::OPRFShuffle(_) => Self::OPRF_SHUFFLE_STR, } } } diff --git a/src/helpers/transport/query/oprf_shuffle.rs b/src/helpers/transport/query/oprf_shuffle.rs new file mode 100644 index 000000000..86a23884f --- /dev/null +++ b/src/helpers/transport/query/oprf_shuffle.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))] +pub struct QueryConfig { + pub bk_size: u8, // breakdown key size bits + pub tv_size: u8, // trigger value size bits +} + +impl Default for QueryConfig { + fn default() -> Self { + Self { + bk_size: 40, + tv_size: 40, + } + } +} diff --git a/src/net/http_serde.rs b/src/net/http_serde.rs index faa2ec13d..eaa86a65e 100644 --- a/src/net/http_serde.rs +++ b/src/net/http_serde.rs @@ -139,6 +139,10 @@ pub mod query { let Query(q) = req.extract().await?; Ok(QueryType::MaliciousSparseAggregate(q)) } + QueryType::OPRF_SHUFFLE_STR => { + let Query(q) = req.extract().await?; + Ok(QueryType::OPRFShuffle(q)) + } other => Err(Error::bad_query_value("query_type", other)), }?; Ok(QueryConfigQueryParams(QueryConfig { @@ -188,6 +192,10 @@ pub mod query { Ok(()) } + QueryType::OPRFShuffle(config) => { + write!(f, "&bk_size={}&tv_size={}", config.bk_size, config.tv_size)?; + Ok(()) + } } } } diff --git a/src/protocol/context/mod.rs b/src/protocol/context/mod.rs index 48b1efec9..c2d4cc6aa 100644 --- a/src/protocol/context/mod.rs +++ b/src/protocol/context/mod.rs @@ -1,4 +1,5 @@ pub mod malicious; +pub mod oprf; pub mod prss; pub mod semi_honest; pub mod upgrade; @@ -8,6 +9,7 @@ use std::{num::NonZeroUsize, sync::Arc}; use async_trait::async_trait; pub use malicious::{Context as MaliciousContext, Upgraded as UpgradedMaliciousContext}; +pub use oprf::Context as OPRFContext; use prss::{InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness}; pub use semi_honest::{Context as SemiHonestContext, Upgraded as UpgradedSemiHonestContext}; pub use upgrade::{UpgradeContext, UpgradeToMalicious}; diff --git a/src/protocol/context/oprf.rs b/src/protocol/context/oprf.rs new file mode 100644 index 000000000..12eb29439 --- /dev/null +++ b/src/protocol/context/oprf.rs @@ -0,0 +1,88 @@ +use std::num::NonZeroUsize; + +use crate::{ + helpers::{Gateway, Message, ReceivingEnd, Role, SendingEnd, TotalRecords}, + protocol::{ + context::{ + Base, InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness, + }, + prss::Endpoint as PrssEndpoint, + step::{Gate, Step, StepNarrow}, + }, + seq_join::SeqJoin, +}; + +#[derive(Clone)] +pub struct Context<'a> { + inner: Base<'a>, +} + +impl<'a> Context<'a> { + pub fn new(participant: &'a PrssEndpoint, gateway: &'a Gateway) -> Self { + Self { + inner: Base::new(participant, gateway), + } + } + + #[cfg(test)] + #[must_use] + pub fn from_base(base: Base<'a>) -> Self { + Self { inner: base } + } +} + +impl<'a> super::Context for Context<'a> { + fn role(&self) -> Role { + self.inner.role() + } + + fn gate(&self) -> &Gate { + self.inner.gate() + } + + fn narrow(&self, step: &S) -> Self + where + Gate: StepNarrow, + { + Self { + inner: self.inner.narrow(step), + } + } + + fn set_total_records>(&self, total_records: T) -> Self { + Self { + inner: self.inner.set_total_records(total_records), + } + } + + fn total_records(&self) -> TotalRecords { + self.inner.total_records() + } + + fn prss(&self) -> InstrumentedIndexedSharedRandomness<'_> { + self.inner.prss() + } + + fn prss_rng( + &self, + ) -> ( + InstrumentedSequentialSharedRandomness, + InstrumentedSequentialSharedRandomness, + ) { + self.inner.prss_rng() + } + + fn send_channel(&self, role: Role) -> SendingEnd { + self.inner.send_channel(role) + } + + fn recv_channel(&self, role: Role) -> ReceivingEnd { + self.inner.recv_channel(role) + } +} + +impl<'a> SeqJoin for Context<'a> { + fn active_work(&self) -> NonZeroUsize { + self.inner.active_work() + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 7931a1e37..6bb6ae1fc 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -6,6 +6,7 @@ pub mod context; pub mod dp; pub mod ipa; pub mod modulus_conversion; +pub mod oprf; #[cfg(feature = "descriptive-gate")] pub mod prf_sharding; pub mod prss; diff --git a/src/protocol/oprf/mod.rs b/src/protocol/oprf/mod.rs new file mode 100644 index 000000000..4d1f0f039 --- /dev/null +++ b/src/protocol/oprf/mod.rs @@ -0,0 +1,610 @@ +use std::ops::Add; + +use futures_util::try_join; +use generic_array::GenericArray; +use ipa_macros::Step; +use rand::{seq::SliceRandom, Rng}; +use typenum::Unsigned; + +use super::{context::Context, ipa::IPAInputRow, RecordId}; +use crate::{ + error::Error, + ff::{Field, Gf32Bit, Gf40Bit, Gf8Bit, Serializable}, + helpers::{query::oprf_shuffle::QueryConfig, Direction, Message, ReceivingEnd, Role}, +}; + +type OprfMK = Gf40Bit; +type OprfBK = Gf8Bit; +type OprfF = Gf32Bit; + +pub type OPRFInputRow = IPAInputRow; + +#[derive(Debug, Clone, Copy)] +pub struct OPRFShuffleSingleShare { + pub timestamp: OprfF, + pub mk: OprfMK, + pub is_trigger_bit: OprfF, + pub breakdown_key: OprfBK, + pub trigger_value: OprfF, +} + +impl OPRFShuffleSingleShare { + #[must_use] + pub fn from_input_row(input_row: &OPRFInputRow, shared_with: Direction) -> Self { + // Relying on the fact that all SharedValue(s) are Copy + match shared_with { + Direction::Left => Self { + timestamp: input_row.timestamp.as_tuple().1, + mk: input_row.mk_shares.as_tuple().1, + is_trigger_bit: input_row.is_trigger_bit.as_tuple().1, + breakdown_key: input_row.breakdown_key.as_tuple().1, + trigger_value: input_row.trigger_value.as_tuple().1, + }, + + Direction::Right => Self { + timestamp: input_row.timestamp.as_tuple().0, + mk: input_row.mk_shares.as_tuple().0, + is_trigger_bit: input_row.is_trigger_bit.as_tuple().0, + breakdown_key: input_row.breakdown_key.as_tuple().0, + trigger_value: input_row.trigger_value.as_tuple().0, + }, + } + } + + #[must_use] + pub fn to_input_row(self, rhs: Self) -> OPRFInputRow { + OPRFInputRow { + timestamp: (self.timestamp, rhs.timestamp).into(), + mk_shares: (self.mk, rhs.mk).into(), + is_trigger_bit: (self.is_trigger_bit, rhs.is_trigger_bit).into(), + breakdown_key: (self.breakdown_key, rhs.breakdown_key).into(), + trigger_value: (self.trigger_value, rhs.trigger_value).into(), + } + } + + pub fn sample(rng: &mut R) -> Self { + Self { + timestamp: OprfF::truncate_from(rng.gen::()), + mk: OprfMK::truncate_from(rng.gen::()), + is_trigger_bit: OprfF::truncate_from(rng.gen::()), + breakdown_key: OprfBK::truncate_from(rng.gen::()), + trigger_value: OprfF::truncate_from(rng.gen::()), + } + } +} + +impl Add for OPRFShuffleSingleShare { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + timestamp: self.timestamp + rhs.timestamp, + mk: self.mk + rhs.mk, + is_trigger_bit: self.is_trigger_bit + rhs.is_trigger_bit, + breakdown_key: self.breakdown_key + rhs.breakdown_key, + trigger_value: self.trigger_value + rhs.trigger_value, + } + } +} + +impl Add for &OPRFShuffleSingleShare { + type Output = OPRFShuffleSingleShare; + + fn add(self, rhs: Self) -> Self::Output { + *self + *rhs // Relies on Copy + } +} + +impl Serializable for OPRFShuffleSingleShare { + type Size = <::Size as Add< + <::Size as Add< + <::Size as Add< + <::Size as Add<::Size>>::Output, + >>::Output, + >>::Output, + >>::Output; + + fn serialize(&self, buf: &mut GenericArray) { + let mk_sz = ::Size::USIZE; + let bk_sz = ::Size::USIZE; + let f_sz = ::Size::USIZE; + + self.timestamp + .serialize(GenericArray::from_mut_slice(&mut buf[..f_sz])); + self.mk + .serialize(GenericArray::from_mut_slice(&mut buf[f_sz..f_sz + mk_sz])); + self.is_trigger_bit.serialize(GenericArray::from_mut_slice( + &mut buf[f_sz + mk_sz..f_sz + mk_sz + f_sz], + )); + self.breakdown_key.serialize(GenericArray::from_mut_slice( + &mut buf[f_sz + mk_sz + f_sz..f_sz + mk_sz + f_sz + bk_sz], + )); + self.trigger_value.serialize(GenericArray::from_mut_slice( + &mut buf[f_sz + mk_sz + f_sz + bk_sz..], + )); + } + + fn deserialize(buf: &GenericArray) -> Self { + let mk_sz = ::Size::USIZE; + let bk_sz = ::Size::USIZE; + let f_sz = ::Size::USIZE; + + let timestamp = OprfF::deserialize(GenericArray::from_slice(&buf[..f_sz])); + let mk = OprfMK::deserialize(GenericArray::from_slice(&buf[f_sz..f_sz + mk_sz])); + let is_trigger_bit = OprfF::deserialize(GenericArray::from_slice( + &buf[f_sz + mk_sz..f_sz + mk_sz + f_sz], + )); + let breakdown_key = OprfBK::deserialize(GenericArray::from_slice( + &buf[f_sz + mk_sz + f_sz..f_sz + mk_sz + f_sz + bk_sz], + )); + let trigger_value = OprfF::deserialize(GenericArray::from_slice( + &buf[f_sz + mk_sz + f_sz + bk_sz..], + )); + Self { + timestamp, + mk, + is_trigger_bit, + breakdown_key, + trigger_value, + } + } +} + +impl Message for OPRFShuffleSingleShare {} + +#[derive(Step)] +pub(crate) enum OPRFShuffleStep { + GenerateAHat, + GenerateBHat, + GeneratePi12, + GeneratePi23, + GeneratePi31, + GenerateZ12, + GenerateZ23, + GenerateZ31, + TransferCHat1, + TransferCHat2, + TransferX2, + TransferY1, +} + +/// # Errors +/// Will propagate errors from transport and a few typecasts +pub async fn oprf_shuffle( + ctx: C, + input_rows: &[OPRFInputRow], + _config: QueryConfig, +) -> Result, Error> { + let role = ctx.role(); + let batch_size = u32::try_from(input_rows.len()).map_err(|_e| { + Error::FieldValueTruncation(format!( + "Cannot truncate the number of input rows {} to u32", + input_rows.len(), + )) + })?; + + let my_shares = split_shares_and_get_left(input_rows); + let shared_with_rhs = split_shares_and_get_right(input_rows); + + match role { + Role::H1 => run_h1(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, + Role::H2 => run_h2(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, + Role::H3 => run_h3(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, + } +} + +async fn run_h1( + ctx: &C, + role: &Role, + batch_size: u32, + my_shares: L, + rhs_shared: R, +) -> Result, Error> +where + C: Context, + L: IntoIterator, + R: IntoIterator, +{ + let a = my_shares; + let b = rhs_shared; + // + // 1. Generate permutations + let pi_12 = generate_pseudorandom_permutation( + batch_size, + ctx, + &OPRFShuffleStep::GeneratePi12, + Direction::Right, + ); + let pi_31 = generate_pseudorandom_permutation( + batch_size, + ctx, + &OPRFShuffleStep::GeneratePi31, + Direction::Left, + ); + // + // 2. Generate random tables + let z_12 = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateZ12, + Direction::Right, + ); + let z_31 = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateZ31, + Direction::Left, + ); + + let a_hat = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateAHat, + Direction::Left, + ); + + let b_hat = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateBHat, + Direction::Right, + ); + + // 3. Run computations + let x_1_arg = add_single_shares(add_single_shares(a, b), z_12); + let x_1 = permute(&pi_12, x_1_arg); + + let x_2_arg = add_single_shares(x_1.iter(), z_31.iter()); + let x_2 = permute(&pi_31, x_2_arg); + + send_to_peer( + ctx, + &OPRFShuffleStep::TransferX2, + role.peer(Direction::Right), + x_2.clone(), + ) + .await?; + + let res = combine_shares(a_hat, b_hat); + Ok(res) +} + +async fn run_h2( + ctx: &C, + role: &Role, + batch_size: u32, + _my_shares: L, + shared_with_rhs: R, +) -> Result, Error> +where + C: Context, + L: IntoIterator, + R: IntoIterator, +{ + let c = shared_with_rhs; + + // 1. Generate permutations + let pi_12 = generate_pseudorandom_permutation( + batch_size, + ctx, + &OPRFShuffleStep::GeneratePi12, + Direction::Left, + ); + + let pi_23 = generate_pseudorandom_permutation( + batch_size, + ctx, + &OPRFShuffleStep::GeneratePi23, + Direction::Right, + ); + + // 2. Generate random tables + let z_12 = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateZ12, + Direction::Left, + ); + + let z_23 = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateZ23, + Direction::Right, + ); + + let b_hat = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateBHat, + Direction::Left, + ); + // + // 3. Run computations + let y_1_arg = add_single_shares(c, z_12.into_iter()); + let y_1 = permute(&pi_12, y_1_arg); + + let ((), x_2) = try_join!( + send_to_peer( + ctx, + &OPRFShuffleStep::TransferY1, + role.peer(Direction::Right), + y_1, + ), + receive_from_peer( + ctx, + &OPRFShuffleStep::TransferX2, + role.peer(Direction::Left), + batch_size, + ), + )?; + + let x_3_arg = add_single_shares(x_2.into_iter(), z_23.into_iter()); + let x_3 = permute(&pi_23, x_3_arg); + let c_hat_1 = add_single_shares(x_3.iter(), b_hat.iter()).collect::>(); + + let ((), c_hat_2) = try_join!( + send_to_peer( + ctx, + &OPRFShuffleStep::TransferCHat1, + role.peer(Direction::Right), + c_hat_1.clone(), + ), + receive_from_peer( + ctx, + &OPRFShuffleStep::TransferCHat2, + role.peer(Direction::Right), + batch_size, + ) + )?; + + let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); + let res = combine_shares(b_hat, c_hat); + Ok(res) +} + +async fn run_h3( + ctx: &C, + role: &Role, + batch_size: u32, + _my_shares: L, + _shared_with_rhs: R, +) -> Result, Error> +where + C: Context, + L: IntoIterator, + R: IntoIterator, +{ + // H3 does not need any secret shares. + // Its "C" shares are processed by helper2, Its "A" shares are processed by helper 1 + /* + let c = my_shares; + let a = rhs_shared; + */ + + // 1. Generate permutations + let pi_23 = generate_pseudorandom_permutation( + batch_size, + ctx, + &OPRFShuffleStep::GeneratePi23, + Direction::Left, + ); + let pi_31 = generate_pseudorandom_permutation( + batch_size, + ctx, + &OPRFShuffleStep::GeneratePi31, + Direction::Right, + ); + + // 2. Generate random tables + let z_23 = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateZ23, + Direction::Left, + ); + + let z_31 = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateZ31, + Direction::Right, + ); + + let a_hat = generate_random_table_with_peer( + batch_size, + ctx, + &OPRFShuffleStep::GenerateAHat, + Direction::Right, + ); + + // 3. Run computations + let y_1 = receive_from_peer( + ctx, + &OPRFShuffleStep::TransferY1, + role.peer(Direction::Left), + batch_size, + ) + .await?; + + let y_2_arg = add_single_shares(y_1, z_31); + let y_2 = permute(&pi_31, y_2_arg); + let y_3_arg = add_single_shares(y_2, z_23); + let y_3 = permute(&pi_23, y_3_arg); + let c_hat_2 = add_single_shares(y_3, a_hat.clone()).collect::>(); + + let (c_hat_1, ()) = try_join!( + receive_from_peer( + ctx, + &OPRFShuffleStep::TransferCHat1, + role.peer(Direction::Left), + batch_size, + ), + send_to_peer( + ctx, + &OPRFShuffleStep::TransferCHat2, + role.peer(Direction::Left), + c_hat_2.clone(), + ) + )?; + + let c_hat = add_single_shares(c_hat_1, c_hat_2); + let res = combine_shares(c_hat, a_hat); + Ok(res) +} + +// ------------------------------------------------------------------------------------------------------------- // + +fn split_shares_and_get_left( + input_rows: &[OPRFInputRow], +) -> impl Iterator + '_ { + let lhs = input_rows + .iter() + .map(|input_row| OPRFShuffleSingleShare::from_input_row(input_row, Direction::Left)); + lhs +} + +fn split_shares_and_get_right( + input_rows: &[OPRFInputRow], +) -> impl Iterator + '_ { + let rhs = input_rows + .iter() + .map(|input_row| OPRFShuffleSingleShare::from_input_row(input_row, Direction::Right)); + rhs +} + +fn combine_shares(l: L, r: R) -> Vec +where + L: IntoIterator, + R: IntoIterator, +{ + l.into_iter() + .zip(r) + .map(|(l, r)| l.to_input_row(r)) + .collect::>() +} + +fn add_single_shares<'a, T, L, R>(l: L, r: R) -> impl Iterator +where + T: Add + 'a, + L: IntoIterator, + R: IntoIterator, +{ + l.into_iter().zip(r).map(|(a, b)| a + b) +} + +fn generate_random_table_with_peer( + batch_size: u32, + ctx: &C, + step: &OPRFShuffleStep, + peer: Direction, +) -> Vec +where + C: Context, +{ + let narrow_step = ctx.narrow(step); + let rngs = narrow_step.prss_rng(); + let mut rng = match peer { + Direction::Left => rngs.0, + Direction::Right => rngs.1, + }; + + let iter = std::iter::from_fn(move || Some(OPRFShuffleSingleShare::sample(&mut rng))) + .take(batch_size as usize); + + // NOTE: I'd like to return an Iterator from here as there is really no need to allocate batch_size of items. + // It'd be better to just pass the iterator to add_single_shares function. + // But I was unable to figure the return type. The type checker was saying something + // about Box> and that rng is not Send, + // but it is currently beyond by level of knowledge. + // So, any advice is appreciated + iter.collect::>() +} +// +// ---------------------------- helper communication ------------------------------------ // + +async fn send_to_peer>( + ctx: &C, + step: &OPRFShuffleStep, + role: Role, + items: I, +) -> Result<(), Error> { + let send_channel = ctx.narrow(step).send_channel(role); + for (record_id, row) in items.into_iter().enumerate() { + send_channel.send(RecordId::from(record_id), row).await?; + } + Ok(()) +} + +async fn receive_from_peer( + ctx: &C, + step: &OPRFShuffleStep, + role: Role, + batch_size: u32, +) -> Result, Error> { + let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); + + let mut output: Vec = Vec::with_capacity(batch_size as usize); + for record_id in 0..batch_size { + let msg = receive_channel.receive(RecordId::from(record_id)).await?; + output.push(msg); + } + + Ok(output) +} + +// --------------------------- permutation-related function --------------------------------------------- // + +fn generate_pseudorandom_permutation( + batch_size: u32, + ctx: &C, + step: &OPRFShuffleStep, + with_peer_on_the: Direction, +) -> Vec { + let narrow_context = ctx.narrow(step); + let rng = narrow_context.prss_rng(); + let mut rng = match with_peer_on_the { + Direction::Left => rng.0, + Direction::Right => rng.1, + }; + + let mut permutation = (0..batch_size).collect::>(); + permutation.shuffle(&mut rng); + permutation +} + +fn permute( + permutation: &[u32], + input: impl Iterator, +) -> Vec { + let mut rows = input.collect::>(); + apply(permutation, &mut rows); + rows +} + +use bitvec::bitvec; +use embed_doc_image::embed_doc_image; + +#[embed_doc_image("apply", "images/sort/apply.png")] +/// Permutation reorders (1, 2, . . . , m) into (σ(1), σ(2), . . . , σ(m)). +/// For example, if σ(1) = 2, σ(2) = 3, σ(3) = 1, and σ(4) = 0, an input (A, B, C, D) is reordered into (C, D, B, A) by σ. +/// +/// ![Apply steps][apply] +fn apply(permutation: &[u32], values: &mut [T]) { + // NOTE: This is copypasta from crate::protocol::sort + debug_assert!(permutation.len() == values.len()); + let mut permuted = bitvec![0; permutation.len()]; + + for i in 0..permutation.len() { + if !permuted[i] { + let mut pos_i = i; + let mut pos_j = permutation[pos_i] as usize; + while pos_j != i { + values.swap(pos_i, pos_j); + permuted.set(pos_j, true); + pos_i = pos_j; + pos_j = permutation[pos_i] as usize; + } + } + } +} diff --git a/src/query/executor.rs b/src/query/executor.rs index d99e6df4f..7f14a8498 100644 --- a/src/query/executor.rs +++ b/src/query/executor.rs @@ -25,12 +25,12 @@ use crate::{ }, hpke::{KeyPair, KeyRegistry}, protocol::{ - context::{MaliciousContext, SemiHonestContext}, + context::{MaliciousContext, OPRFContext, SemiHonestContext}, prss::Endpoint as PrssEndpoint, step::{Gate, StepNarrow}, }, query::{ - runner::{IpaQuery, QueryResult, SparseAggregateQuery}, + runner::{IpaQuery, OPRFShuffleQuery, QueryResult, SparseAggregateQuery}, state::RunningQuery, }, }; @@ -202,6 +202,18 @@ pub fn execute( }, ) } + (QueryType::OPRFShuffle(oprf_shuffle_config), _) => do_query( + config, + gateway, + input, + move |prss, gateway, config, input| { + let ctx = OPRFContext::new(prss, gateway); + let query = OPRFShuffleQuery::new(oprf_shuffle_config) + .execute(ctx, config.size, input) + .then(|res| ready(res.map(|out| Box::new(out) as Box))); + Box::pin(query) + }, + ), } } diff --git a/src/query/runner/mod.rs b/src/query/runner/mod.rs index bcba34275..3ec724fa5 100644 --- a/src/query/runner/mod.rs +++ b/src/query/runner/mod.rs @@ -1,12 +1,16 @@ mod aggregate; mod ipa; +pub mod oprf_shuffle; + #[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod test_multiply; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use test_multiply::execute_test_multiply; -pub(super) use self::{aggregate::SparseAggregateQuery, ipa::IpaQuery}; +pub(super) use self::{ + aggregate::SparseAggregateQuery, ipa::IpaQuery, oprf_shuffle::OPRFShuffleQuery, +}; use crate::{error::Error, query::ProtocolResult}; pub(super) type QueryResult = Result, Error>; diff --git a/src/query/runner/oprf_shuffle.rs b/src/query/runner/oprf_shuffle.rs new file mode 100644 index 000000000..38c2565f1 --- /dev/null +++ b/src/query/runner/oprf_shuffle.rs @@ -0,0 +1,48 @@ +use futures::{Stream, TryStreamExt}; + +use crate::{ + error::Error, + helpers::{ + query::{oprf_shuffle, QuerySize}, + BodyStream, RecordsStream, + }, + protocol::{ + context::Context, + oprf::{oprf_shuffle, OPRFInputRow}, + }, +}; + +pub struct OPRFShuffleQuery { + config: oprf_shuffle::QueryConfig, +} + +impl OPRFShuffleQuery { + pub fn new(config: oprf_shuffle::QueryConfig) -> Self { + Self { config } + } + + #[tracing::instrument("ipa_query", skip_all, fields(sz=%query_size))] + pub async fn execute<'a, C: Context + Send>( + self, + ctx: C, + query_size: QuerySize, + input_stream: BodyStream, + ) -> Result, Error> { + let input: Vec = + assert_stream_send(RecordsStream::::new(input_stream)) + .try_concat() + .await?; + + oprf_shuffle(ctx, input.as_slice(), self.config).await + } +} + +/// Helps to convince the compiler that things are `Send`. Like `seq_join::assert_send`, but for +/// streams. +/// +/// +pub fn assert_stream_send<'a, T>( + st: impl Stream + Send + 'a, +) -> impl Stream + Send + 'a { + st +} From cd6b5d248e463853e6c4ab0a3d758510500a896d Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Wed, 18 Oct 2023 17:55:25 -0500 Subject: [PATCH 02/15] Removed duplication of generating permutations and "Z" random tables with peers in run_h{1,2,3} functions --- src/protocol/oprf/mod.rs | 259 ++++++++++++++------------------------- 1 file changed, 89 insertions(+), 170 deletions(-) diff --git a/src/protocol/oprf/mod.rs b/src/protocol/oprf/mod.rs index 4d1f0f039..d67e89435 100644 --- a/src/protocol/oprf/mod.rs +++ b/src/protocol/oprf/mod.rs @@ -175,7 +175,6 @@ pub async fn oprf_shuffle( input_rows: &[OPRFInputRow], _config: QueryConfig, ) -> Result, Error> { - let role = ctx.role(); let batch_size = u32::try_from(input_rows.len()).map_err(|_e| { Error::FieldValueTruncation(format!( "Cannot truncate the number of input rows {} to u32", @@ -183,67 +182,34 @@ pub async fn oprf_shuffle( )) })?; - let my_shares = split_shares_and_get_left(input_rows); - let shared_with_rhs = split_shares_and_get_right(input_rows); + let share_l = split_shares_and_get_left(input_rows); + let share_r = split_shares_and_get_right(input_rows); - match role { - Role::H1 => run_h1(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, - Role::H2 => run_h2(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, - Role::H3 => run_h3(&ctx, &role, batch_size, my_shares, shared_with_rhs).await, + match ctx.role() { + Role::H1 => run_h1(&ctx, batch_size, share_l, share_r).await, + Role::H2 => run_h2(&ctx, batch_size, share_l, share_r).await, + Role::H3 => run_h3(&ctx, batch_size, share_l, share_r).await, } } -async fn run_h1( - ctx: &C, - role: &Role, - batch_size: u32, - my_shares: L, - rhs_shared: R, -) -> Result, Error> +async fn run_h1(ctx: &C, batch_size: u32, a: L, b: R) -> Result, Error> where C: Context, L: IntoIterator, R: IntoIterator, { - let a = my_shares; - let b = rhs_shared; - // // 1. Generate permutations - let pi_12 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi12, - Direction::Right, - ); - let pi_31 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi31, - Direction::Left, - ); - // - // 2. Generate random tables - let z_12 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ12, - Direction::Right, - ); - let z_31 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ31, - Direction::Left, - ); + let (pi_31, pi_12) = generate_permutations_with_peers(batch_size, ctx); - let a_hat = generate_random_table_with_peer( + // 2. Generate random tables + let (z_31, z_12) = generate_random_tables_with_peers(batch_size, ctx); + let a_hat = generate_random_table( batch_size, ctx, &OPRFShuffleStep::GenerateAHat, Direction::Left, ); - - let b_hat = generate_random_table_with_peer( + let b_hat = generate_random_table( batch_size, ctx, &OPRFShuffleStep::GenerateBHat, @@ -253,14 +219,13 @@ where // 3. Run computations let x_1_arg = add_single_shares(add_single_shares(a, b), z_12); let x_1 = permute(&pi_12, x_1_arg); - let x_2_arg = add_single_shares(x_1.iter(), z_31.iter()); let x_2 = permute(&pi_31, x_2_arg); send_to_peer( ctx, &OPRFShuffleStep::TransferX2, - role.peer(Direction::Right), + Direction::Right, x_2.clone(), ) .await?; @@ -269,72 +234,34 @@ where Ok(res) } -async fn run_h2( - ctx: &C, - role: &Role, - batch_size: u32, - _my_shares: L, - shared_with_rhs: R, -) -> Result, Error> +async fn run_h2(ctx: &C, batch_size: u32, _b: L, c: R) -> Result, Error> where C: Context, L: IntoIterator, R: IntoIterator, { - let c = shared_with_rhs; - // 1. Generate permutations - let pi_12 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi12, - Direction::Left, - ); - - let pi_23 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi23, - Direction::Right, - ); + let (pi_12, pi_23) = generate_permutations_with_peers(batch_size, ctx); // 2. Generate random tables - let z_12 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ12, - Direction::Left, - ); - - let z_23 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ23, - Direction::Right, - ); - - let b_hat = generate_random_table_with_peer( + let (z_12, z_23) = generate_random_tables_with_peers(batch_size, ctx); + let b_hat = generate_random_table( batch_size, ctx, &OPRFShuffleStep::GenerateBHat, Direction::Left, ); - // + // 3. Run computations let y_1_arg = add_single_shares(c, z_12.into_iter()); let y_1 = permute(&pi_12, y_1_arg); let ((), x_2) = try_join!( - send_to_peer( - ctx, - &OPRFShuffleStep::TransferY1, - role.peer(Direction::Right), - y_1, - ), + send_to_peer(ctx, &OPRFShuffleStep::TransferY1, Direction::Right, y_1), receive_from_peer( ctx, &OPRFShuffleStep::TransferX2, - role.peer(Direction::Left), + Direction::Left, batch_size, ), )?; @@ -342,76 +269,24 @@ where let x_3_arg = add_single_shares(x_2.into_iter(), z_23.into_iter()); let x_3 = permute(&pi_23, x_3_arg); let c_hat_1 = add_single_shares(x_3.iter(), b_hat.iter()).collect::>(); - - let ((), c_hat_2) = try_join!( - send_to_peer( - ctx, - &OPRFShuffleStep::TransferCHat1, - role.peer(Direction::Right), - c_hat_1.clone(), - ), - receive_from_peer( - ctx, - &OPRFShuffleStep::TransferCHat2, - role.peer(Direction::Right), - batch_size, - ) - )?; - + let c_hat_2 = exchange_c_hat(ctx, batch_size, c_hat_1.clone()).await?; let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); let res = combine_shares(b_hat, c_hat); Ok(res) } -async fn run_h3( - ctx: &C, - role: &Role, - batch_size: u32, - _my_shares: L, - _shared_with_rhs: R, -) -> Result, Error> +async fn run_h3(ctx: &C, batch_size: u32, _c: L, _a: R) -> Result, Error> where C: Context, L: IntoIterator, R: IntoIterator, { - // H3 does not need any secret shares. - // Its "C" shares are processed by helper2, Its "A" shares are processed by helper 1 - /* - let c = my_shares; - let a = rhs_shared; - */ - // 1. Generate permutations - let pi_23 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi23, - Direction::Left, - ); - let pi_31 = generate_pseudorandom_permutation( - batch_size, - ctx, - &OPRFShuffleStep::GeneratePi31, - Direction::Right, - ); + let (pi_23, pi_31) = generate_permutations_with_peers(batch_size, ctx); // 2. Generate random tables - let z_23 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ23, - Direction::Left, - ); - - let z_31 = generate_random_table_with_peer( - batch_size, - ctx, - &OPRFShuffleStep::GenerateZ31, - Direction::Right, - ); - - let a_hat = generate_random_table_with_peer( + let (z_23, z_31) = generate_random_tables_with_peers(batch_size, ctx); + let a_hat = generate_random_table( batch_size, ctx, &OPRFShuffleStep::GenerateAHat, @@ -422,7 +297,7 @@ where let y_1 = receive_from_peer( ctx, &OPRFShuffleStep::TransferY1, - role.peer(Direction::Left), + Direction::Left, batch_size, ) .await?; @@ -432,22 +307,7 @@ where let y_3_arg = add_single_shares(y_2, z_23); let y_3 = permute(&pi_23, y_3_arg); let c_hat_2 = add_single_shares(y_3, a_hat.clone()).collect::>(); - - let (c_hat_1, ()) = try_join!( - receive_from_peer( - ctx, - &OPRFShuffleStep::TransferCHat1, - role.peer(Direction::Left), - batch_size, - ), - send_to_peer( - ctx, - &OPRFShuffleStep::TransferCHat2, - role.peer(Direction::Left), - c_hat_2.clone(), - ) - )?; - + let c_hat_1 = exchange_c_hat(ctx, batch_size, c_hat_2.clone()).await?; let c_hat = add_single_shares(c_hat_1, c_hat_2); let res = combine_shares(c_hat, a_hat); Ok(res) @@ -493,7 +353,22 @@ where l.into_iter().zip(r).map(|(a, b)| a + b) } -fn generate_random_table_with_peer( +fn generate_random_tables_with_peers( + batch_size: u32, + ctx: &C, +) -> (Vec, Vec) { + let (step_left, step_right) = match ctx.role() { + Role::H1 => (OPRFShuffleStep::GenerateZ31, OPRFShuffleStep::GenerateZ12), + Role::H2 => (OPRFShuffleStep::GenerateZ12, OPRFShuffleStep::GenerateZ23), + Role::H3 => (OPRFShuffleStep::GenerateZ23, OPRFShuffleStep::GenerateZ12), + }; + + let with_left = generate_random_table(batch_size, ctx, &step_left, Direction::Left); + let with_right = generate_random_table(batch_size, ctx, &step_right, Direction::Right); + (with_left, with_right) +} + +fn generate_random_table( batch_size: u32, ctx: &C, step: &OPRFShuffleStep, @@ -526,9 +401,10 @@ where async fn send_to_peer>( ctx: &C, step: &OPRFShuffleStep, - role: Role, + direction: Direction, items: I, ) -> Result<(), Error> { + let role = ctx.role().peer(direction); let send_channel = ctx.narrow(step).send_channel(role); for (record_id, row) in items.into_iter().enumerate() { send_channel.send(RecordId::from(record_id), row).await?; @@ -539,9 +415,10 @@ async fn send_to_peer async fn receive_from_peer( ctx: &C, step: &OPRFShuffleStep, - role: Role, + direction: Direction, batch_size: u32, ) -> Result, Error> { + let role = ctx.role().peer(direction); let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); let mut output: Vec = Vec::with_capacity(batch_size as usize); @@ -553,8 +430,50 @@ async fn receive_from_peer( Ok(output) } +async fn exchange_c_hat>( + ctx: &C, + batch_size: u32, + part_to_send: I, +) -> Result, Error> { + let (step_send, step_recv, dir) = match ctx.role() { + Role::H2 => ( + OPRFShuffleStep::TransferCHat1, + OPRFShuffleStep::TransferCHat2, + Direction::Right, + ), + Role::H3 => ( + OPRFShuffleStep::TransferCHat2, + OPRFShuffleStep::TransferCHat1, + Direction::Left, + ), + role @ Role::H1 => { + unreachable!("Role {:?} does not participate in C_hat computation", role) + } + }; + + let ((), received_part) = try_join!( + send_to_peer(ctx, &step_send, dir, part_to_send), + receive_from_peer(ctx, &step_recv, dir, batch_size), + )?; + + Ok(received_part) +} + // --------------------------- permutation-related function --------------------------------------------- // +fn generate_permutations_with_peers(batch_size: u32, ctx: &C) -> (Vec, Vec) { + let (step_left, step_right) = match &ctx.role() { + Role::H1 => (OPRFShuffleStep::GeneratePi31, OPRFShuffleStep::GeneratePi12), + Role::H2 => (OPRFShuffleStep::GeneratePi12, OPRFShuffleStep::GeneratePi23), + Role::H3 => (OPRFShuffleStep::GeneratePi23, OPRFShuffleStep::GeneratePi12), + }; + + let with_left = generate_pseudorandom_permutation(batch_size, ctx, &step_left, Direction::Left); + let with_right = + generate_pseudorandom_permutation(batch_size, ctx, &step_right, Direction::Right); + (with_left, with_right) +} + fn generate_pseudorandom_permutation( batch_size: u32, ctx: &C, From f9119bf4003bbcc3db98a1bc5499698a7efdbd67 Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Thu, 19 Oct 2023 12:14:06 -0500 Subject: [PATCH 03/15] swapped bodies for impl Add for &OPRFShuffleSingleShare and impl Add for OPRFShuffleSingleShare --- src/protocol/oprf/mod.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/protocol/oprf/mod.rs b/src/protocol/oprf/mod.rs index d67e89435..b7a149ef0 100644 --- a/src/protocol/oprf/mod.rs +++ b/src/protocol/oprf/mod.rs @@ -76,22 +76,23 @@ impl OPRFShuffleSingleShare { impl Add for OPRFShuffleSingleShare { type Output = Self; + #[allow(clippy::op_ref)] fn add(self, rhs: Self) -> Self::Output { - Self { - timestamp: self.timestamp + rhs.timestamp, - mk: self.mk + rhs.mk, - is_trigger_bit: self.is_trigger_bit + rhs.is_trigger_bit, - breakdown_key: self.breakdown_key + rhs.breakdown_key, - trigger_value: self.trigger_value + rhs.trigger_value, - } + &self + &rhs } } impl Add for &OPRFShuffleSingleShare { type Output = OPRFShuffleSingleShare; - fn add(self, rhs: Self) -> Self::Output { - *self + *rhs // Relies on Copy + fn add(self, &rhs: Self) -> Self::Output { + Self::Output { + timestamp: self.timestamp + rhs.timestamp, + mk: self.mk + rhs.mk, + is_trigger_bit: self.is_trigger_bit + rhs.is_trigger_bit, + breakdown_key: self.breakdown_key + rhs.breakdown_key, + trigger_value: self.trigger_value + rhs.trigger_value, + } } } From 110a7e60a422ef5c85b4a0ed7670b6ec3171efd9 Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Thu, 19 Oct 2023 10:53:23 -0500 Subject: [PATCH 04/15] [oprf][shuffle] Addressed review comments 1. Unified GenerateZ_ij and GeneratePi_ij steps 2. Moved assert_stream_send into a new module create:one_off_fns (and updated a few callsites to use it) 3. Replaced try_join! macro with future:try_join fn calls 4. Unified ExchangeCHat1 and ExchangeCHat2 steps into ExchangeCHat 5. Replaced calls to exchange_c_hat function with inline send/receive 6. Dropped OPRFContext. And used Base context instead. I had to make Base::new public 7. Removed unnecessary x_2.clone() un run_h1 8. impl Distrubution for Standard and usage of rng.sample_iter(Standard) instead of manual OPRFShuffle::sample() 9. Got rid of permute fn. Call apply (shuffle) inline --- src/lib.rs | 2 + src/one_off_fns.rs | 12 ++ src/protocol/context/mod.rs | 4 +- src/protocol/context/oprf.rs | 88 ------------ src/protocol/oprf/mod.rs | 232 +++++++++++++------------------ src/query/executor.rs | 4 +- src/query/runner/aggregate.rs | 2 +- src/query/runner/ipa.rs | 13 +- src/query/runner/oprf_shuffle.rs | 13 +- 9 files changed, 117 insertions(+), 253 deletions(-) create mode 100644 src/one_off_fns.rs delete mode 100644 src/protocol/context/oprf.rs diff --git a/src/lib.rs b/src/lib.rs index 340601adc..b61144136 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,8 @@ pub mod error; pub mod ff; pub mod helpers; pub mod hpke; +pub mod one_off_fns; + #[cfg(feature = "web-app")] pub mod net; pub mod protocol; diff --git a/src/one_off_fns.rs b/src/one_off_fns.rs new file mode 100644 index 000000000..44bf7177e --- /dev/null +++ b/src/one_off_fns.rs @@ -0,0 +1,12 @@ +use futures::Stream; + +/// +/// Helps to convince the compiler that things are `Send`. Like `seq_join::assert_send`, but for +/// streams. +/// +/// +pub fn assert_stream_send<'a, T>( + st: impl Stream + Send + 'a, +) -> impl Stream + Send + 'a { + st +} diff --git a/src/protocol/context/mod.rs b/src/protocol/context/mod.rs index c2d4cc6aa..1a591a8a9 100644 --- a/src/protocol/context/mod.rs +++ b/src/protocol/context/mod.rs @@ -1,5 +1,4 @@ pub mod malicious; -pub mod oprf; pub mod prss; pub mod semi_honest; pub mod upgrade; @@ -9,7 +8,6 @@ use std::{num::NonZeroUsize, sync::Arc}; use async_trait::async_trait; pub use malicious::{Context as MaliciousContext, Upgraded as UpgradedMaliciousContext}; -pub use oprf::Context as OPRFContext; use prss::{InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness}; pub use semi_honest::{Context as SemiHonestContext, Upgraded as UpgradedSemiHonestContext}; pub use upgrade::{UpgradeContext, UpgradeToMalicious}; @@ -162,7 +160,7 @@ pub struct Base<'a> { } impl<'a> Base<'a> { - fn new(participant: &'a PrssEndpoint, gateway: &'a Gateway) -> Self { + pub fn new(participant: &'a PrssEndpoint, gateway: &'a Gateway) -> Self { Self::new_complete( participant, gateway, diff --git a/src/protocol/context/oprf.rs b/src/protocol/context/oprf.rs deleted file mode 100644 index 12eb29439..000000000 --- a/src/protocol/context/oprf.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::num::NonZeroUsize; - -use crate::{ - helpers::{Gateway, Message, ReceivingEnd, Role, SendingEnd, TotalRecords}, - protocol::{ - context::{ - Base, InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness, - }, - prss::Endpoint as PrssEndpoint, - step::{Gate, Step, StepNarrow}, - }, - seq_join::SeqJoin, -}; - -#[derive(Clone)] -pub struct Context<'a> { - inner: Base<'a>, -} - -impl<'a> Context<'a> { - pub fn new(participant: &'a PrssEndpoint, gateway: &'a Gateway) -> Self { - Self { - inner: Base::new(participant, gateway), - } - } - - #[cfg(test)] - #[must_use] - pub fn from_base(base: Base<'a>) -> Self { - Self { inner: base } - } -} - -impl<'a> super::Context for Context<'a> { - fn role(&self) -> Role { - self.inner.role() - } - - fn gate(&self) -> &Gate { - self.inner.gate() - } - - fn narrow(&self, step: &S) -> Self - where - Gate: StepNarrow, - { - Self { - inner: self.inner.narrow(step), - } - } - - fn set_total_records>(&self, total_records: T) -> Self { - Self { - inner: self.inner.set_total_records(total_records), - } - } - - fn total_records(&self) -> TotalRecords { - self.inner.total_records() - } - - fn prss(&self) -> InstrumentedIndexedSharedRandomness<'_> { - self.inner.prss() - } - - fn prss_rng( - &self, - ) -> ( - InstrumentedSequentialSharedRandomness, - InstrumentedSequentialSharedRandomness, - ) { - self.inner.prss_rng() - } - - fn send_channel(&self, role: Role) -> SendingEnd { - self.inner.send_channel(role) - } - - fn recv_channel(&self, role: Role) -> ReceivingEnd { - self.inner.recv_channel(role) - } -} - -impl<'a> SeqJoin for Context<'a> { - fn active_work(&self) -> NonZeroUsize { - self.inner.active_work() - } -} diff --git a/src/protocol/oprf/mod.rs b/src/protocol/oprf/mod.rs index b7a149ef0..6c3265877 100644 --- a/src/protocol/oprf/mod.rs +++ b/src/protocol/oprf/mod.rs @@ -1,9 +1,9 @@ use std::ops::Add; -use futures_util::try_join; +use futures::future; use generic_array::GenericArray; use ipa_macros::Step; -use rand::{seq::SliceRandom, Rng}; +use rand::{distributions::Standard, seq::SliceRandom, Rng}; use typenum::Unsigned; use super::{context::Context, ipa::IPAInputRow, RecordId}; @@ -31,7 +31,6 @@ pub struct OPRFShuffleSingleShare { impl OPRFShuffleSingleShare { #[must_use] pub fn from_input_row(input_row: &OPRFInputRow, shared_with: Direction) -> Self { - // Relying on the fact that all SharedValue(s) are Copy match shared_with { Direction::Left => Self { timestamp: input_row.timestamp.as_tuple().1, @@ -61,9 +60,11 @@ impl OPRFShuffleSingleShare { trigger_value: (self.trigger_value, rhs.trigger_value).into(), } } +} - pub fn sample(rng: &mut R) -> Self { - Self { +impl rand::prelude::Distribution for Standard { + fn sample(&self, rng: &mut R) -> OPRFShuffleSingleShare { + OPRFShuffleSingleShare { timestamp: OprfF::truncate_from(rng.gen::()), mk: OprfMK::truncate_from(rng.gen::()), is_trigger_bit: OprfF::truncate_from(rng.gen::()), @@ -157,14 +158,9 @@ impl Message for OPRFShuffleSingleShare {} pub(crate) enum OPRFShuffleStep { GenerateAHat, GenerateBHat, - GeneratePi12, - GeneratePi23, - GeneratePi31, - GenerateZ12, - GenerateZ23, - GenerateZ31, - TransferCHat1, - TransferCHat2, + GeneratePi, + GenerateZ, + TransferCHat, TransferX2, TransferY1, } @@ -183,8 +179,8 @@ pub async fn oprf_shuffle( )) })?; - let share_l = split_shares_and_get_left(input_rows); - let share_r = split_shares_and_get_right(input_rows); + let share_l = split_shares(input_rows, Direction::Left); + let share_r = split_shares(input_rows, Direction::Right); match ctx.role() { Role::H1 => run_h1(&ctx, batch_size, share_l, share_r).await, @@ -204,13 +200,13 @@ where // 2. Generate random tables let (z_31, z_12) = generate_random_tables_with_peers(batch_size, ctx); - let a_hat = generate_random_table( + let a_hat = generate_random_table_solo( batch_size, ctx, &OPRFShuffleStep::GenerateAHat, Direction::Left, ); - let b_hat = generate_random_table( + let b_hat = generate_random_table_solo( batch_size, ctx, &OPRFShuffleStep::GenerateBHat, @@ -218,18 +214,14 @@ where ); // 3. Run computations - let x_1_arg = add_single_shares(add_single_shares(a, b), z_12); - let x_1 = permute(&pi_12, x_1_arg); - let x_2_arg = add_single_shares(x_1.iter(), z_31.iter()); - let x_2 = permute(&pi_31, x_2_arg); + let mut x_1: Vec = + add_single_shares(add_single_shares(a, b), z_12).collect(); + apply(&pi_12, &mut x_1); - send_to_peer( - ctx, - &OPRFShuffleStep::TransferX2, - Direction::Right, - x_2.clone(), - ) - .await?; + let mut x_2: Vec = add_single_shares(x_1.iter(), z_31.iter()).collect(); + apply(&pi_31, &mut x_2); + + send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; let res = combine_shares(a_hat, b_hat); Ok(res) @@ -246,7 +238,7 @@ where // 2. Generate random tables let (z_12, z_23) = generate_random_tables_with_peers(batch_size, ctx); - let b_hat = generate_random_table( + let b_hat = generate_random_table_solo( batch_size, ctx, &OPRFShuffleStep::GenerateBHat, @@ -254,10 +246,10 @@ where ); // 3. Run computations - let y_1_arg = add_single_shares(c, z_12.into_iter()); - let y_1 = permute(&pi_12, y_1_arg); + let mut y_1: Vec = add_single_shares(c, z_12.into_iter()).collect(); + apply(&pi_12, &mut y_1); - let ((), x_2) = try_join!( + let ((), x_2) = future::try_join( send_to_peer(ctx, &OPRFShuffleStep::TransferY1, Direction::Right, y_1), receive_from_peer( ctx, @@ -265,12 +257,30 @@ where Direction::Left, batch_size, ), - )?; + ) + .await?; + + let mut x_3: Vec = + add_single_shares(x_2.into_iter(), z_23.into_iter()).collect(); + apply(&pi_23, &mut x_3); - let x_3_arg = add_single_shares(x_2.into_iter(), z_23.into_iter()); - let x_3 = permute(&pi_23, x_3_arg); let c_hat_1 = add_single_shares(x_3.iter(), b_hat.iter()).collect::>(); - let c_hat_2 = exchange_c_hat(ctx, batch_size, c_hat_1.clone()).await?; + let ((), c_hat_2) = future::try_join( + send_to_peer( + ctx, + &OPRFShuffleStep::TransferCHat, + Direction::Right, + c_hat_1.clone(), + ), + receive_from_peer( + ctx, + &OPRFShuffleStep::TransferCHat, + Direction::Right, + batch_size, + ), + ) + .await?; + let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); let res = combine_shares(b_hat, c_hat); Ok(res) @@ -287,7 +297,7 @@ where // 2. Generate random tables let (z_23, z_31) = generate_random_tables_with_peers(batch_size, ctx); - let a_hat = generate_random_table( + let a_hat = generate_random_table_solo( batch_size, ctx, &OPRFShuffleStep::GenerateAHat, @@ -303,35 +313,42 @@ where ) .await?; - let y_2_arg = add_single_shares(y_1, z_31); - let y_2 = permute(&pi_31, y_2_arg); - let y_3_arg = add_single_shares(y_2, z_23); - let y_3 = permute(&pi_23, y_3_arg); + let mut y_2: Vec = add_single_shares(y_1, z_31).collect(); + apply(&pi_31, &mut y_2); + + let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); + apply(&pi_23, &mut y_3); + let c_hat_2 = add_single_shares(y_3, a_hat.clone()).collect::>(); - let c_hat_1 = exchange_c_hat(ctx, batch_size, c_hat_2.clone()).await?; + let ((), c_hat_1) = future::try_join( + send_to_peer( + ctx, + &OPRFShuffleStep::TransferCHat, + Direction::Left, + c_hat_2.clone(), + ), + receive_from_peer( + ctx, + &OPRFShuffleStep::TransferCHat, + Direction::Left, + batch_size, + ), + ) + .await?; + let c_hat = add_single_shares(c_hat_1, c_hat_2); let res = combine_shares(c_hat, a_hat); Ok(res) } -// ------------------------------------------------------------------------------------------------------------- // - -fn split_shares_and_get_left( - input_rows: &[OPRFInputRow], -) -> impl Iterator + '_ { - let lhs = input_rows - .iter() - .map(|input_row| OPRFShuffleSingleShare::from_input_row(input_row, Direction::Left)); - lhs -} +// --------------------------------------------------------------------------- // -fn split_shares_and_get_right( +fn split_shares( input_rows: &[OPRFInputRow], + direction: Direction, ) -> impl Iterator + '_ { - let rhs = input_rows - .iter() - .map(|input_row| OPRFShuffleSingleShare::from_input_row(input_row, Direction::Right)); - rhs + let f = move |input_row| OPRFShuffleSingleShare::from_input_row(input_row, direction); + input_rows.iter().map(f) } fn combine_shares(l: L, r: R) -> Vec @@ -354,22 +371,20 @@ where l.into_iter().zip(r).map(|(a, b)| a + b) } +// --------------------------------------------------------------------------- // + fn generate_random_tables_with_peers( batch_size: u32, ctx: &C, ) -> (Vec, Vec) { - let (step_left, step_right) = match ctx.role() { - Role::H1 => (OPRFShuffleStep::GenerateZ31, OPRFShuffleStep::GenerateZ12), - Role::H2 => (OPRFShuffleStep::GenerateZ12, OPRFShuffleStep::GenerateZ23), - Role::H3 => (OPRFShuffleStep::GenerateZ23, OPRFShuffleStep::GenerateZ12), - }; - - let with_left = generate_random_table(batch_size, ctx, &step_left, Direction::Left); - let with_right = generate_random_table(batch_size, ctx, &step_right, Direction::Right); + let narrow_step = ctx.narrow(&OPRFShuffleStep::GenerateZ); + let (rng_l, rng_r) = narrow_step.prss_rng(); + let with_left = sample_iter(rng_l).take(batch_size as usize).collect(); + let with_right = sample_iter(rng_r).take(batch_size as usize).collect(); (with_left, with_right) } -fn generate_random_table( +fn generate_random_table_solo( batch_size: u32, ctx: &C, step: &OPRFShuffleStep, @@ -385,18 +400,15 @@ where Direction::Right => rngs.1, }; - let iter = std::iter::from_fn(move || Some(OPRFShuffleSingleShare::sample(&mut rng))) - .take(batch_size as usize); + sample_iter(&mut rng) + .take(batch_size as usize) + .collect::>() +} - // NOTE: I'd like to return an Iterator from here as there is really no need to allocate batch_size of items. - // It'd be better to just pass the iterator to add_single_shares function. - // But I was unable to figure the return type. The type checker was saying something - // about Box> and that rng is not Send, - // but it is currently beyond by level of knowledge. - // So, any advice is appreciated - iter.collect::>() +fn sample_iter(rng: R) -> impl Iterator { + rng.sample_iter(Standard) } -// + // ---------------------------- helper communication ------------------------------------ // async fn send_to_peer>( @@ -431,77 +443,23 @@ async fn receive_from_peer( Ok(output) } -async fn exchange_c_hat>( - ctx: &C, - batch_size: u32, - part_to_send: I, -) -> Result, Error> { - let (step_send, step_recv, dir) = match ctx.role() { - Role::H2 => ( - OPRFShuffleStep::TransferCHat1, - OPRFShuffleStep::TransferCHat2, - Direction::Right, - ), - Role::H3 => ( - OPRFShuffleStep::TransferCHat2, - OPRFShuffleStep::TransferCHat1, - Direction::Left, - ), - role @ Role::H1 => { - unreachable!("Role {:?} does not participate in C_hat computation", role) - } - }; - - let ((), received_part) = try_join!( - send_to_peer(ctx, &step_send, dir, part_to_send), - receive_from_peer(ctx, &step_recv, dir, batch_size), - )?; - - Ok(received_part) -} - -// --------------------------- permutation-related function --------------------------------------------- // +// ------------------ Pseudorandom permutations functions -------------------- // fn generate_permutations_with_peers(batch_size: u32, ctx: &C) -> (Vec, Vec) { - let (step_left, step_right) = match &ctx.role() { - Role::H1 => (OPRFShuffleStep::GeneratePi31, OPRFShuffleStep::GeneratePi12), - Role::H2 => (OPRFShuffleStep::GeneratePi12, OPRFShuffleStep::GeneratePi23), - Role::H3 => (OPRFShuffleStep::GeneratePi23, OPRFShuffleStep::GeneratePi12), - }; + let narrow_context = ctx.narrow(&OPRFShuffleStep::GeneratePi); + let mut rng = narrow_context.prss_rng(); - let with_left = generate_pseudorandom_permutation(batch_size, ctx, &step_left, Direction::Left); - let with_right = - generate_pseudorandom_permutation(batch_size, ctx, &step_right, Direction::Right); + let with_left = generate_pseudorandom_permutation(batch_size, &mut rng.0); + let with_right = generate_pseudorandom_permutation(batch_size, &mut rng.1); (with_left, with_right) } -fn generate_pseudorandom_permutation( - batch_size: u32, - ctx: &C, - step: &OPRFShuffleStep, - with_peer_on_the: Direction, -) -> Vec { - let narrow_context = ctx.narrow(step); - let rng = narrow_context.prss_rng(); - let mut rng = match with_peer_on_the { - Direction::Left => rng.0, - Direction::Right => rng.1, - }; - +fn generate_pseudorandom_permutation(batch_size: u32, rng: &mut R) -> Vec { let mut permutation = (0..batch_size).collect::>(); - permutation.shuffle(&mut rng); + permutation.shuffle(rng); permutation } -fn permute( - permutation: &[u32], - input: impl Iterator, -) -> Vec { - let mut rows = input.collect::>(); - apply(permutation, &mut rows); - rows -} - use bitvec::bitvec; use embed_doc_image::embed_doc_image; diff --git a/src/query/executor.rs b/src/query/executor.rs index 7f14a8498..585d2060e 100644 --- a/src/query/executor.rs +++ b/src/query/executor.rs @@ -25,7 +25,7 @@ use crate::{ }, hpke::{KeyPair, KeyRegistry}, protocol::{ - context::{MaliciousContext, OPRFContext, SemiHonestContext}, + context::{Base as BaseContext, MaliciousContext, SemiHonestContext}, prss::Endpoint as PrssEndpoint, step::{Gate, StepNarrow}, }, @@ -207,7 +207,7 @@ pub fn execute( gateway, input, move |prss, gateway, config, input| { - let ctx = OPRFContext::new(prss, gateway); + let ctx = BaseContext::new(prss, gateway); let query = OPRFShuffleQuery::new(oprf_shuffle_config) .execute(ctx, config.size, input) .then(|res| ready(res.map(|out| Box::new(out) as Box))); diff --git a/src/query/runner/aggregate.rs b/src/query/runner/aggregate.rs index e0dee3550..72c40f833 100644 --- a/src/query/runner/aggregate.rs +++ b/src/query/runner/aggregate.rs @@ -2,7 +2,6 @@ use std::marker::PhantomData; use futures_util::TryStreamExt; -use super::ipa::assert_stream_send; use crate::{ error::Error, ff::{Gf2, Gf8Bit, PrimeField, Serializable}, @@ -11,6 +10,7 @@ use crate::{ BodyStream, RecordsStream, }, hpke::{KeyPair, KeyRegistry}, + one_off_fns::assert_stream_send, protocol::{ aggregation::{sparse_aggregate, SparseAggregateInputRow}, basics::{Reshare, ShareKnownValue}, diff --git a/src/query/runner/ipa.rs b/src/query/runner/ipa.rs index 2b19ead23..1541ac0c0 100644 --- a/src/query/runner/ipa.rs +++ b/src/query/runner/ipa.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use futures::{ stream::{iter, repeat}, - Stream, StreamExt, TryStreamExt, + StreamExt, TryStreamExt, }; use crate::{ @@ -13,6 +13,7 @@ use crate::{ BodyStream, LengthDelimitedStream, RecordsStream, }, hpke::{KeyPair, KeyRegistry}, + one_off_fns::assert_stream_send, protocol::{ basics::{Reshare, ShareKnownValue}, context::{UpgradableContext, UpgradeContext, UpgradeToMalicious, UpgradedContext}, @@ -147,16 +148,6 @@ where } } -/// Helps to convince the compiler that things are `Send`. Like `seq_join::assert_send`, but for -/// streams. -/// -/// -pub fn assert_stream_send<'a, T>( - st: impl Stream + Send + 'a, -) -> impl Stream + Send + 'a { - st -} - /// no dependency on `weak-field` feature because it is enabled in tests by default #[cfg(all(test, unit_test))] mod tests { diff --git a/src/query/runner/oprf_shuffle.rs b/src/query/runner/oprf_shuffle.rs index 38c2565f1..515d5de86 100644 --- a/src/query/runner/oprf_shuffle.rs +++ b/src/query/runner/oprf_shuffle.rs @@ -1,4 +1,4 @@ -use futures::{Stream, TryStreamExt}; +use futures::TryStreamExt; use crate::{ error::Error, @@ -6,6 +6,7 @@ use crate::{ query::{oprf_shuffle, QuerySize}, BodyStream, RecordsStream, }, + one_off_fns::assert_stream_send, protocol::{ context::Context, oprf::{oprf_shuffle, OPRFInputRow}, @@ -36,13 +37,3 @@ impl OPRFShuffleQuery { oprf_shuffle(ctx, input.as_slice(), self.config).await } } - -/// Helps to convince the compiler that things are `Send`. Like `seq_join::assert_send`, but for -/// streams. -/// -/// -pub fn assert_stream_send<'a, T>( - st: impl Stream + Send + 'a, -) -> impl Stream + Send + 'a { - st -} From 90afdbb6933c8e8dccde43138fddbb494c6926fd Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Thu, 19 Oct 2023 16:41:53 -0500 Subject: [PATCH 05/15] More improvements 1. Rewrote implt Add for &OPRFShuffleSingleShare in a way that does not spook clippy 2. Hoisted variables for narrow contexts generating random tables into the callers to avoid allocation of them. Instead the addition of random tables are done by combining iterators if possible 3. Extract the pieces of work common for all 3 helprs (generate pis, generate zs) to the calling function --- src/protocol/oprf/mod.rs | 172 ++++++++++++++++++++------------------- 1 file changed, 89 insertions(+), 83 deletions(-) diff --git a/src/protocol/oprf/mod.rs b/src/protocol/oprf/mod.rs index 6c3265877..a4b393454 100644 --- a/src/protocol/oprf/mod.rs +++ b/src/protocol/oprf/mod.rs @@ -77,16 +77,15 @@ impl rand::prelude::Distribution for Standard { impl Add for OPRFShuffleSingleShare { type Output = Self; - #[allow(clippy::op_ref)] fn add(self, rhs: Self) -> Self::Output { - &self + &rhs + Add::add(&self, &rhs) } } -impl Add for &OPRFShuffleSingleShare { +impl<'a, 'b> Add<&'b OPRFShuffleSingleShare> for &'a OPRFShuffleSingleShare { type Output = OPRFShuffleSingleShare; - fn add(self, &rhs: Self) -> Self::Output { + fn add(self, rhs: &'b OPRFShuffleSingleShare) -> Self::Output { Self::Output { timestamp: self.timestamp + rhs.timestamp, mk: self.mk + rhs.mk, @@ -179,46 +178,54 @@ pub async fn oprf_shuffle( )) })?; - let share_l = split_shares(input_rows, Direction::Left); - let share_r = split_shares(input_rows, Direction::Right); + let shares = ( + split_shares(input_rows, Direction::Left), + split_shares(input_rows, Direction::Right), + ); + + // 1. Generate permutations + let pis = generate_permutations_with_peers(batch_size, &ctx); + + // 2. Generate random tables used by all helpers + let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ); + let zs = generate_random_tables_with_peers(batch_size, &ctx_z); match ctx.role() { - Role::H1 => run_h1(&ctx, batch_size, share_l, share_r).await, - Role::H2 => run_h2(&ctx, batch_size, share_l, share_r).await, - Role::H3 => run_h3(&ctx, batch_size, share_l, share_r).await, + Role::H1 => run_h1(&ctx, batch_size, shares, pis, zs).await, + Role::H2 => run_h2(&ctx, batch_size, shares, pis, zs).await, + Role::H3 => run_h3(&ctx, batch_size, shares, pis, zs).await, } } -async fn run_h1(ctx: &C, batch_size: u32, a: L, b: R) -> Result, Error> +async fn run_h1( + ctx: &C, + batch_size: u32, + (a, b): (Sl, Sr), + (pi_31, pi_12): (Vec, Vec), + (z_31, z_12): (Zl, Zr), +) -> Result, Error> where C: Context, - L: IntoIterator, - R: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, { - // 1. Generate permutations - let (pi_31, pi_12) = generate_permutations_with_peers(batch_size, ctx); + // 1. Generate helper-specific random tables + let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); + let a_hat: Vec<_> = + generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Left).collect(); - // 2. Generate random tables - let (z_31, z_12) = generate_random_tables_with_peers(batch_size, ctx); - let a_hat = generate_random_table_solo( - batch_size, - ctx, - &OPRFShuffleStep::GenerateAHat, - Direction::Left, - ); - let b_hat = generate_random_table_solo( - batch_size, - ctx, - &OPRFShuffleStep::GenerateBHat, - Direction::Right, - ); + let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); + let b_hat: Vec<_> = + generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect(); - // 3. Run computations + // 2. Run computations let mut x_1: Vec = add_single_shares(add_single_shares(a, b), z_12).collect(); apply(&pi_12, &mut x_1); - let mut x_2: Vec = add_single_shares(x_1.iter(), z_31.iter()).collect(); + let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); apply(&pi_31, &mut x_2); send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; @@ -227,25 +234,26 @@ where Ok(res) } -async fn run_h2(ctx: &C, batch_size: u32, _b: L, c: R) -> Result, Error> +async fn run_h2( + ctx: &C, + batch_size: u32, + (_b, c): (Sl, Sr), + (pi_12, pi_23): (Vec, Vec), + (z_12, z_23): (Zl, Zr), +) -> Result, Error> where C: Context, - L: IntoIterator, - R: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, { - // 1. Generate permutations - let (pi_12, pi_23) = generate_permutations_with_peers(batch_size, ctx); - - // 2. Generate random tables - let (z_12, z_23) = generate_random_tables_with_peers(batch_size, ctx); - let b_hat = generate_random_table_solo( - batch_size, - ctx, - &OPRFShuffleStep::GenerateBHat, - Direction::Left, - ); + // 1. Generate helper-specific random tables + let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); + let b_hat: Vec<_> = + generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Left).collect(); - // 3. Run computations + // 2. Run computations let mut y_1: Vec = add_single_shares(c, z_12.into_iter()).collect(); apply(&pi_12, &mut y_1); @@ -260,11 +268,10 @@ where ) .await?; - let mut x_3: Vec = - add_single_shares(x_2.into_iter(), z_23.into_iter()).collect(); + let mut x_3: Vec<_> = add_single_shares(x_2.into_iter(), z_23.into_iter()).collect(); apply(&pi_23, &mut x_3); - let c_hat_1 = add_single_shares(x_3.iter(), b_hat.iter()).collect::>(); + let c_hat_1: Vec<_> = add_single_shares(x_3.iter(), b_hat.iter()).collect(); let ((), c_hat_2) = future::try_join( send_to_peer( ctx, @@ -286,25 +293,26 @@ where Ok(res) } -async fn run_h3(ctx: &C, batch_size: u32, _c: L, _a: R) -> Result, Error> +async fn run_h3( + ctx: &C, + batch_size: u32, + (_c, _a): (Sl, Sr), + (pi_23, pi_31): (Vec, Vec), + (z_23, z_31): (Zl, Zr), +) -> Result, Error> where C: Context, - L: IntoIterator, - R: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, { - // 1. Generate permutations - let (pi_23, pi_31) = generate_permutations_with_peers(batch_size, ctx); - - // 2. Generate random tables - let (z_23, z_31) = generate_random_tables_with_peers(batch_size, ctx); - let a_hat = generate_random_table_solo( - batch_size, - ctx, - &OPRFShuffleStep::GenerateAHat, - Direction::Right, - ); + // 1. Generate helper-specific random tables + let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); + let a_hat: Vec<_> = + generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Right).collect(); - // 3. Run computations + // 2. Run computations let y_1 = receive_from_peer( ctx, &OPRFShuffleStep::TransferY1, @@ -319,7 +327,7 @@ where let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); apply(&pi_23, &mut y_3); - let c_hat_2 = add_single_shares(y_3, a_hat.clone()).collect::>(); + let c_hat_2 = add_single_shares(y_3.iter(), a_hat.iter()).collect::>(); let ((), c_hat_1) = future::try_join( send_to_peer( ctx, @@ -362,11 +370,11 @@ where .collect::>() } -fn add_single_shares<'a, T, L, R>(l: L, r: R) -> impl Iterator +fn add_single_shares<'i, T, L, R>(l: L, r: R) -> impl Iterator + 'i where - T: Add + 'a, - L: IntoIterator, - R: IntoIterator, + T: Add, + L: IntoIterator + 'i, + R: IntoIterator + 'i, { l.into_iter().zip(r).map(|(a, b)| a + b) } @@ -375,34 +383,32 @@ where fn generate_random_tables_with_peers( batch_size: u32, - ctx: &C, -) -> (Vec, Vec) { - let narrow_step = ctx.narrow(&OPRFShuffleStep::GenerateZ); - let (rng_l, rng_r) = narrow_step.prss_rng(); - let with_left = sample_iter(rng_l).take(batch_size as usize).collect(); - let with_right = sample_iter(rng_r).take(batch_size as usize).collect(); + narrow_ctx: &C, +) -> ( + impl Iterator + '_, + impl Iterator + '_, +) { + let (rng_l, rng_r) = narrow_ctx.prss_rng(); + let with_left = sample_iter(rng_l).take(batch_size as usize); + let with_right = sample_iter(rng_r).take(batch_size as usize); (with_left, with_right) } fn generate_random_table_solo( batch_size: u32, - ctx: &C, - step: &OPRFShuffleStep, + narrow_ctx: &C, peer: Direction, -) -> Vec +) -> impl Iterator + '_ where C: Context, { - let narrow_step = ctx.narrow(step); - let rngs = narrow_step.prss_rng(); - let mut rng = match peer { + let rngs = narrow_ctx.prss_rng(); + let rng = match peer { Direction::Left => rngs.0, Direction::Right => rngs.1, }; - sample_iter(&mut rng) - .take(batch_size as usize) - .collect::>() + sample_iter(rng).take(batch_size as usize) } fn sample_iter(rng: R) -> impl Iterator { From dd448761b1199606200b8adc51c2ee9c4051bde7 Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Fri, 20 Oct 2023 12:12:13 -0500 Subject: [PATCH 06/15] [oprf][shuffle] Renames 1. Renamed OPRFShuffleSingleShare -> OPRFShare 2. Extract OPRFShare and all its impls into oprf_share module 3. Used apply (permutation) from protocol/sort/apply.rs --- src/protocol/oprf/mod.rs | 248 ++++++-------------------------- src/protocol/oprf/oprf_share.rs | 148 +++++++++++++++++++ src/protocol/sort/mod.rs | 2 +- 3 files changed, 191 insertions(+), 207 deletions(-) create mode 100644 src/protocol/oprf/oprf_share.rs diff --git a/src/protocol/oprf/mod.rs b/src/protocol/oprf/mod.rs index a4b393454..249c9e675 100644 --- a/src/protocol/oprf/mod.rs +++ b/src/protocol/oprf/mod.rs @@ -1,158 +1,22 @@ +pub mod oprf_share; + use std::ops::Add; use futures::future; -use generic_array::GenericArray; use ipa_macros::Step; use rand::{distributions::Standard, seq::SliceRandom, Rng}; -use typenum::Unsigned; -use super::{context::Context, ipa::IPAInputRow, RecordId}; +use self::oprf_share::{OPRFShare, OprfBK, OprfF, OprfMK}; +use super::{ + context::Context, ipa::IPAInputRow, sort::apply::apply as apply_permutation, RecordId, +}; use crate::{ error::Error, - ff::{Field, Gf32Bit, Gf40Bit, Gf8Bit, Serializable}, - helpers::{query::oprf_shuffle::QueryConfig, Direction, Message, ReceivingEnd, Role}, + helpers::{query::oprf_shuffle::QueryConfig, Direction, ReceivingEnd, Role}, }; -type OprfMK = Gf40Bit; -type OprfBK = Gf8Bit; -type OprfF = Gf32Bit; - pub type OPRFInputRow = IPAInputRow; -#[derive(Debug, Clone, Copy)] -pub struct OPRFShuffleSingleShare { - pub timestamp: OprfF, - pub mk: OprfMK, - pub is_trigger_bit: OprfF, - pub breakdown_key: OprfBK, - pub trigger_value: OprfF, -} - -impl OPRFShuffleSingleShare { - #[must_use] - pub fn from_input_row(input_row: &OPRFInputRow, shared_with: Direction) -> Self { - match shared_with { - Direction::Left => Self { - timestamp: input_row.timestamp.as_tuple().1, - mk: input_row.mk_shares.as_tuple().1, - is_trigger_bit: input_row.is_trigger_bit.as_tuple().1, - breakdown_key: input_row.breakdown_key.as_tuple().1, - trigger_value: input_row.trigger_value.as_tuple().1, - }, - - Direction::Right => Self { - timestamp: input_row.timestamp.as_tuple().0, - mk: input_row.mk_shares.as_tuple().0, - is_trigger_bit: input_row.is_trigger_bit.as_tuple().0, - breakdown_key: input_row.breakdown_key.as_tuple().0, - trigger_value: input_row.trigger_value.as_tuple().0, - }, - } - } - - #[must_use] - pub fn to_input_row(self, rhs: Self) -> OPRFInputRow { - OPRFInputRow { - timestamp: (self.timestamp, rhs.timestamp).into(), - mk_shares: (self.mk, rhs.mk).into(), - is_trigger_bit: (self.is_trigger_bit, rhs.is_trigger_bit).into(), - breakdown_key: (self.breakdown_key, rhs.breakdown_key).into(), - trigger_value: (self.trigger_value, rhs.trigger_value).into(), - } - } -} - -impl rand::prelude::Distribution for Standard { - fn sample(&self, rng: &mut R) -> OPRFShuffleSingleShare { - OPRFShuffleSingleShare { - timestamp: OprfF::truncate_from(rng.gen::()), - mk: OprfMK::truncate_from(rng.gen::()), - is_trigger_bit: OprfF::truncate_from(rng.gen::()), - breakdown_key: OprfBK::truncate_from(rng.gen::()), - trigger_value: OprfF::truncate_from(rng.gen::()), - } - } -} - -impl Add for OPRFShuffleSingleShare { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - Add::add(&self, &rhs) - } -} - -impl<'a, 'b> Add<&'b OPRFShuffleSingleShare> for &'a OPRFShuffleSingleShare { - type Output = OPRFShuffleSingleShare; - - fn add(self, rhs: &'b OPRFShuffleSingleShare) -> Self::Output { - Self::Output { - timestamp: self.timestamp + rhs.timestamp, - mk: self.mk + rhs.mk, - is_trigger_bit: self.is_trigger_bit + rhs.is_trigger_bit, - breakdown_key: self.breakdown_key + rhs.breakdown_key, - trigger_value: self.trigger_value + rhs.trigger_value, - } - } -} - -impl Serializable for OPRFShuffleSingleShare { - type Size = <::Size as Add< - <::Size as Add< - <::Size as Add< - <::Size as Add<::Size>>::Output, - >>::Output, - >>::Output, - >>::Output; - - fn serialize(&self, buf: &mut GenericArray) { - let mk_sz = ::Size::USIZE; - let bk_sz = ::Size::USIZE; - let f_sz = ::Size::USIZE; - - self.timestamp - .serialize(GenericArray::from_mut_slice(&mut buf[..f_sz])); - self.mk - .serialize(GenericArray::from_mut_slice(&mut buf[f_sz..f_sz + mk_sz])); - self.is_trigger_bit.serialize(GenericArray::from_mut_slice( - &mut buf[f_sz + mk_sz..f_sz + mk_sz + f_sz], - )); - self.breakdown_key.serialize(GenericArray::from_mut_slice( - &mut buf[f_sz + mk_sz + f_sz..f_sz + mk_sz + f_sz + bk_sz], - )); - self.trigger_value.serialize(GenericArray::from_mut_slice( - &mut buf[f_sz + mk_sz + f_sz + bk_sz..], - )); - } - - fn deserialize(buf: &GenericArray) -> Self { - let mk_sz = ::Size::USIZE; - let bk_sz = ::Size::USIZE; - let f_sz = ::Size::USIZE; - - let timestamp = OprfF::deserialize(GenericArray::from_slice(&buf[..f_sz])); - let mk = OprfMK::deserialize(GenericArray::from_slice(&buf[f_sz..f_sz + mk_sz])); - let is_trigger_bit = OprfF::deserialize(GenericArray::from_slice( - &buf[f_sz + mk_sz..f_sz + mk_sz + f_sz], - )); - let breakdown_key = OprfBK::deserialize(GenericArray::from_slice( - &buf[f_sz + mk_sz + f_sz..f_sz + mk_sz + f_sz + bk_sz], - )); - let trigger_value = OprfF::deserialize(GenericArray::from_slice( - &buf[f_sz + mk_sz + f_sz + bk_sz..], - )); - Self { - timestamp, - mk, - is_trigger_bit, - breakdown_key, - trigger_value, - } - } -} - -impl Message for OPRFShuffleSingleShare {} - #[derive(Step)] pub(crate) enum OPRFShuffleStep { GenerateAHat, @@ -206,10 +70,10 @@ async fn run_h1( ) -> Result, Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, { // 1. Generate helper-specific random tables let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); @@ -221,12 +85,11 @@ where generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect(); // 2. Run computations - let mut x_1: Vec = - add_single_shares(add_single_shares(a, b), z_12).collect(); - apply(&pi_12, &mut x_1); + let mut x_1: Vec = add_single_shares(add_single_shares(a, b), z_12).collect(); + apply_permutation(&pi_12, &mut x_1); - let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); - apply(&pi_31, &mut x_2); + let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); + apply_permutation(&pi_31, &mut x_2); send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; @@ -243,10 +106,10 @@ async fn run_h2( ) -> Result, Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, { // 1. Generate helper-specific random tables let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); @@ -254,8 +117,8 @@ where generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Left).collect(); // 2. Run computations - let mut y_1: Vec = add_single_shares(c, z_12.into_iter()).collect(); - apply(&pi_12, &mut y_1); + let mut y_1: Vec = add_single_shares(c, z_12.into_iter()).collect(); + apply_permutation(&pi_12, &mut y_1); let ((), x_2) = future::try_join( send_to_peer(ctx, &OPRFShuffleStep::TransferY1, Direction::Right, y_1), @@ -269,7 +132,7 @@ where .await?; let mut x_3: Vec<_> = add_single_shares(x_2.into_iter(), z_23.into_iter()).collect(); - apply(&pi_23, &mut x_3); + apply_permutation(&pi_23, &mut x_3); let c_hat_1: Vec<_> = add_single_shares(x_3.iter(), b_hat.iter()).collect(); let ((), c_hat_2) = future::try_join( @@ -302,10 +165,10 @@ async fn run_h3( ) -> Result, Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, { // 1. Generate helper-specific random tables let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); @@ -321,11 +184,11 @@ where ) .await?; - let mut y_2: Vec = add_single_shares(y_1, z_31).collect(); - apply(&pi_31, &mut y_2); + let mut y_2: Vec = add_single_shares(y_1, z_31).collect(); + apply_permutation(&pi_31, &mut y_2); - let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); - apply(&pi_23, &mut y_3); + let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); + apply_permutation(&pi_23, &mut y_3); let c_hat_2 = add_single_shares(y_3.iter(), a_hat.iter()).collect::>(); let ((), c_hat_1) = future::try_join( @@ -354,15 +217,15 @@ where fn split_shares( input_rows: &[OPRFInputRow], direction: Direction, -) -> impl Iterator + '_ { - let f = move |input_row| OPRFShuffleSingleShare::from_input_row(input_row, direction); +) -> impl Iterator + '_ { + let f = move |input_row| OPRFShare::from_input_row(input_row, direction); input_rows.iter().map(f) } fn combine_shares(l: L, r: R) -> Vec where - L: IntoIterator, - R: IntoIterator, + L: IntoIterator, + R: IntoIterator, { l.into_iter() .zip(r) @@ -385,8 +248,8 @@ fn generate_random_tables_with_peers( batch_size: u32, narrow_ctx: &C, ) -> ( - impl Iterator + '_, - impl Iterator + '_, + impl Iterator + '_, + impl Iterator + '_, ) { let (rng_l, rng_r) = narrow_ctx.prss_rng(); let with_left = sample_iter(rng_l).take(batch_size as usize); @@ -398,7 +261,7 @@ fn generate_random_table_solo( batch_size: u32, narrow_ctx: &C, peer: Direction, -) -> impl Iterator + '_ +) -> impl Iterator + '_ where C: Context, { @@ -411,13 +274,13 @@ where sample_iter(rng).take(batch_size as usize) } -fn sample_iter(rng: R) -> impl Iterator { +fn sample_iter(rng: R) -> impl Iterator { rng.sample_iter(Standard) } // ---------------------------- helper communication ------------------------------------ // -async fn send_to_peer>( +async fn send_to_peer>( ctx: &C, step: &OPRFShuffleStep, direction: Direction, @@ -436,11 +299,11 @@ async fn receive_from_peer( step: &OPRFShuffleStep, direction: Direction, batch_size: u32, -) -> Result, Error> { +) -> Result, Error> { let role = ctx.role().peer(direction); - let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); + let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); - let mut output: Vec = Vec::with_capacity(batch_size as usize); + let mut output: Vec = Vec::with_capacity(batch_size as usize); for record_id in 0..batch_size { let msg = receive_channel.receive(RecordId::from(record_id)).await?; output.push(msg); @@ -465,30 +328,3 @@ fn generate_pseudorandom_permutation(batch_size: u32, rng: &mut R) -> Ve permutation.shuffle(rng); permutation } - -use bitvec::bitvec; -use embed_doc_image::embed_doc_image; - -#[embed_doc_image("apply", "images/sort/apply.png")] -/// Permutation reorders (1, 2, . . . , m) into (σ(1), σ(2), . . . , σ(m)). -/// For example, if σ(1) = 2, σ(2) = 3, σ(3) = 1, and σ(4) = 0, an input (A, B, C, D) is reordered into (C, D, B, A) by σ. -/// -/// ![Apply steps][apply] -fn apply(permutation: &[u32], values: &mut [T]) { - // NOTE: This is copypasta from crate::protocol::sort - debug_assert!(permutation.len() == values.len()); - let mut permuted = bitvec![0; permutation.len()]; - - for i in 0..permutation.len() { - if !permuted[i] { - let mut pos_i = i; - let mut pos_j = permutation[pos_i] as usize; - while pos_j != i { - values.swap(pos_i, pos_j); - permuted.set(pos_j, true); - pos_i = pos_j; - pos_j = permutation[pos_i] as usize; - } - } - } -} diff --git a/src/protocol/oprf/oprf_share.rs b/src/protocol/oprf/oprf_share.rs new file mode 100644 index 000000000..d864ea9b9 --- /dev/null +++ b/src/protocol/oprf/oprf_share.rs @@ -0,0 +1,148 @@ +use std::ops::Add; + +use generic_array::GenericArray; +use rand::{distributions::Standard, Rng}; +use typenum::Unsigned; + +use super::OPRFInputRow; +use crate::{ + ff::{Field, Gf32Bit, Gf40Bit, Gf8Bit, Serializable}, + helpers::{Direction, Message}, +}; +pub type OprfMK = Gf40Bit; +pub type OprfBK = Gf8Bit; +pub type OprfF = Gf32Bit; + +#[derive(Debug, Clone, Copy)] +pub struct OPRFShare { + pub timestamp: OprfF, + pub mk: OprfMK, + pub is_trigger_bit: OprfF, + pub breakdown_key: OprfBK, + pub trigger_value: OprfF, +} + +impl OPRFShare { + #[must_use] + pub fn from_input_row(input_row: &OPRFInputRow, shared_with: Direction) -> Self { + match shared_with { + Direction::Left => Self { + timestamp: input_row.timestamp.as_tuple().1, + mk: input_row.mk_shares.as_tuple().1, + is_trigger_bit: input_row.is_trigger_bit.as_tuple().1, + breakdown_key: input_row.breakdown_key.as_tuple().1, + trigger_value: input_row.trigger_value.as_tuple().1, + }, + + Direction::Right => Self { + timestamp: input_row.timestamp.as_tuple().0, + mk: input_row.mk_shares.as_tuple().0, + is_trigger_bit: input_row.is_trigger_bit.as_tuple().0, + breakdown_key: input_row.breakdown_key.as_tuple().0, + trigger_value: input_row.trigger_value.as_tuple().0, + }, + } + } + + #[must_use] + pub fn to_input_row(self, rhs: Self) -> OPRFInputRow { + OPRFInputRow { + timestamp: (self.timestamp, rhs.timestamp).into(), + mk_shares: (self.mk, rhs.mk).into(), + is_trigger_bit: (self.is_trigger_bit, rhs.is_trigger_bit).into(), + breakdown_key: (self.breakdown_key, rhs.breakdown_key).into(), + trigger_value: (self.trigger_value, rhs.trigger_value).into(), + } + } +} + +impl rand::prelude::Distribution for Standard { + fn sample(&self, rng: &mut R) -> OPRFShare { + OPRFShare { + timestamp: OprfF::truncate_from(rng.gen::()), + mk: OprfMK::truncate_from(rng.gen::()), + is_trigger_bit: OprfF::truncate_from(rng.gen::()), + breakdown_key: OprfBK::truncate_from(rng.gen::()), + trigger_value: OprfF::truncate_from(rng.gen::()), + } + } +} + +impl Add for OPRFShare { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Add::add(&self, &rhs) + } +} + +impl<'a, 'b> Add<&'b OPRFShare> for &'a OPRFShare { + type Output = OPRFShare; + + fn add(self, rhs: &'b OPRFShare) -> Self::Output { + Self::Output { + timestamp: self.timestamp + rhs.timestamp, + mk: self.mk + rhs.mk, + is_trigger_bit: self.is_trigger_bit + rhs.is_trigger_bit, + breakdown_key: self.breakdown_key + rhs.breakdown_key, + trigger_value: self.trigger_value + rhs.trigger_value, + } + } +} + +impl Serializable for OPRFShare { + type Size = <::Size as Add< + <::Size as Add< + <::Size as Add< + <::Size as Add<::Size>>::Output, + >>::Output, + >>::Output, + >>::Output; + + fn serialize(&self, buf: &mut GenericArray) { + let mk_sz = ::Size::USIZE; + let bk_sz = ::Size::USIZE; + let f_sz = ::Size::USIZE; + + self.timestamp + .serialize(GenericArray::from_mut_slice(&mut buf[..f_sz])); + self.mk + .serialize(GenericArray::from_mut_slice(&mut buf[f_sz..f_sz + mk_sz])); + self.is_trigger_bit.serialize(GenericArray::from_mut_slice( + &mut buf[f_sz + mk_sz..f_sz + mk_sz + f_sz], + )); + self.breakdown_key.serialize(GenericArray::from_mut_slice( + &mut buf[f_sz + mk_sz + f_sz..f_sz + mk_sz + f_sz + bk_sz], + )); + self.trigger_value.serialize(GenericArray::from_mut_slice( + &mut buf[f_sz + mk_sz + f_sz + bk_sz..], + )); + } + + fn deserialize(buf: &GenericArray) -> Self { + let mk_sz = ::Size::USIZE; + let bk_sz = ::Size::USIZE; + let f_sz = ::Size::USIZE; + + let timestamp = OprfF::deserialize(GenericArray::from_slice(&buf[..f_sz])); + let mk = OprfMK::deserialize(GenericArray::from_slice(&buf[f_sz..f_sz + mk_sz])); + let is_trigger_bit = OprfF::deserialize(GenericArray::from_slice( + &buf[f_sz + mk_sz..f_sz + mk_sz + f_sz], + )); + let breakdown_key = OprfBK::deserialize(GenericArray::from_slice( + &buf[f_sz + mk_sz + f_sz..f_sz + mk_sz + f_sz + bk_sz], + )); + let trigger_value = OprfF::deserialize(GenericArray::from_slice( + &buf[f_sz + mk_sz + f_sz + bk_sz..], + )); + Self { + timestamp, + mk, + is_trigger_bit, + breakdown_key, + trigger_value, + } + } +} + +impl Message for OPRFShare {} diff --git a/src/protocol/sort/mod.rs b/src/protocol/sort/mod.rs index 8427f0b03..4908a76bf 100644 --- a/src/protocol/sort/mod.rs +++ b/src/protocol/sort/mod.rs @@ -1,9 +1,9 @@ +pub mod apply; pub mod apply_sort; pub mod bit_permutation; pub mod generate_permutation; pub mod generate_permutation_opt; -mod apply; mod compose; mod multi_bit_permutation; mod secureapplyinv; From 7cba0bd0e1e1d5880148ffbfc45dc9315e09a111 Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Fri, 20 Oct 2023 15:09:22 -0500 Subject: [PATCH 07/15] [oprf][shuffle] A bunch of renames 1. Renamed oprf_shuffle module into oprf::shuffle 2. Renamed OPRFShare to ShuffleShare 3. Renamed OPRFShuffle* type aliases (for fields) into ShufleShare* --- src/protocol/oprf/mod.rs | 331 +----------------- src/protocol/oprf/shuffle/mod.rs | 330 +++++++++++++++++ .../oprf/{oprf_share.rs => shuffle/share.rs} | 88 ++--- src/query/runner/oprf_shuffle.rs | 10 +- 4 files changed, 381 insertions(+), 378 deletions(-) create mode 100644 src/protocol/oprf/shuffle/mod.rs rename src/protocol/oprf/{oprf_share.rs => shuffle/share.rs} (55%) diff --git a/src/protocol/oprf/mod.rs b/src/protocol/oprf/mod.rs index 249c9e675..afc1d0476 100644 --- a/src/protocol/oprf/mod.rs +++ b/src/protocol/oprf/mod.rs @@ -1,330 +1 @@ -pub mod oprf_share; - -use std::ops::Add; - -use futures::future; -use ipa_macros::Step; -use rand::{distributions::Standard, seq::SliceRandom, Rng}; - -use self::oprf_share::{OPRFShare, OprfBK, OprfF, OprfMK}; -use super::{ - context::Context, ipa::IPAInputRow, sort::apply::apply as apply_permutation, RecordId, -}; -use crate::{ - error::Error, - helpers::{query::oprf_shuffle::QueryConfig, Direction, ReceivingEnd, Role}, -}; - -pub type OPRFInputRow = IPAInputRow; - -#[derive(Step)] -pub(crate) enum OPRFShuffleStep { - GenerateAHat, - GenerateBHat, - GeneratePi, - GenerateZ, - TransferCHat, - TransferX2, - TransferY1, -} - -/// # Errors -/// Will propagate errors from transport and a few typecasts -pub async fn oprf_shuffle( - ctx: C, - input_rows: &[OPRFInputRow], - _config: QueryConfig, -) -> Result, Error> { - let batch_size = u32::try_from(input_rows.len()).map_err(|_e| { - Error::FieldValueTruncation(format!( - "Cannot truncate the number of input rows {} to u32", - input_rows.len(), - )) - })?; - - let shares = ( - split_shares(input_rows, Direction::Left), - split_shares(input_rows, Direction::Right), - ); - - // 1. Generate permutations - let pis = generate_permutations_with_peers(batch_size, &ctx); - - // 2. Generate random tables used by all helpers - let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ); - let zs = generate_random_tables_with_peers(batch_size, &ctx_z); - - match ctx.role() { - Role::H1 => run_h1(&ctx, batch_size, shares, pis, zs).await, - Role::H2 => run_h2(&ctx, batch_size, shares, pis, zs).await, - Role::H3 => run_h3(&ctx, batch_size, shares, pis, zs).await, - } -} - -async fn run_h1( - ctx: &C, - batch_size: u32, - (a, b): (Sl, Sr), - (pi_31, pi_12): (Vec, Vec), - (z_31, z_12): (Zl, Zr), -) -> Result, Error> -where - C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, -{ - // 1. Generate helper-specific random tables - let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); - let a_hat: Vec<_> = - generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Left).collect(); - - let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); - let b_hat: Vec<_> = - generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect(); - - // 2. Run computations - let mut x_1: Vec = add_single_shares(add_single_shares(a, b), z_12).collect(); - apply_permutation(&pi_12, &mut x_1); - - let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); - apply_permutation(&pi_31, &mut x_2); - - send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; - - let res = combine_shares(a_hat, b_hat); - Ok(res) -} - -async fn run_h2( - ctx: &C, - batch_size: u32, - (_b, c): (Sl, Sr), - (pi_12, pi_23): (Vec, Vec), - (z_12, z_23): (Zl, Zr), -) -> Result, Error> -where - C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, -{ - // 1. Generate helper-specific random tables - let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); - let b_hat: Vec<_> = - generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Left).collect(); - - // 2. Run computations - let mut y_1: Vec = add_single_shares(c, z_12.into_iter()).collect(); - apply_permutation(&pi_12, &mut y_1); - - let ((), x_2) = future::try_join( - send_to_peer(ctx, &OPRFShuffleStep::TransferY1, Direction::Right, y_1), - receive_from_peer( - ctx, - &OPRFShuffleStep::TransferX2, - Direction::Left, - batch_size, - ), - ) - .await?; - - let mut x_3: Vec<_> = add_single_shares(x_2.into_iter(), z_23.into_iter()).collect(); - apply_permutation(&pi_23, &mut x_3); - - let c_hat_1: Vec<_> = add_single_shares(x_3.iter(), b_hat.iter()).collect(); - let ((), c_hat_2) = future::try_join( - send_to_peer( - ctx, - &OPRFShuffleStep::TransferCHat, - Direction::Right, - c_hat_1.clone(), - ), - receive_from_peer( - ctx, - &OPRFShuffleStep::TransferCHat, - Direction::Right, - batch_size, - ), - ) - .await?; - - let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); - let res = combine_shares(b_hat, c_hat); - Ok(res) -} - -async fn run_h3( - ctx: &C, - batch_size: u32, - (_c, _a): (Sl, Sr), - (pi_23, pi_31): (Vec, Vec), - (z_23, z_31): (Zl, Zr), -) -> Result, Error> -where - C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, -{ - // 1. Generate helper-specific random tables - let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); - let a_hat: Vec<_> = - generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Right).collect(); - - // 2. Run computations - let y_1 = receive_from_peer( - ctx, - &OPRFShuffleStep::TransferY1, - Direction::Left, - batch_size, - ) - .await?; - - let mut y_2: Vec = add_single_shares(y_1, z_31).collect(); - apply_permutation(&pi_31, &mut y_2); - - let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); - apply_permutation(&pi_23, &mut y_3); - - let c_hat_2 = add_single_shares(y_3.iter(), a_hat.iter()).collect::>(); - let ((), c_hat_1) = future::try_join( - send_to_peer( - ctx, - &OPRFShuffleStep::TransferCHat, - Direction::Left, - c_hat_2.clone(), - ), - receive_from_peer( - ctx, - &OPRFShuffleStep::TransferCHat, - Direction::Left, - batch_size, - ), - ) - .await?; - - let c_hat = add_single_shares(c_hat_1, c_hat_2); - let res = combine_shares(c_hat, a_hat); - Ok(res) -} - -// --------------------------------------------------------------------------- // - -fn split_shares( - input_rows: &[OPRFInputRow], - direction: Direction, -) -> impl Iterator + '_ { - let f = move |input_row| OPRFShare::from_input_row(input_row, direction); - input_rows.iter().map(f) -} - -fn combine_shares(l: L, r: R) -> Vec -where - L: IntoIterator, - R: IntoIterator, -{ - l.into_iter() - .zip(r) - .map(|(l, r)| l.to_input_row(r)) - .collect::>() -} - -fn add_single_shares<'i, T, L, R>(l: L, r: R) -> impl Iterator + 'i -where - T: Add, - L: IntoIterator + 'i, - R: IntoIterator + 'i, -{ - l.into_iter().zip(r).map(|(a, b)| a + b) -} - -// --------------------------------------------------------------------------- // - -fn generate_random_tables_with_peers( - batch_size: u32, - narrow_ctx: &C, -) -> ( - impl Iterator + '_, - impl Iterator + '_, -) { - let (rng_l, rng_r) = narrow_ctx.prss_rng(); - let with_left = sample_iter(rng_l).take(batch_size as usize); - let with_right = sample_iter(rng_r).take(batch_size as usize); - (with_left, with_right) -} - -fn generate_random_table_solo( - batch_size: u32, - narrow_ctx: &C, - peer: Direction, -) -> impl Iterator + '_ -where - C: Context, -{ - let rngs = narrow_ctx.prss_rng(); - let rng = match peer { - Direction::Left => rngs.0, - Direction::Right => rngs.1, - }; - - sample_iter(rng).take(batch_size as usize) -} - -fn sample_iter(rng: R) -> impl Iterator { - rng.sample_iter(Standard) -} - -// ---------------------------- helper communication ------------------------------------ // - -async fn send_to_peer>( - ctx: &C, - step: &OPRFShuffleStep, - direction: Direction, - items: I, -) -> Result<(), Error> { - let role = ctx.role().peer(direction); - let send_channel = ctx.narrow(step).send_channel(role); - for (record_id, row) in items.into_iter().enumerate() { - send_channel.send(RecordId::from(record_id), row).await?; - } - Ok(()) -} - -async fn receive_from_peer( - ctx: &C, - step: &OPRFShuffleStep, - direction: Direction, - batch_size: u32, -) -> Result, Error> { - let role = ctx.role().peer(direction); - let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); - - let mut output: Vec = Vec::with_capacity(batch_size as usize); - for record_id in 0..batch_size { - let msg = receive_channel.receive(RecordId::from(record_id)).await?; - output.push(msg); - } - - Ok(output) -} - -// ------------------ Pseudorandom permutations functions -------------------- // - -fn generate_permutations_with_peers(batch_size: u32, ctx: &C) -> (Vec, Vec) { - let narrow_context = ctx.narrow(&OPRFShuffleStep::GeneratePi); - let mut rng = narrow_context.prss_rng(); - - let with_left = generate_pseudorandom_permutation(batch_size, &mut rng.0); - let with_right = generate_pseudorandom_permutation(batch_size, &mut rng.1); - (with_left, with_right) -} - -fn generate_pseudorandom_permutation(batch_size: u32, rng: &mut R) -> Vec { - let mut permutation = (0..batch_size).collect::>(); - permutation.shuffle(rng); - permutation -} +pub mod shuffle; diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs new file mode 100644 index 000000000..bdf872cc3 --- /dev/null +++ b/src/protocol/oprf/shuffle/mod.rs @@ -0,0 +1,330 @@ +pub mod share; + +use std::ops::Add; + +use futures::future; +use ipa_macros::Step; +use rand::{distributions::Standard, seq::SliceRandom, Rng}; + +use self::share::{ShuffleShare, ShuffleShareBK, ShuffleShareF, ShuffleShareMK}; +use super::super::{ + context::Context, ipa::IPAInputRow, sort::apply::apply as apply_permutation, RecordId, +}; +use crate::{ + error::Error, + helpers::{query::oprf_shuffle::QueryConfig, Direction, ReceivingEnd, Role}, +}; + +pub type ShuffleInputRow = IPAInputRow; + +#[derive(Step)] +pub(crate) enum OPRFShuffleStep { + GenerateAHat, + GenerateBHat, + GeneratePi, + GenerateZ, + TransferCHat, + TransferX2, + TransferY1, +} + +/// # Errors +/// Will propagate errors from transport and a few typecasts +pub async fn shuffle( + ctx: C, + input_rows: &[ShuffleInputRow], + _config: QueryConfig, +) -> Result, Error> { + let batch_size = u32::try_from(input_rows.len()).map_err(|_e| { + Error::FieldValueTruncation(format!( + "Cannot truncate the number of input rows {} to u32", + input_rows.len(), + )) + })?; + + let shares = ( + split_shares(input_rows, Direction::Left), + split_shares(input_rows, Direction::Right), + ); + + // 1. Generate permutations + let pis = generate_permutations_with_peers(batch_size, &ctx); + + // 2. Generate random tables used by all helpers + let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ); + let zs = generate_random_tables_with_peers(batch_size, &ctx_z); + + match ctx.role() { + Role::H1 => run_h1(&ctx, batch_size, shares, pis, zs).await, + Role::H2 => run_h2(&ctx, batch_size, shares, pis, zs).await, + Role::H3 => run_h3(&ctx, batch_size, shares, pis, zs).await, + } +} + +async fn run_h1( + ctx: &C, + batch_size: u32, + (a, b): (Sl, Sr), + (pi_31, pi_12): (Vec, Vec), + (z_31, z_12): (Zl, Zr), +) -> Result, Error> +where + C: Context, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, +{ + // 1. Generate helper-specific random tables + let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); + let a_hat: Vec<_> = + generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Left).collect(); + + let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); + let b_hat: Vec<_> = + generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect(); + + // 2. Run computations + let mut x_1: Vec = add_single_shares(add_single_shares(a, b), z_12).collect(); + apply_permutation(&pi_12, &mut x_1); + + let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); + apply_permutation(&pi_31, &mut x_2); + + send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; + + let res = combine_shares(a_hat, b_hat); + Ok(res) +} + +async fn run_h2( + ctx: &C, + batch_size: u32, + (_b, c): (Sl, Sr), + (pi_12, pi_23): (Vec, Vec), + (z_12, z_23): (Zl, Zr), +) -> Result, Error> +where + C: Context, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, +{ + // 1. Generate helper-specific random tables + let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); + let b_hat: Vec<_> = + generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Left).collect(); + + // 2. Run computations + let mut y_1: Vec = add_single_shares(c, z_12.into_iter()).collect(); + apply_permutation(&pi_12, &mut y_1); + + let ((), x_2) = future::try_join( + send_to_peer(ctx, &OPRFShuffleStep::TransferY1, Direction::Right, y_1), + receive_from_peer( + ctx, + &OPRFShuffleStep::TransferX2, + Direction::Left, + batch_size, + ), + ) + .await?; + + let mut x_3: Vec<_> = add_single_shares(x_2.into_iter(), z_23.into_iter()).collect(); + apply_permutation(&pi_23, &mut x_3); + + let c_hat_1: Vec<_> = add_single_shares(x_3.iter(), b_hat.iter()).collect(); + let ((), c_hat_2) = future::try_join( + send_to_peer( + ctx, + &OPRFShuffleStep::TransferCHat, + Direction::Right, + c_hat_1.clone(), + ), + receive_from_peer( + ctx, + &OPRFShuffleStep::TransferCHat, + Direction::Right, + batch_size, + ), + ) + .await?; + + let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); + let res = combine_shares(b_hat, c_hat); + Ok(res) +} + +async fn run_h3( + ctx: &C, + batch_size: u32, + (_c, _a): (Sl, Sr), + (pi_23, pi_31): (Vec, Vec), + (z_23, z_31): (Zl, Zr), +) -> Result, Error> +where + C: Context, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, +{ + // 1. Generate helper-specific random tables + let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); + let a_hat: Vec<_> = + generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Right).collect(); + + // 2. Run computations + let y_1 = receive_from_peer( + ctx, + &OPRFShuffleStep::TransferY1, + Direction::Left, + batch_size, + ) + .await?; + + let mut y_2: Vec = add_single_shares(y_1, z_31).collect(); + apply_permutation(&pi_31, &mut y_2); + + let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); + apply_permutation(&pi_23, &mut y_3); + + let c_hat_2 = add_single_shares(y_3.iter(), a_hat.iter()).collect::>(); + let ((), c_hat_1) = future::try_join( + send_to_peer( + ctx, + &OPRFShuffleStep::TransferCHat, + Direction::Left, + c_hat_2.clone(), + ), + receive_from_peer( + ctx, + &OPRFShuffleStep::TransferCHat, + Direction::Left, + batch_size, + ), + ) + .await?; + + let c_hat = add_single_shares(c_hat_1, c_hat_2); + let res = combine_shares(c_hat, a_hat); + Ok(res) +} + +// --------------------------------------------------------------------------- // + +fn split_shares( + input_rows: &[ShuffleInputRow], + direction: Direction, +) -> impl Iterator + '_ { + let f = move |input_row| ShuffleShare::from_input_row(input_row, direction); + input_rows.iter().map(f) +} + +fn combine_shares(l: L, r: R) -> Vec +where + L: IntoIterator, + R: IntoIterator, +{ + l.into_iter() + .zip(r) + .map(|(l, r)| l.to_input_row(r)) + .collect::>() +} + +fn add_single_shares<'i, T, L, R>(l: L, r: R) -> impl Iterator + 'i +where + T: Add, + L: IntoIterator + 'i, + R: IntoIterator + 'i, +{ + l.into_iter().zip(r).map(|(a, b)| a + b) +} + +// --------------------------------------------------------------------------- // + +fn generate_random_tables_with_peers( + batch_size: u32, + narrow_ctx: &C, +) -> ( + impl Iterator + '_, + impl Iterator + '_, +) { + let (rng_l, rng_r) = narrow_ctx.prss_rng(); + let with_left = sample_iter(rng_l).take(batch_size as usize); + let with_right = sample_iter(rng_r).take(batch_size as usize); + (with_left, with_right) +} + +fn generate_random_table_solo( + batch_size: u32, + narrow_ctx: &C, + peer: Direction, +) -> impl Iterator + '_ +where + C: Context, +{ + let rngs = narrow_ctx.prss_rng(); + let rng = match peer { + Direction::Left => rngs.0, + Direction::Right => rngs.1, + }; + + sample_iter(rng).take(batch_size as usize) +} + +fn sample_iter(rng: R) -> impl Iterator { + rng.sample_iter(Standard) +} + +// ---------------------------- helper communication ------------------------------------ // + +async fn send_to_peer>( + ctx: &C, + step: &OPRFShuffleStep, + direction: Direction, + items: I, +) -> Result<(), Error> { + let role = ctx.role().peer(direction); + let send_channel = ctx.narrow(step).send_channel(role); + for (record_id, row) in items.into_iter().enumerate() { + send_channel.send(RecordId::from(record_id), row).await?; + } + Ok(()) +} + +async fn receive_from_peer( + ctx: &C, + step: &OPRFShuffleStep, + direction: Direction, + batch_size: u32, +) -> Result, Error> { + let role = ctx.role().peer(direction); + let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); + + let mut output: Vec = Vec::with_capacity(batch_size as usize); + for record_id in 0..batch_size { + let msg = receive_channel.receive(RecordId::from(record_id)).await?; + output.push(msg); + } + + Ok(output) +} + +// ------------------ Pseudorandom permutations functions -------------------- // + +fn generate_permutations_with_peers(batch_size: u32, ctx: &C) -> (Vec, Vec) { + let narrow_context = ctx.narrow(&OPRFShuffleStep::GeneratePi); + let mut rng = narrow_context.prss_rng(); + + let with_left = generate_pseudorandom_permutation(batch_size, &mut rng.0); + let with_right = generate_pseudorandom_permutation(batch_size, &mut rng.1); + (with_left, with_right) +} + +fn generate_pseudorandom_permutation(batch_size: u32, rng: &mut R) -> Vec { + let mut permutation = (0..batch_size).collect::>(); + permutation.shuffle(rng); + permutation +} diff --git a/src/protocol/oprf/oprf_share.rs b/src/protocol/oprf/shuffle/share.rs similarity index 55% rename from src/protocol/oprf/oprf_share.rs rename to src/protocol/oprf/shuffle/share.rs index d864ea9b9..a38e2fd2e 100644 --- a/src/protocol/oprf/oprf_share.rs +++ b/src/protocol/oprf/shuffle/share.rs @@ -4,27 +4,27 @@ use generic_array::GenericArray; use rand::{distributions::Standard, Rng}; use typenum::Unsigned; -use super::OPRFInputRow; +use super::ShuffleInputRow; use crate::{ ff::{Field, Gf32Bit, Gf40Bit, Gf8Bit, Serializable}, helpers::{Direction, Message}, }; -pub type OprfMK = Gf40Bit; -pub type OprfBK = Gf8Bit; -pub type OprfF = Gf32Bit; +pub type ShuffleShareMK = Gf40Bit; +pub type ShuffleShareBK = Gf8Bit; +pub type ShuffleShareF = Gf32Bit; #[derive(Debug, Clone, Copy)] -pub struct OPRFShare { - pub timestamp: OprfF, - pub mk: OprfMK, - pub is_trigger_bit: OprfF, - pub breakdown_key: OprfBK, - pub trigger_value: OprfF, +pub struct ShuffleShare { + pub timestamp: ShuffleShareF, + pub mk: ShuffleShareMK, + pub is_trigger_bit: ShuffleShareF, + pub breakdown_key: ShuffleShareBK, + pub trigger_value: ShuffleShareF, } -impl OPRFShare { +impl ShuffleShare { #[must_use] - pub fn from_input_row(input_row: &OPRFInputRow, shared_with: Direction) -> Self { + pub fn from_input_row(input_row: &ShuffleInputRow, shared_with: Direction) -> Self { match shared_with { Direction::Left => Self { timestamp: input_row.timestamp.as_tuple().1, @@ -45,8 +45,8 @@ impl OPRFShare { } #[must_use] - pub fn to_input_row(self, rhs: Self) -> OPRFInputRow { - OPRFInputRow { + pub fn to_input_row(self, rhs: Self) -> ShuffleInputRow { + ShuffleInputRow { timestamp: (self.timestamp, rhs.timestamp).into(), mk_shares: (self.mk, rhs.mk).into(), is_trigger_bit: (self.is_trigger_bit, rhs.is_trigger_bit).into(), @@ -56,19 +56,19 @@ impl OPRFShare { } } -impl rand::prelude::Distribution for Standard { - fn sample(&self, rng: &mut R) -> OPRFShare { - OPRFShare { - timestamp: OprfF::truncate_from(rng.gen::()), - mk: OprfMK::truncate_from(rng.gen::()), - is_trigger_bit: OprfF::truncate_from(rng.gen::()), - breakdown_key: OprfBK::truncate_from(rng.gen::()), - trigger_value: OprfF::truncate_from(rng.gen::()), +impl rand::prelude::Distribution for Standard { + fn sample(&self, rng: &mut R) -> ShuffleShare { + ShuffleShare { + timestamp: ShuffleShareF::truncate_from(rng.gen::()), + mk: ShuffleShareMK::truncate_from(rng.gen::()), + is_trigger_bit: ShuffleShareF::truncate_from(rng.gen::()), + breakdown_key: ShuffleShareBK::truncate_from(rng.gen::()), + trigger_value: ShuffleShareF::truncate_from(rng.gen::()), } } } -impl Add for OPRFShare { +impl Add for ShuffleShare { type Output = Self; fn add(self, rhs: Self) -> Self::Output { @@ -76,10 +76,10 @@ impl Add for OPRFShare { } } -impl<'a, 'b> Add<&'b OPRFShare> for &'a OPRFShare { - type Output = OPRFShare; +impl<'a, 'b> Add<&'b ShuffleShare> for &'a ShuffleShare { + type Output = ShuffleShare; - fn add(self, rhs: &'b OPRFShare) -> Self::Output { + fn add(self, rhs: &'b ShuffleShare) -> Self::Output { Self::Output { timestamp: self.timestamp + rhs.timestamp, mk: self.mk + rhs.mk, @@ -90,19 +90,21 @@ impl<'a, 'b> Add<&'b OPRFShare> for &'a OPRFShare { } } -impl Serializable for OPRFShare { - type Size = <::Size as Add< - <::Size as Add< - <::Size as Add< - <::Size as Add<::Size>>::Output, +impl Serializable for ShuffleShare { + type Size = <::Size as Add< + <::Size as Add< + <::Size as Add< + <::Size as Add< + ::Size, + >>::Output, >>::Output, >>::Output, >>::Output; fn serialize(&self, buf: &mut GenericArray) { - let mk_sz = ::Size::USIZE; - let bk_sz = ::Size::USIZE; - let f_sz = ::Size::USIZE; + let mk_sz = ::Size::USIZE; + let bk_sz = ::Size::USIZE; + let f_sz = ::Size::USIZE; self.timestamp .serialize(GenericArray::from_mut_slice(&mut buf[..f_sz])); @@ -120,19 +122,19 @@ impl Serializable for OPRFShare { } fn deserialize(buf: &GenericArray) -> Self { - let mk_sz = ::Size::USIZE; - let bk_sz = ::Size::USIZE; - let f_sz = ::Size::USIZE; + let mk_sz = ::Size::USIZE; + let bk_sz = ::Size::USIZE; + let f_sz = ::Size::USIZE; - let timestamp = OprfF::deserialize(GenericArray::from_slice(&buf[..f_sz])); - let mk = OprfMK::deserialize(GenericArray::from_slice(&buf[f_sz..f_sz + mk_sz])); - let is_trigger_bit = OprfF::deserialize(GenericArray::from_slice( + let timestamp = ShuffleShareF::deserialize(GenericArray::from_slice(&buf[..f_sz])); + let mk = ShuffleShareMK::deserialize(GenericArray::from_slice(&buf[f_sz..f_sz + mk_sz])); + let is_trigger_bit = ShuffleShareF::deserialize(GenericArray::from_slice( &buf[f_sz + mk_sz..f_sz + mk_sz + f_sz], )); - let breakdown_key = OprfBK::deserialize(GenericArray::from_slice( + let breakdown_key = ShuffleShareBK::deserialize(GenericArray::from_slice( &buf[f_sz + mk_sz + f_sz..f_sz + mk_sz + f_sz + bk_sz], )); - let trigger_value = OprfF::deserialize(GenericArray::from_slice( + let trigger_value = ShuffleShareF::deserialize(GenericArray::from_slice( &buf[f_sz + mk_sz + f_sz + bk_sz..], )); Self { @@ -145,4 +147,4 @@ impl Serializable for OPRFShare { } } -impl Message for OPRFShare {} +impl Message for ShuffleShare {} diff --git a/src/query/runner/oprf_shuffle.rs b/src/query/runner/oprf_shuffle.rs index 515d5de86..4d347a6ab 100644 --- a/src/query/runner/oprf_shuffle.rs +++ b/src/query/runner/oprf_shuffle.rs @@ -9,7 +9,7 @@ use crate::{ one_off_fns::assert_stream_send, protocol::{ context::Context, - oprf::{oprf_shuffle, OPRFInputRow}, + oprf::shuffle::{shuffle, ShuffleInputRow}, }, }; @@ -28,12 +28,12 @@ impl OPRFShuffleQuery { ctx: C, query_size: QuerySize, input_stream: BodyStream, - ) -> Result, Error> { - let input: Vec = - assert_stream_send(RecordsStream::::new(input_stream)) + ) -> Result, Error> { + let input: Vec = + assert_stream_send(RecordsStream::::new(input_stream)) .try_concat() .await?; - oprf_shuffle(ctx, input.as_slice(), self.config).await + shuffle(ctx, input.as_slice(), self.config).await } } From cd5987237b66e50258c2f348305b905bfad53fdf Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Fri, 20 Oct 2023 15:26:39 -0500 Subject: [PATCH 08/15] [oprf][shuffle] Made apply permutation module from sort protocoal an individual basic protocol Moved protocols::sort::apply to protocols::basic:apply_permutation --- .../apply.rs => basics/apply_permutation.rs} | 0 src/protocol/basics/mod.rs | 1 + src/protocol/oprf/shuffle/mod.rs | 3 ++- src/protocol/sort/apply_sort/mod.rs | 7 ++----- src/protocol/sort/apply_sort/shuffle.rs | 6 ++++-- src/protocol/sort/compose.rs | 10 ++++------ src/protocol/sort/mod.rs | 1 - src/protocol/sort/secureapplyinv.rs | 10 ++++------ src/protocol/sort/shuffle.rs | 15 ++++++++++----- 9 files changed, 27 insertions(+), 26 deletions(-) rename src/protocol/{sort/apply.rs => basics/apply_permutation.rs} (100%) diff --git a/src/protocol/sort/apply.rs b/src/protocol/basics/apply_permutation.rs similarity index 100% rename from src/protocol/sort/apply.rs rename to src/protocol/basics/apply_permutation.rs diff --git a/src/protocol/basics/mod.rs b/src/protocol/basics/mod.rs index 4e3def06b..a5d0f58a9 100644 --- a/src/protocol/basics/mod.rs +++ b/src/protocol/basics/mod.rs @@ -1,3 +1,4 @@ +pub mod apply_permutation; pub mod check_zero; mod if_else; pub(crate) mod mul; diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index bdf872cc3..1dfe0b1a5 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -8,7 +8,8 @@ use rand::{distributions::Standard, seq::SliceRandom, Rng}; use self::share::{ShuffleShare, ShuffleShareBK, ShuffleShareF, ShuffleShareMK}; use super::super::{ - context::Context, ipa::IPAInputRow, sort::apply::apply as apply_permutation, RecordId, + basics::apply_permutation::apply as apply_permutation, context::Context, ipa::IPAInputRow, + RecordId, }; use crate::{ error::Error, diff --git a/src/protocol/sort/apply_sort/mod.rs b/src/protocol/sort/apply_sort/mod.rs index 10a783c09..131c0db46 100644 --- a/src/protocol/sort/apply_sort/mod.rs +++ b/src/protocol/sort/apply_sort/mod.rs @@ -5,12 +5,9 @@ pub use shuffle::shuffle_shares; use crate::{ error::Error, protocol::{ - basics::Reshare, + basics::{apply_permutation::apply_inv, Reshare}, context::Context, - sort::{ - apply::apply_inv, generate_permutation::RevealedAndRandomPermutations, - ApplyInvStep::ShuffleInputs, - }, + sort::{generate_permutation::RevealedAndRandomPermutations, ApplyInvStep::ShuffleInputs}, RecordId, }, }; diff --git a/src/protocol/sort/apply_sort/shuffle.rs b/src/protocol/sort/apply_sort/shuffle.rs index 9268e40dd..22a354cec 100644 --- a/src/protocol/sort/apply_sort/shuffle.rs +++ b/src/protocol/sort/apply_sort/shuffle.rs @@ -5,10 +5,12 @@ use crate::{ error::Error, helpers::Direction, protocol::{ - basics::Reshare, + basics::{ + apply_permutation::{apply, apply_inv}, + Reshare, + }, context::Context, sort::{ - apply::{apply, apply_inv}, shuffle::{shuffle_for_helper, ShuffleOrUnshuffle}, ShuffleStep::{self, Shuffle1, Shuffle2, Shuffle3}, }, diff --git a/src/protocol/sort/compose.rs b/src/protocol/sort/compose.rs index 2a858af55..0fe7f73d9 100644 --- a/src/protocol/sort/compose.rs +++ b/src/protocol/sort/compose.rs @@ -4,9 +4,9 @@ use crate::{ error::Error, ff::Field, protocol::{ - basics::Reshare, + basics::{apply_permutation::apply, Reshare}, context::Context, - sort::{apply::apply, shuffle::unshuffle_shares, ComposeStep::UnshuffleRho}, + sort::{shuffle::unshuffle_shares, ComposeStep::UnshuffleRho}, RecordId, }, secret_sharing::SecretSharing, @@ -59,11 +59,9 @@ mod tests { use crate::{ ff::{Field, Fp31}, protocol::{ + basics::apply_permutation::apply, context::{Context, SemiHonestContext, UpgradableContext, Validator}, - sort::{ - apply::apply, compose::compose, - generate_permutation::shuffle_and_reveal_permutation, - }, + sort::{compose::compose, generate_permutation::shuffle_and_reveal_permutation}, }, rand::thread_rng, test_fixture::{Reconstruct, Runner, TestWorld}, diff --git a/src/protocol/sort/mod.rs b/src/protocol/sort/mod.rs index 4908a76bf..c285af1b6 100644 --- a/src/protocol/sort/mod.rs +++ b/src/protocol/sort/mod.rs @@ -1,4 +1,3 @@ -pub mod apply; pub mod apply_sort; pub mod bit_permutation; pub mod generate_permutation; diff --git a/src/protocol/sort/secureapplyinv.rs b/src/protocol/sort/secureapplyinv.rs index 8cfb8aca9..ebced0c8f 100644 --- a/src/protocol/sort/secureapplyinv.rs +++ b/src/protocol/sort/secureapplyinv.rs @@ -1,12 +1,9 @@ use crate::{ error::Error, protocol::{ - basics::Reshare, + basics::{apply_permutation::apply_inv, Reshare}, context::Context, - sort::{ - apply::apply_inv, apply_sort::shuffle_shares as shuffle_vectors, - ApplyInvStep::ShuffleInputs, - }, + sort::{apply_sort::shuffle_shares as shuffle_vectors, ApplyInvStep::ShuffleInputs}, RecordId, }, }; @@ -38,9 +35,10 @@ mod tests { use crate::{ ff::{Field, Fp31}, protocol::{ + basics::apply_permutation::apply_inv, context::{Context, SemiHonestContext, UpgradableContext, Validator}, sort::{ - apply::apply_inv, generate_permutation::shuffle_and_reveal_permutation, + generate_permutation::shuffle_and_reveal_permutation, secureapplyinv::secureapplyinv_multi, }, }, diff --git a/src/protocol/sort/shuffle.rs b/src/protocol/sort/shuffle.rs index feff12ee1..d500bf0a2 100644 --- a/src/protocol/sort/shuffle.rs +++ b/src/protocol/sort/shuffle.rs @@ -1,15 +1,20 @@ use embed_doc_image::embed_doc_image; use rand::{seq::SliceRandom, Rng}; -use super::{ - apply::{apply, apply_inv}, - ShuffleStep::{self, Shuffle1, Shuffle2, Shuffle3}, -}; +use super::ShuffleStep::{self, Shuffle1, Shuffle2, Shuffle3}; use crate::{ error::Error, ff::Field, helpers::{Direction, Role}, - protocol::{basics::Reshare, context::Context, step::Step, NoRecord, RecordId}, + protocol::{ + basics::{ + apply_permutation::{apply, apply_inv}, + Reshare, + }, + context::Context, + step::Step, + NoRecord, RecordId, + }, secret_sharing::SecretSharing, }; From 9043ec33d4336d973662e2af992d7e009d70d32a Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Fri, 20 Oct 2023 15:59:38 -0500 Subject: [PATCH 09/15] [oprf][shuffle] Move share splitting and combining logic into runner module This is as prep for making oprf/shuffle protocol generic --- src/protocol/oprf/shuffle/mod.rs | 67 +++++++++----------------------- src/query/runner/oprf_shuffle.rs | 38 ++++++++++++++++-- 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index 1dfe0b1a5..c00023ae0 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -31,23 +31,17 @@ pub(crate) enum OPRFShuffleStep { /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn shuffle( - ctx: C, - input_rows: &[ShuffleInputRow], +pub async fn shuffle( _config: QueryConfig, -) -> Result, Error> { - let batch_size = u32::try_from(input_rows.len()).map_err(|_e| { - Error::FieldValueTruncation(format!( - "Cannot truncate the number of input rows {} to u32", - input_rows.len(), - )) - })?; - - let shares = ( - split_shares(input_rows, Direction::Left), - split_shares(input_rows, Direction::Right), - ); - + ctx: C, + batch_size: u32, + shares: (Sl, Sr), +) -> Result<(Vec, Vec), Error> +where + C: Context, + Sl: IntoIterator, + Sr: IntoIterator, +{ // 1. Generate permutations let pis = generate_permutations_with_peers(batch_size, &ctx); @@ -68,7 +62,7 @@ async fn run_h1( (a, b): (Sl, Sr), (pi_31, pi_12): (Vec, Vec), (z_31, z_12): (Zl, Zr), -) -> Result, Error> +) -> Result<(Vec, Vec), Error> where C: Context, Sl: IntoIterator, @@ -94,8 +88,7 @@ where send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; - let res = combine_shares(a_hat, b_hat); - Ok(res) + Ok((a_hat, b_hat)) } async fn run_h2( @@ -104,7 +97,7 @@ async fn run_h2( (_b, c): (Sl, Sr), (pi_12, pi_23): (Vec, Vec), (z_12, z_23): (Zl, Zr), -) -> Result, Error> +) -> Result<(Vec, Vec), Error> where C: Context, Sl: IntoIterator, @@ -152,9 +145,8 @@ where ) .await?; - let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); - let res = combine_shares(b_hat, c_hat); - Ok(res) + let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()).collect(); + Ok((b_hat, c_hat)) } async fn run_h3( @@ -163,7 +155,7 @@ async fn run_h3( (_c, _a): (Sl, Sr), (pi_23, pi_31): (Vec, Vec), (z_23, z_31): (Zl, Zr), -) -> Result, Error> +) -> Result<(Vec, Vec), Error> where C: Context, Sl: IntoIterator, @@ -208,30 +200,8 @@ where ) .await?; - let c_hat = add_single_shares(c_hat_1, c_hat_2); - let res = combine_shares(c_hat, a_hat); - Ok(res) -} - -// --------------------------------------------------------------------------- // - -fn split_shares( - input_rows: &[ShuffleInputRow], - direction: Direction, -) -> impl Iterator + '_ { - let f = move |input_row| ShuffleShare::from_input_row(input_row, direction); - input_rows.iter().map(f) -} - -fn combine_shares(l: L, r: R) -> Vec -where - L: IntoIterator, - R: IntoIterator, -{ - l.into_iter() - .zip(r) - .map(|(l, r)| l.to_input_row(r)) - .collect::>() + let c_hat = add_single_shares(c_hat_1, c_hat_2).collect(); + Ok((c_hat, a_hat)) } fn add_single_shares<'i, T, L, R>(l: L, r: R) -> impl Iterator + 'i @@ -242,7 +212,6 @@ where { l.into_iter().zip(r).map(|(a, b)| a + b) } - // --------------------------------------------------------------------------- // fn generate_random_tables_with_peers( diff --git a/src/query/runner/oprf_shuffle.rs b/src/query/runner/oprf_shuffle.rs index 4d347a6ab..369aefe71 100644 --- a/src/query/runner/oprf_shuffle.rs +++ b/src/query/runner/oprf_shuffle.rs @@ -4,12 +4,12 @@ use crate::{ error::Error, helpers::{ query::{oprf_shuffle, QuerySize}, - BodyStream, RecordsStream, + BodyStream, Direction, RecordsStream, }, one_off_fns::assert_stream_send, protocol::{ context::Context, - oprf::shuffle::{shuffle, ShuffleInputRow}, + oprf::shuffle::{share::ShuffleShare, shuffle, ShuffleInputRow}, }, }; @@ -34,6 +34,38 @@ impl OPRFShuffleQuery { .try_concat() .await?; - shuffle(ctx, input.as_slice(), self.config).await + let batch_size = u32::try_from(input.len()).map_err(|_e| { + Error::FieldValueTruncation(format!( + "Cannot truncate the number of input rows {} to u32", + input.len(), + )) + })?; + + let shares = ( + split_shares(input.as_slice(), Direction::Left), + split_shares(input.as_slice(), Direction::Right), + ); + + let (res_l, res_r) = shuffle(self.config, ctx, batch_size, shares).await?; + Ok(combine_shares(res_l, res_r)) } } + +fn split_shares( + input_rows: &[ShuffleInputRow], + direction: Direction, +) -> impl Iterator + '_ { + let f = move |input_row| ShuffleShare::from_input_row(input_row, direction); + input_rows.iter().map(f) +} + +fn combine_shares(l: L, r: R) -> Vec +where + L: IntoIterator, + R: IntoIterator, +{ + l.into_iter() + .zip(r) + .map(|(l, r)| l.to_input_row(r)) + .collect::>() +} From adc46da4e5339fc3080e1d57f90410c98f5d751a Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Fri, 20 Oct 2023 17:49:42 -0500 Subject: [PATCH 10/15] [oprf][shuffle] Generalized shuffle function across any inputs satisfying some trait bounds the shuffle, run_hN functions do not depend on a particular format of input rows. Instead they can receiv IntoIterators consuming anything satisfying a few bounds I have also moved the logic of dealing with specific inputs (deserialize, serialize, split into individual shares, combine back) into a submodule of queiry::runner::oprf, as I expect all this code to be thrown away/reworked, when the actual formats of inputs/outputs are more clear. But for now this change should let me focus on writing unit tests for the protocol, that should remain relevant even if the data input/output formats change --- src/protocol/oprf/shuffle/mod.rs | 159 ++++++++++-------- src/query/runner/mod.rs | 2 +- src/query/runner/oprf_shuffle/mod.rs | 6 + .../query.rs} | 12 +- .../runner/oprf_shuffle}/share.rs | 7 + 5 files changed, 107 insertions(+), 79 deletions(-) create mode 100644 src/query/runner/oprf_shuffle/mod.rs rename src/query/runner/{oprf_shuffle.rs => oprf_shuffle/query.rs} (86%) rename src/{protocol/oprf/shuffle => query/runner/oprf_shuffle}/share.rs (96%) diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index c00023ae0..29faff337 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -1,23 +1,17 @@ -pub mod share; - use std::ops::Add; use futures::future; use ipa_macros::Step; -use rand::{distributions::Standard, seq::SliceRandom, Rng}; +use rand::{distributions::Standard, prelude::Distribution, seq::SliceRandom, Rng}; -use self::share::{ShuffleShare, ShuffleShareBK, ShuffleShareF, ShuffleShareMK}; use super::super::{ - basics::apply_permutation::apply as apply_permutation, context::Context, ipa::IPAInputRow, - RecordId, + basics::apply_permutation::apply as apply_permutation, context::Context, RecordId, }; use crate::{ error::Error, - helpers::{query::oprf_shuffle::QueryConfig, Direction, ReceivingEnd, Role}, + helpers::{Direction, Message, ReceivingEnd, Role}, }; -pub type ShuffleInputRow = IPAInputRow; - #[derive(Step)] pub(crate) enum OPRFShuffleStep { GenerateAHat, @@ -31,16 +25,19 @@ pub(crate) enum OPRFShuffleStep { /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn shuffle( - _config: QueryConfig, +pub async fn shuffle( ctx: C, batch_size: u32, shares: (Sl, Sr), -) -> Result<(Vec, Vec), Error> +) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + S: Clone + Add + Message, + for<'b> &'b S: Add, + for<'b> &'b S: Add<&'b S, Output = S>, + Standard: Distribution, { // 1. Generate permutations let pis = generate_permutations_with_peers(batch_size, &ctx); @@ -56,19 +53,22 @@ where } } -async fn run_h1( +async fn run_h1( ctx: &C, batch_size: u32, (a, b): (Sl, Sr), (pi_31, pi_12): (Vec, Vec), (z_31, z_12): (Zl, Zr), -) -> Result<(Vec, Vec), Error> +) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, + S: Clone + Add + Message, + for<'a> &'a S: Add, + Standard: Distribution, { // 1. Generate helper-specific random tables let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); @@ -80,10 +80,10 @@ where generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect(); // 2. Run computations - let mut x_1: Vec = add_single_shares(add_single_shares(a, b), z_12).collect(); + let mut x_1: Vec = add_single_shares(add_single_shares(a, b), z_12).collect(); apply_permutation(&pi_12, &mut x_1); - let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); + let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); apply_permutation(&pi_31, &mut x_2); send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; @@ -91,30 +91,34 @@ where Ok((a_hat, b_hat)) } -async fn run_h2( +async fn run_h2( ctx: &C, batch_size: u32, (_b, c): (Sl, Sr), (pi_12, pi_23): (Vec, Vec), (z_12, z_23): (Zl, Zr), -) -> Result<(Vec, Vec), Error> +) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, + S: Clone + Add + Message, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + Standard: Distribution, { // 1. Generate helper-specific random tables let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); - let b_hat: Vec<_> = + let b_hat: Vec = generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Left).collect(); // 2. Run computations - let mut y_1: Vec = add_single_shares(c, z_12.into_iter()).collect(); + let mut y_1: Vec = add_single_shares(c, z_12).collect(); apply_permutation(&pi_12, &mut y_1); - let ((), x_2) = future::try_join( + let ((), x_2): ((), Vec) = future::try_join( send_to_peer(ctx, &OPRFShuffleStep::TransferY1, Direction::Right, y_1), receive_from_peer( ctx, @@ -125,10 +129,10 @@ where ) .await?; - let mut x_3: Vec<_> = add_single_shares(x_2.into_iter(), z_23.into_iter()).collect(); + let mut x_3: Vec = add_single_shares(x_2.iter(), z_23).collect(); apply_permutation(&pi_23, &mut x_3); - let c_hat_1: Vec<_> = add_single_shares(x_3.iter(), b_hat.iter()).collect(); + let c_hat_1: Vec = add_single_shares(x_3.iter(), b_hat.iter()).collect(); let ((), c_hat_2) = future::try_join( send_to_peer( ctx, @@ -149,27 +153,31 @@ where Ok((b_hat, c_hat)) } -async fn run_h3( +async fn run_h3( ctx: &C, batch_size: u32, (_c, _a): (Sl, Sr), (pi_23, pi_31): (Vec, Vec), (z_23, z_31): (Zl, Zr), -) -> Result<(Vec, Vec), Error> +) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - Zl: IntoIterator, - Zr: IntoIterator, + Sl: IntoIterator, + Sr: IntoIterator, + Zl: IntoIterator, + Zr: IntoIterator, + S: Clone + Add + Message, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, + Standard: Distribution, { // 1. Generate helper-specific random tables let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); - let a_hat: Vec<_> = + let a_hat: Vec = generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Right).collect(); // 2. Run computations - let y_1 = receive_from_peer( + let y_1: Vec = receive_from_peer( ctx, &OPRFShuffleStep::TransferY1, Direction::Left, @@ -177,14 +185,14 @@ where ) .await?; - let mut y_2: Vec = add_single_shares(y_1, z_31).collect(); + let mut y_2: Vec = add_single_shares(y_1, z_31).collect(); apply_permutation(&pi_31, &mut y_2); - let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); + let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); apply_permutation(&pi_23, &mut y_3); - let c_hat_2 = add_single_shares(y_3.iter(), a_hat.iter()).collect::>(); - let ((), c_hat_1) = future::try_join( + let c_hat_2: Vec = add_single_shares(y_3.iter(), a_hat.iter()).collect(); + let ((), c_hat_1): ((), Vec) = future::try_join( send_to_peer( ctx, &OPRFShuffleStep::TransferCHat, @@ -204,36 +212,40 @@ where Ok((c_hat, a_hat)) } -fn add_single_shares<'i, T, L, R>(l: L, r: R) -> impl Iterator + 'i +fn add_single_shares(l: L, r: R) -> impl Iterator where - T: Add, - L: IntoIterator + 'i, - R: IntoIterator + 'i, + A: Add, + L: IntoIterator, + R: IntoIterator, { l.into_iter().zip(r).map(|(a, b)| a + b) } // --------------------------------------------------------------------------- // -fn generate_random_tables_with_peers( +fn generate_random_tables_with_peers<'a, C, S>( batch_size: u32, - narrow_ctx: &C, -) -> ( - impl Iterator + '_, - impl Iterator + '_, -) { + narrow_ctx: &'a C, +) -> (impl Iterator + 'a, impl Iterator + 'a) +where + C: Context, + Standard: Distribution, + S: 'a, +{ let (rng_l, rng_r) = narrow_ctx.prss_rng(); - let with_left = sample_iter(rng_l).take(batch_size as usize); - let with_right = sample_iter(rng_r).take(batch_size as usize); + let with_left = rng_l.sample_iter(Standard).take(batch_size as usize); + let with_right = rng_r.sample_iter(Standard).take(batch_size as usize); (with_left, with_right) } -fn generate_random_table_solo( +fn generate_random_table_solo<'a, C, S>( batch_size: u32, - narrow_ctx: &C, + narrow_ctx: &'a C, peer: Direction, -) -> impl Iterator + '_ +) -> impl Iterator + 'a where C: Context, + Standard: Distribution, + S: 'a, { let rngs = narrow_ctx.prss_rng(); let rng = match peer { @@ -241,21 +253,22 @@ where Direction::Right => rngs.1, }; - sample_iter(rng).take(batch_size as usize) -} - -fn sample_iter(rng: R) -> impl Iterator { - rng.sample_iter(Standard) + rng.sample_iter(Standard).take(batch_size as usize) } // ---------------------------- helper communication ------------------------------------ // -async fn send_to_peer>( +async fn send_to_peer( ctx: &C, step: &OPRFShuffleStep, direction: Direction, items: I, -) -> Result<(), Error> { +) -> Result<(), Error> +where + C: Context, + I: IntoIterator, + S: Message, +{ let role = ctx.role().peer(direction); let send_channel = ctx.narrow(step).send_channel(role); for (record_id, row) in items.into_iter().enumerate() { @@ -264,16 +277,20 @@ async fn send_to_peer>( Ok(()) } -async fn receive_from_peer( +async fn receive_from_peer( ctx: &C, step: &OPRFShuffleStep, direction: Direction, batch_size: u32, -) -> Result, Error> { +) -> Result, Error> +where + C: Context, + S: Message, +{ let role = ctx.role().peer(direction); - let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); + let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); - let mut output: Vec = Vec::with_capacity(batch_size as usize); + let mut output: Vec = Vec::with_capacity(batch_size as usize); for record_id in 0..batch_size { let msg = receive_channel.receive(RecordId::from(record_id)).await?; output.push(msg); diff --git a/src/query/runner/mod.rs b/src/query/runner/mod.rs index 3ec724fa5..4c83930aa 100644 --- a/src/query/runner/mod.rs +++ b/src/query/runner/mod.rs @@ -9,7 +9,7 @@ mod test_multiply; pub(super) use test_multiply::execute_test_multiply; pub(super) use self::{ - aggregate::SparseAggregateQuery, ipa::IpaQuery, oprf_shuffle::OPRFShuffleQuery, + aggregate::SparseAggregateQuery, ipa::IpaQuery, oprf_shuffle::query::OPRFShuffleQuery, }; use crate::{error::Error, query::ProtocolResult}; diff --git a/src/query/runner/oprf_shuffle/mod.rs b/src/query/runner/oprf_shuffle/mod.rs new file mode 100644 index 000000000..5a8705b3c --- /dev/null +++ b/src/query/runner/oprf_shuffle/mod.rs @@ -0,0 +1,6 @@ +pub(super) mod query; +mod share; + +use self::share::{ShuffleShareBK, ShuffleShareF, ShuffleShareMK}; +use crate::protocol::ipa::IPAInputRow; +pub type ShuffleInputRow = IPAInputRow; diff --git a/src/query/runner/oprf_shuffle.rs b/src/query/runner/oprf_shuffle/query.rs similarity index 86% rename from src/query/runner/oprf_shuffle.rs rename to src/query/runner/oprf_shuffle/query.rs index 369aefe71..cdcac3a31 100644 --- a/src/query/runner/oprf_shuffle.rs +++ b/src/query/runner/oprf_shuffle/query.rs @@ -1,5 +1,6 @@ use futures::TryStreamExt; +use super::{share::ShuffleShare, ShuffleInputRow}; use crate::{ error::Error, helpers::{ @@ -7,19 +8,16 @@ use crate::{ BodyStream, Direction, RecordsStream, }, one_off_fns::assert_stream_send, - protocol::{ - context::Context, - oprf::shuffle::{share::ShuffleShare, shuffle, ShuffleInputRow}, - }, + protocol::{context::Context, oprf::shuffle::shuffle}, }; pub struct OPRFShuffleQuery { - config: oprf_shuffle::QueryConfig, + _config: oprf_shuffle::QueryConfig, } impl OPRFShuffleQuery { pub fn new(config: oprf_shuffle::QueryConfig) -> Self { - Self { config } + Self { _config: config } } #[tracing::instrument("ipa_query", skip_all, fields(sz=%query_size))] @@ -46,7 +44,7 @@ impl OPRFShuffleQuery { split_shares(input.as_slice(), Direction::Right), ); - let (res_l, res_r) = shuffle(self.config, ctx, batch_size, shares).await?; + let (res_l, res_r) = shuffle(ctx, batch_size, shares).await?; Ok(combine_shares(res_l, res_r)) } } diff --git a/src/protocol/oprf/shuffle/share.rs b/src/query/runner/oprf_shuffle/share.rs similarity index 96% rename from src/protocol/oprf/shuffle/share.rs rename to src/query/runner/oprf_shuffle/share.rs index a38e2fd2e..168691a40 100644 --- a/src/protocol/oprf/shuffle/share.rs +++ b/src/query/runner/oprf_shuffle/share.rs @@ -90,6 +90,13 @@ impl<'a, 'b> Add<&'b ShuffleShare> for &'a ShuffleShare { } } +impl<'a> Add for &'a ShuffleShare { + type Output = ShuffleShare; + + fn add(self, rhs: ShuffleShare) -> Self::Output { + Add::add(self, &rhs) + } +} impl Serializable for ShuffleShare { type Size = <::Size as Add< <::Size as Add< From 38ca599bfc7f46b5ced38e57f58052c1d65c17ee Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Mon, 23 Oct 2023 13:23:42 -0500 Subject: [PATCH 11/15] [oprf][shuffle] Reimplemented the input to shuffle function as an interator of ReplicatedSecretShares Instead of accepting a tuple of 2 iterators returning SharedValues accept a single iterator of AdditiveShares: ReplicatedSecretShares This should make interaction with TestWorld easier --- src/protocol/oprf/shuffle/mod.rs | 47 +++++----- src/query/runner/oprf_shuffle/query.rs | 17 ++-- src/query/runner/oprf_shuffle/share.rs | 94 ++++++++++++++++++- .../replicated/semi_honest/additive_share.rs | 1 - 4 files changed, 121 insertions(+), 38 deletions(-) diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index 29faff337..536cb7d14 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -10,6 +10,10 @@ use super::super::{ use crate::{ error::Error, helpers::{Direction, Message, ReceivingEnd, Role}, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + SharedValue, + }, }; #[derive(Step)] @@ -25,16 +29,11 @@ pub(crate) enum OPRFShuffleStep { /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn shuffle( - ctx: C, - batch_size: u32, - shares: (Sl, Sr), -) -> Result<(Vec, Vec), Error> +pub async fn shuffle(ctx: C, batch_size: u32, shares: I) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, - S: Clone + Add + Message, + I: IntoIterator>, + S: SharedValue + Add + Message, for<'b> &'b S: Add, for<'b> &'b S: Add<&'b S, Output = S>, Standard: Distribution, @@ -49,24 +48,23 @@ where match ctx.role() { Role::H1 => run_h1(&ctx, batch_size, shares, pis, zs).await, Role::H2 => run_h2(&ctx, batch_size, shares, pis, zs).await, - Role::H3 => run_h3(&ctx, batch_size, shares, pis, zs).await, + Role::H3 => run_h3(&ctx, batch_size, pis, zs).await, } } -async fn run_h1( +async fn run_h1( ctx: &C, batch_size: u32, - (a, b): (Sl, Sr), + shares: I, (pi_31, pi_12): (Vec, Vec), (z_31, z_12): (Zl, Zr), ) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, + I: IntoIterator>, + S: SharedValue + Add + Message, Zl: IntoIterator, Zr: IntoIterator, - S: Clone + Add + Message, for<'a> &'a S: Add, Standard: Distribution, { @@ -80,7 +78,10 @@ where generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect(); // 2. Run computations - let mut x_1: Vec = add_single_shares(add_single_shares(a, b), z_12).collect(); + let a_add_b_iter = shares + .into_iter() + .map(|s: AdditiveShare| s.left().add(s.right())); + let mut x_1: Vec = add_single_shares(a_add_b_iter, z_12).collect(); apply_permutation(&pi_12, &mut x_1); let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); @@ -91,20 +92,19 @@ where Ok((a_hat, b_hat)) } -async fn run_h2( +async fn run_h2( ctx: &C, batch_size: u32, - (_b, c): (Sl, Sr), + shares: I, (pi_12, pi_23): (Vec, Vec), (z_12, z_23): (Zl, Zr), ) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, + I: IntoIterator>, + S: SharedValue + Add + Message, Zl: IntoIterator, Zr: IntoIterator, - S: Clone + Add + Message, for<'a> &'a S: Add, for<'a> &'a S: Add<&'a S, Output = S>, Standard: Distribution, @@ -115,6 +115,7 @@ where generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Left).collect(); // 2. Run computations + let c = shares.into_iter().map(|s| s.right()); let mut y_1: Vec = add_single_shares(c, z_12).collect(); apply_permutation(&pi_12, &mut y_1); @@ -153,17 +154,15 @@ where Ok((b_hat, c_hat)) } -async fn run_h3( +async fn run_h3( ctx: &C, batch_size: u32, - (_c, _a): (Sl, Sr), (pi_23, pi_31): (Vec, Vec), (z_23, z_31): (Zl, Zr), ) -> Result<(Vec, Vec), Error> where C: Context, - Sl: IntoIterator, - Sr: IntoIterator, + S: SharedValue + Add + Message, Zl: IntoIterator, Zr: IntoIterator, S: Clone + Add + Message, diff --git a/src/query/runner/oprf_shuffle/query.rs b/src/query/runner/oprf_shuffle/query.rs index cdcac3a31..21365fc0b 100644 --- a/src/query/runner/oprf_shuffle/query.rs +++ b/src/query/runner/oprf_shuffle/query.rs @@ -9,6 +9,7 @@ use crate::{ }, one_off_fns::assert_stream_send, protocol::{context::Context, oprf::shuffle::shuffle}, + secret_sharing::replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, }; pub struct OPRFShuffleQuery { @@ -39,11 +40,7 @@ impl OPRFShuffleQuery { )) })?; - let shares = ( - split_shares(input.as_slice(), Direction::Left), - split_shares(input.as_slice(), Direction::Right), - ); - + let shares = split_shares(&input); let (res_l, res_r) = shuffle(ctx, batch_size, shares).await?; Ok(combine_shares(res_l, res_r)) } @@ -51,9 +48,13 @@ impl OPRFShuffleQuery { fn split_shares( input_rows: &[ShuffleInputRow], - direction: Direction, -) -> impl Iterator + '_ { - let f = move |input_row| ShuffleShare::from_input_row(input_row, direction); +) -> impl Iterator> + '_ { + let f = move |input_row| { + let l = ShuffleShare::from_input_row(input_row, Direction::Left); + let r = ShuffleShare::from_input_row(input_row, Direction::Right); + ReplicatedSecretSharing::new(l, r) + }; + input_rows.iter().map(f) } diff --git a/src/query/runner/oprf_shuffle/share.rs b/src/query/runner/oprf_shuffle/share.rs index 168691a40..8d1edac2b 100644 --- a/src/query/runner/oprf_shuffle/share.rs +++ b/src/query/runner/oprf_shuffle/share.rs @@ -1,4 +1,4 @@ -use std::ops::Add; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use generic_array::GenericArray; use rand::{distributions::Standard, Rng}; @@ -7,13 +7,14 @@ use typenum::Unsigned; use super::ShuffleInputRow; use crate::{ ff::{Field, Gf32Bit, Gf40Bit, Gf8Bit, Serializable}, - helpers::{Direction, Message}, + helpers::Direction, + secret_sharing::SharedValue, }; pub type ShuffleShareMK = Gf40Bit; pub type ShuffleShareBK = Gf8Bit; pub type ShuffleShareF = Gf32Bit; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ShuffleShare { pub timestamp: ShuffleShareF, pub mk: ShuffleShareMK, @@ -22,6 +23,91 @@ pub struct ShuffleShare { pub trigger_value: ShuffleShareF, } +impl AddAssign for ShuffleShare { + fn add_assign(&mut self, rhs: Self) { + self.timestamp += rhs.timestamp; + self.mk += rhs.mk; + self.is_trigger_bit += rhs.is_trigger_bit; + self.breakdown_key += rhs.breakdown_key; + self.trigger_value += rhs.trigger_value; + } +} +impl Sub for ShuffleShare { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self::Output { + timestamp: self.timestamp - rhs.timestamp, + mk: self.mk - rhs.mk, + is_trigger_bit: self.is_trigger_bit - rhs.is_trigger_bit, + breakdown_key: self.breakdown_key - rhs.breakdown_key, + trigger_value: self.trigger_value - rhs.trigger_value, + } + } +} +impl SubAssign for ShuffleShare { + fn sub_assign(&mut self, rhs: Self) { + self.timestamp -= rhs.timestamp; + self.mk -= rhs.mk; + self.is_trigger_bit -= rhs.is_trigger_bit; + self.breakdown_key -= rhs.breakdown_key; + self.trigger_value -= rhs.trigger_value; + } +} +impl Neg for ShuffleShare { + type Output = Self; + + fn neg(self) -> Self::Output { + Self::Output { + timestamp: self.timestamp.neg(), + mk: self.mk.neg(), + is_trigger_bit: self.is_trigger_bit.neg(), + breakdown_key: self.breakdown_key.neg(), + trigger_value: self.trigger_value.neg(), + } + } +} +impl Mul for ShuffleShare { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self::Output { + timestamp: self.timestamp * rhs.timestamp, + mk: self.mk * rhs.mk, + is_trigger_bit: self.is_trigger_bit * rhs.is_trigger_bit, + breakdown_key: self.breakdown_key * rhs.breakdown_key, + trigger_value: self.trigger_value * rhs.trigger_value, + } + } +} +impl MulAssign for ShuffleShare { + fn mul_assign(&mut self, rhs: Self) { + self.timestamp *= rhs.timestamp; + self.mk *= rhs.mk; + self.is_trigger_bit *= rhs.is_trigger_bit; + self.breakdown_key *= rhs.breakdown_key; + self.trigger_value *= rhs.trigger_value; + } +} + +impl SharedValue for ShuffleShare { + type Storage = ::Storage; + + const BITS: u32 = ShuffleShareF::BITS + + ShuffleShareMK::BITS + + ShuffleShareF::BITS + + ShuffleShareBK::BITS + + ShuffleShareF::BITS; + + const ZERO: Self = Self { + timestamp: ShuffleShareF::ZERO, + mk: ShuffleShareMK::ZERO, + is_trigger_bit: ShuffleShareF::ZERO, + breakdown_key: ShuffleShareBK::ZERO, + trigger_value: ShuffleShareF::ZERO, + }; +} + impl ShuffleShare { #[must_use] pub fn from_input_row(input_row: &ShuffleInputRow, shared_with: Direction) -> Self { @@ -153,5 +239,3 @@ impl Serializable for ShuffleShare { } } } - -impl Message for ShuffleShare {} diff --git a/src/secret_sharing/replicated/semi_honest/additive_share.rs b/src/secret_sharing/replicated/semi_honest/additive_share.rs index 10b8b2b39..b1eb2a850 100644 --- a/src/secret_sharing/replicated/semi_honest/additive_share.rs +++ b/src/secret_sharing/replicated/semi_honest/additive_share.rs @@ -20,7 +20,6 @@ pub struct AdditiveShare(V, V); impl SecretSharing for AdditiveShare { const ZERO: Self = AdditiveShare::ZERO; } - impl LinearSecretSharing for AdditiveShare {} impl Debug for AdditiveShare { From 2ea0ca1a3a33048b888cac61edee179eb5e13a7a Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Mon, 23 Oct 2023 14:43:33 -0500 Subject: [PATCH 12/15] [oprf][shuffle] Tests --- src/protocol/oprf/shuffle/mod.rs | 65 +++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index 536cb7d14..c5da6fd97 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -166,7 +166,7 @@ where Zl: IntoIterator, Zr: IntoIterator, S: Clone + Add + Message, - for<'a> &'a S: Add, + // for<'a> &'a S: Add, for<'a> &'a S: Add<&'a S, Output = S>, Standard: Distribution, { @@ -314,3 +314,66 @@ fn generate_pseudorandom_permutation(batch_size: u32, rng: &mut R) -> Ve permutation.shuffle(rng); permutation } + +#[cfg(all(test, any(unit_test, feature = "shuttle")))] +pub mod tests { + use std::ops::Add; + + use crate::secret_sharing::replicated::semi_honest::AdditiveShare; + use crate::secret_sharing::replicated::ReplicatedSecretSharing; + use crate::test_executor::run; + use crate::{ + ff::{Field, Gf40Bit}, + test_fixture::{Reconstruct, Runner, TestWorld}, + }; + + use super::shuffle; + + pub type MatchKey = Gf40Bit; + + impl Add<&MatchKey> for &MatchKey { + type Output = MatchKey; + + fn add(self, rhs: &MatchKey) -> Self::Output { + Add::add(*self, *rhs) + } + } + + impl Add for &MatchKey { + type Output = MatchKey; + + fn add(self, rhs: MatchKey) -> Self::Output { + Add::add(*self, rhs) + } + } + + #[test] + fn added_random_tables_cancel_out() { + run(|| async { + let records = vec![MatchKey::truncate_from(12345 as u128) as MatchKey]; + let expected = records[0].clone(); + + let world = TestWorld::default(); + + let result = world + .semi_honest(records.into_iter(), |ctx, shares| async move { + shuffle(ctx, 1, shares).await.unwrap() + }) + .await; + + let result = result + .into_iter() + .map(|(l, r)| { + l.into_iter() + .zip(r) + .map(|(li, ri)| AdditiveShare::new(li, ri)) + .collect::>() + }) + .collect::>(); + + let result: [Vec<_>; 3] = result.try_into().unwrap(); + let actual = result.reconstruct()[0]; + assert_eq!(actual, expected); + }); + } +} From 1d9138d259b6d157546142c9f469c97433425c1d Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Mon, 23 Oct 2023 16:37:49 -0500 Subject: [PATCH 13/15] [oprf][shuffle] Represent shuffle response as Vec> This makes the code slighlty simpler --- src/helpers/transport/query/oprf_shuffle.rs | 16 +-- src/net/http_serde.rs | 4 +- src/protocol/oprf/shuffle/mod.rs | 150 ++++++++++++-------- src/query/runner/oprf_shuffle/query.rs | 26 ++-- src/query/runner/oprf_shuffle/share.rs | 17 ++- 5 files changed, 112 insertions(+), 101 deletions(-) diff --git a/src/helpers/transport/query/oprf_shuffle.rs b/src/helpers/transport/query/oprf_shuffle.rs index 86a23884f..c39242902 100644 --- a/src/helpers/transport/query/oprf_shuffle.rs +++ b/src/helpers/transport/query/oprf_shuffle.rs @@ -1,17 +1,5 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] #[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))] -pub struct QueryConfig { - pub bk_size: u8, // breakdown key size bits - pub tv_size: u8, // trigger value size bits -} - -impl Default for QueryConfig { - fn default() -> Self { - Self { - bk_size: 40, - tv_size: 40, - } - } -} +pub struct QueryConfig {} diff --git a/src/net/http_serde.rs b/src/net/http_serde.rs index eaa86a65e..4e8a4fddb 100644 --- a/src/net/http_serde.rs +++ b/src/net/http_serde.rs @@ -192,8 +192,8 @@ pub mod query { Ok(()) } - QueryType::OPRFShuffle(config) => { - write!(f, "&bk_size={}&tv_size={}", config.bk_size, config.tv_size)?; + QueryType::OPRFShuffle(_config) => { + write!(f, "")?; Ok(()) } } diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index c5da6fd97..e2d7fea0f 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -29,7 +29,11 @@ pub(crate) enum OPRFShuffleStep { /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn shuffle(ctx: C, batch_size: u32, shares: I) -> Result<(Vec, Vec), Error> +pub async fn shuffle( + ctx: C, + batch_size: usize, + shares: I, +) -> Result>, Error> where C: Context, I: IntoIterator>, @@ -39,7 +43,12 @@ where Standard: Distribution, { // 1. Generate permutations - let pis = generate_permutations_with_peers(batch_size, &ctx); + let permutation_size: u32 = batch_size.try_into().map_err(|e| { + Error::FieldValueTruncation(format!( + "batch size {batch_size} does not fit into u32. Error={e:?}" + )) + })?; + let pis = generate_permutations_with_peers(permutation_size, &ctx); // 2. Generate random tables used by all helpers let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ); @@ -54,11 +63,11 @@ where async fn run_h1( ctx: &C, - batch_size: u32, + batch_size: usize, shares: I, (pi_31, pi_12): (Vec, Vec), (z_31, z_12): (Zl, Zr), -) -> Result<(Vec, Vec), Error> +) -> Result>, Error> where C: Context, I: IntoIterator>, @@ -70,12 +79,10 @@ where { // 1. Generate helper-specific random tables let ctx_a_hat = ctx.narrow(&OPRFShuffleStep::GenerateAHat); - let a_hat: Vec<_> = - generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Left).collect(); + let a_hat = generate_random_table_solo(batch_size, &ctx_a_hat, Direction::Left); let ctx_b_hat = ctx.narrow(&OPRFShuffleStep::GenerateBHat); - let b_hat: Vec<_> = - generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right).collect(); + let b_hat = generate_random_table_solo(batch_size, &ctx_b_hat, Direction::Right); // 2. Run computations let a_add_b_iter = shares @@ -89,16 +96,17 @@ where send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; - Ok((a_hat, b_hat)) + let res = combine_single_shares(a_hat, b_hat).collect::>(); + Ok(res) } async fn run_h2( ctx: &C, - batch_size: u32, + batch_size: usize, shares: I, (pi_12, pi_23): (Vec, Vec), (z_12, z_23): (Zl, Zr), -) -> Result<(Vec, Vec), Error> +) -> Result>, Error> where C: Context, I: IntoIterator>, @@ -150,16 +158,17 @@ where ) .await?; - let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()).collect(); - Ok((b_hat, c_hat)) + let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); + let res = combine_single_shares(b_hat, c_hat).collect::>(); + Ok(res) } async fn run_h3( ctx: &C, - batch_size: u32, + batch_size: usize, (pi_23, pi_31): (Vec, Vec), (z_23, z_31): (Zl, Zr), -) -> Result<(Vec, Vec), Error> +) -> Result>, Error> where C: Context, S: SharedValue + Add + Message, @@ -207,8 +216,9 @@ where ) .await?; - let c_hat = add_single_shares(c_hat_1, c_hat_2).collect(); - Ok((c_hat, a_hat)) + let c_hat = add_single_shares(c_hat_1, c_hat_2); + let res = combine_single_shares(c_hat, a_hat).collect::>(); + Ok(res) } fn add_single_shares(l: L, r: R) -> impl Iterator @@ -221,8 +231,19 @@ where } // --------------------------------------------------------------------------- // +fn combine_single_shares(l: Il, r: Ir) -> impl Iterator> +where + S: SharedValue, + Il: IntoIterator, + Ir: IntoIterator, +{ + l.into_iter() + .zip(r) + .map(|(li, ri)| AdditiveShare::new(li, ri)) +} + fn generate_random_tables_with_peers<'a, C, S>( - batch_size: u32, + batch_size: usize, narrow_ctx: &'a C, ) -> (impl Iterator + 'a, impl Iterator + 'a) where @@ -231,13 +252,13 @@ where S: 'a, { let (rng_l, rng_r) = narrow_ctx.prss_rng(); - let with_left = rng_l.sample_iter(Standard).take(batch_size as usize); - let with_right = rng_r.sample_iter(Standard).take(batch_size as usize); + let with_left = rng_l.sample_iter(Standard).take(batch_size); + let with_right = rng_r.sample_iter(Standard).take(batch_size); (with_left, with_right) } fn generate_random_table_solo<'a, C, S>( - batch_size: u32, + batch_size: usize, narrow_ctx: &'a C, peer: Direction, ) -> impl Iterator + 'a @@ -252,24 +273,27 @@ where Direction::Right => rngs.1, }; - rng.sample_iter(Standard).take(batch_size as usize) + rng.sample_iter(Standard).take(batch_size) } // ---------------------------- helper communication ------------------------------------ // -async fn send_to_peer( +async fn send_to_peer( ctx: &C, step: &OPRFShuffleStep, direction: Direction, - items: I, + items: Vec, ) -> Result<(), Error> where C: Context, - I: IntoIterator, S: Message, { let role = ctx.role().peer(direction); - let send_channel = ctx.narrow(step).send_channel(role); + let send_channel = ctx + .narrow(step) + .set_total_records(items.len()) + .send_channel(role); + for (record_id, row) in items.into_iter().enumerate() { send_channel.send(RecordId::from(record_id), row).await?; } @@ -280,16 +304,19 @@ async fn receive_from_peer( ctx: &C, step: &OPRFShuffleStep, direction: Direction, - batch_size: u32, + batch_size: usize, ) -> Result, Error> where C: Context, S: Message, { let role = ctx.role().peer(direction); - let receive_channel: ReceivingEnd = ctx.narrow(step).recv_channel(role); + let receive_channel: ReceivingEnd = ctx + .narrow(step) + .set_total_records(batch_size) + .recv_channel(role); - let mut output: Vec = Vec::with_capacity(batch_size as usize); + let mut output: Vec = Vec::with_capacity(batch_size); for record_id in 0..batch_size { let msg = receive_channel.receive(RecordId::from(record_id)).await?; output.push(msg); @@ -300,17 +327,17 @@ where // ------------------ Pseudorandom permutations functions -------------------- // -fn generate_permutations_with_peers(batch_size: u32, ctx: &C) -> (Vec, Vec) { +fn generate_permutations_with_peers(size: u32, ctx: &C) -> (Vec, Vec) { let narrow_context = ctx.narrow(&OPRFShuffleStep::GeneratePi); let mut rng = narrow_context.prss_rng(); - let with_left = generate_pseudorandom_permutation(batch_size, &mut rng.0); - let with_right = generate_pseudorandom_permutation(batch_size, &mut rng.1); + let with_left = generate_pseudorandom_permutation(size, &mut rng.0); + let with_right = generate_pseudorandom_permutation(size, &mut rng.1); (with_left, with_right) } -fn generate_pseudorandom_permutation(batch_size: u32, rng: &mut R) -> Vec { - let mut permutation = (0..batch_size).collect::>(); +fn generate_pseudorandom_permutation(size: u32, rng: &mut R) -> Vec { + let mut permutation = (0..size).collect::>(); permutation.shuffle(rng); permutation } @@ -319,16 +346,13 @@ fn generate_pseudorandom_permutation(batch_size: u32, rng: &mut R) -> Ve pub mod tests { use std::ops::Add; - use crate::secret_sharing::replicated::semi_honest::AdditiveShare; - use crate::secret_sharing::replicated::ReplicatedSecretSharing; - use crate::test_executor::run; + use super::shuffle; use crate::{ ff::{Field, Gf40Bit}, + test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld}, }; - use super::shuffle; - pub type MatchKey = Gf40Bit; impl Add<&MatchKey> for &MatchKey { @@ -348,32 +372,36 @@ pub mod tests { } #[test] - fn added_random_tables_cancel_out() { + fn shuffles_the_order() { run(|| async { - let records = vec![MatchKey::truncate_from(12345 as u128) as MatchKey]; - let expected = records[0].clone(); + let mut i: u128 = 0; + let mut records = std::iter::from_fn(move || { + i += 1; + Some(MatchKey::truncate_from(i) as MatchKey) + }) + .take(100) + .collect::>(); + + records.sort(); + + let mut actual = TestWorld::default() + .semi_honest(records.clone().into_iter(), |ctx, shares| async move { + shuffle(ctx, shares.len(), shares).await.unwrap() + }) + .await + .reconstruct(); - let world = TestWorld::default(); + assert_ne!( + actual, records, + "Shuffle should produce a different order of items" + ); - let result = world - .semi_honest(records.into_iter(), |ctx, shares| async move { - shuffle(ctx, 1, shares).await.unwrap() - }) - .await; - - let result = result - .into_iter() - .map(|(l, r)| { - l.into_iter() - .zip(r) - .map(|(li, ri)| AdditiveShare::new(li, ri)) - .collect::>() - }) - .collect::>(); + actual.sort(); - let result: [Vec<_>; 3] = result.try_into().unwrap(); - let actual = result.reconstruct()[0]; - assert_eq!(actual, expected); + assert_eq!( + actual, records, + "Shuffle should not change the items in the set" + ); }); } } diff --git a/src/query/runner/oprf_shuffle/query.rs b/src/query/runner/oprf_shuffle/query.rs index 21365fc0b..5e481b9f4 100644 --- a/src/query/runner/oprf_shuffle/query.rs +++ b/src/query/runner/oprf_shuffle/query.rs @@ -33,16 +33,10 @@ impl OPRFShuffleQuery { .try_concat() .await?; - let batch_size = u32::try_from(input.len()).map_err(|_e| { - Error::FieldValueTruncation(format!( - "Cannot truncate the number of input rows {} to u32", - input.len(), - )) - })?; - + let batch_size = input.len(); let shares = split_shares(&input); - let (res_l, res_r) = shuffle(ctx, batch_size, shares).await?; - Ok(combine_shares(res_l, res_r)) + let res = shuffle(ctx, batch_size, shares).await?; + Ok(combine_shares(res.iter())) } } @@ -58,13 +52,11 @@ fn split_shares( input_rows.iter().map(f) } -fn combine_shares(l: L, r: R) -> Vec -where - L: IntoIterator, - R: IntoIterator, -{ - l.into_iter() - .zip(r) - .map(|(l, r)| l.to_input_row(r)) +fn combine_shares<'a>( + input: impl IntoIterator>, +) -> Vec { + input + .into_iter() + .map(ShuffleShare::to_input_row) .collect::>() } diff --git a/src/query/runner/oprf_shuffle/share.rs b/src/query/runner/oprf_shuffle/share.rs index 8d1edac2b..21cf46da6 100644 --- a/src/query/runner/oprf_shuffle/share.rs +++ b/src/query/runner/oprf_shuffle/share.rs @@ -8,7 +8,10 @@ use super::ShuffleInputRow; use crate::{ ff::{Field, Gf32Bit, Gf40Bit, Gf8Bit, Serializable}, helpers::Direction, - secret_sharing::SharedValue, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + SharedValue, + }, }; pub type ShuffleShareMK = Gf40Bit; pub type ShuffleShareBK = Gf8Bit; @@ -131,13 +134,13 @@ impl ShuffleShare { } #[must_use] - pub fn to_input_row(self, rhs: Self) -> ShuffleInputRow { + pub fn to_input_row(input: &AdditiveShare) -> ShuffleInputRow { ShuffleInputRow { - timestamp: (self.timestamp, rhs.timestamp).into(), - mk_shares: (self.mk, rhs.mk).into(), - is_trigger_bit: (self.is_trigger_bit, rhs.is_trigger_bit).into(), - breakdown_key: (self.breakdown_key, rhs.breakdown_key).into(), - trigger_value: (self.trigger_value, rhs.trigger_value).into(), + timestamp: ReplicatedSecretSharing::map(input, |v| v.timestamp), + mk_shares: ReplicatedSecretSharing::map(input, |v| v.mk), + is_trigger_bit: ReplicatedSecretSharing::map(input, |v| v.is_trigger_bit), + breakdown_key: ReplicatedSecretSharing::map(input, |v| v.breakdown_key), + trigger_value: ReplicatedSecretSharing::map(input, |v| v.trigger_value), } } } From 7791ae041a6cc8e6525365db68daae4a17f70e13 Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Wed, 25 Oct 2023 10:08:02 -0500 Subject: [PATCH 14/15] [oprf][shuffle] Avoid allocation of permutations Instead, shuffle in-place based on shared randomness --- src/protocol/oprf/shuffle/mod.rs | 60 ++++++++++++-------------------- 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index e2d7fea0f..c97ccdf3c 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -4,9 +4,7 @@ use futures::future; use ipa_macros::Step; use rand::{distributions::Standard, prelude::Distribution, seq::SliceRandom, Rng}; -use super::super::{ - basics::apply_permutation::apply as apply_permutation, context::Context, RecordId, -}; +use super::super::{context::Context, RecordId}; use crate::{ error::Error, helpers::{Direction, Message, ReceivingEnd, Role}, @@ -18,9 +16,9 @@ use crate::{ #[derive(Step)] pub(crate) enum OPRFShuffleStep { + ApplyPermutations, GenerateAHat, GenerateBHat, - GeneratePi, GenerateZ, TransferCHat, TransferX2, @@ -42,22 +40,13 @@ where for<'b> &'b S: Add<&'b S, Output = S>, Standard: Distribution, { - // 1. Generate permutations - let permutation_size: u32 = batch_size.try_into().map_err(|e| { - Error::FieldValueTruncation(format!( - "batch size {batch_size} does not fit into u32. Error={e:?}" - )) - })?; - let pis = generate_permutations_with_peers(permutation_size, &ctx); - - // 2. Generate random tables used by all helpers let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ); let zs = generate_random_tables_with_peers(batch_size, &ctx_z); match ctx.role() { - Role::H1 => run_h1(&ctx, batch_size, shares, pis, zs).await, - Role::H2 => run_h2(&ctx, batch_size, shares, pis, zs).await, - Role::H3 => run_h3(&ctx, batch_size, pis, zs).await, + Role::H1 => run_h1(&ctx, batch_size, shares, zs).await, + Role::H2 => run_h2(&ctx, batch_size, shares, zs).await, + Role::H3 => run_h3(&ctx, batch_size, zs).await, } } @@ -65,7 +54,6 @@ async fn run_h1( ctx: &C, batch_size: usize, shares: I, - (pi_31, pi_12): (Vec, Vec), (z_31, z_12): (Zl, Zr), ) -> Result>, Error> where @@ -89,10 +77,13 @@ where .into_iter() .map(|s: AdditiveShare| s.left().add(s.right())); let mut x_1: Vec = add_single_shares(a_add_b_iter, z_12).collect(); - apply_permutation(&pi_12, &mut x_1); + + let ctx_perm = ctx.narrow(&OPRFShuffleStep::ApplyPermutations); + let (mut rng_perm_l, mut rng_perm_r) = ctx_perm.prss_rng(); + apply_permutation(&mut rng_perm_r, &mut x_1); let mut x_2: Vec = add_single_shares(x_1, z_31).collect(); - apply_permutation(&pi_31, &mut x_2); + apply_permutation(&mut rng_perm_l, &mut x_2); send_to_peer(ctx, &OPRFShuffleStep::TransferX2, Direction::Right, x_2).await?; @@ -104,7 +95,6 @@ async fn run_h2( ctx: &C, batch_size: usize, shares: I, - (pi_12, pi_23): (Vec, Vec), (z_12, z_23): (Zl, Zr), ) -> Result>, Error> where @@ -125,7 +115,10 @@ where // 2. Run computations let c = shares.into_iter().map(|s| s.right()); let mut y_1: Vec = add_single_shares(c, z_12).collect(); - apply_permutation(&pi_12, &mut y_1); + + let ctx_perm = ctx.narrow(&OPRFShuffleStep::ApplyPermutations); + let (mut rng_perm_l, mut rng_perm_r) = ctx_perm.prss_rng(); + apply_permutation(&mut rng_perm_l, &mut y_1); let ((), x_2): ((), Vec) = future::try_join( send_to_peer(ctx, &OPRFShuffleStep::TransferY1, Direction::Right, y_1), @@ -139,7 +132,7 @@ where .await?; let mut x_3: Vec = add_single_shares(x_2.iter(), z_23).collect(); - apply_permutation(&pi_23, &mut x_3); + apply_permutation(&mut rng_perm_r, &mut x_3); let c_hat_1: Vec = add_single_shares(x_3.iter(), b_hat.iter()).collect(); let ((), c_hat_2) = future::try_join( @@ -166,7 +159,6 @@ where async fn run_h3( ctx: &C, batch_size: usize, - (pi_23, pi_31): (Vec, Vec), (z_23, z_31): (Zl, Zr), ) -> Result>, Error> where @@ -194,10 +186,13 @@ where .await?; let mut y_2: Vec = add_single_shares(y_1, z_31).collect(); - apply_permutation(&pi_31, &mut y_2); + + let ctx_perm = ctx.narrow(&OPRFShuffleStep::ApplyPermutations); + let (mut rng_perm_l, mut rng_perm_r) = ctx_perm.prss_rng(); + apply_permutation(&mut rng_perm_r, &mut y_2); let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); - apply_permutation(&pi_23, &mut y_3); + apply_permutation(&mut rng_perm_l, &mut y_3); let c_hat_2: Vec = add_single_shares(y_3.iter(), a_hat.iter()).collect(); let ((), c_hat_1): ((), Vec) = future::try_join( @@ -327,19 +322,8 @@ where // ------------------ Pseudorandom permutations functions -------------------- // -fn generate_permutations_with_peers(size: u32, ctx: &C) -> (Vec, Vec) { - let narrow_context = ctx.narrow(&OPRFShuffleStep::GeneratePi); - let mut rng = narrow_context.prss_rng(); - - let with_left = generate_pseudorandom_permutation(size, &mut rng.0); - let with_right = generate_pseudorandom_permutation(size, &mut rng.1); - (with_left, with_right) -} - -fn generate_pseudorandom_permutation(size: u32, rng: &mut R) -> Vec { - let mut permutation = (0..size).collect::>(); - permutation.shuffle(rng); - permutation +fn apply_permutation(rng: &mut R, items: &mut [S]) { + items.shuffle(rng); } #[cfg(all(test, any(unit_test, feature = "shuttle")))] From b3b09ed4f099c35dbd073f5710466f64d87ef444 Mon Sep 17 00:00:00 2001 From: Artem Ignatyev Date: Wed, 25 Oct 2023 11:57:37 -0500 Subject: [PATCH 15/15] [oprf][shuffle] Removed a few unnecessary allocations If a vec produced by permutation is used only to compute some other "table" by adding some other table to it, we can do that in-place --- src/protocol/oprf/shuffle/mod.rs | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/protocol/oprf/shuffle/mod.rs b/src/protocol/oprf/shuffle/mod.rs index c97ccdf3c..6a9bf5894 100644 --- a/src/protocol/oprf/shuffle/mod.rs +++ b/src/protocol/oprf/shuffle/mod.rs @@ -1,4 +1,4 @@ -use std::ops::Add; +use std::ops::{Add, AddAssign}; use futures::future; use ipa_macros::Step; @@ -36,8 +36,8 @@ where C: Context, I: IntoIterator>, S: SharedValue + Add + Message, - for<'b> &'b S: Add, - for<'b> &'b S: Add<&'b S, Output = S>, + for<'a> &'a S: Add, + for<'a> &'a S: Add<&'a S, Output = S>, Standard: Distribution, { let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ); @@ -131,7 +131,7 @@ where ) .await?; - let mut x_3: Vec = add_single_shares(x_2.iter(), z_23).collect(); + let mut x_3: Vec = add_single_shares_in_place(x_2, z_23); apply_permutation(&mut rng_perm_r, &mut x_3); let c_hat_1: Vec = add_single_shares(x_3.iter(), b_hat.iter()).collect(); @@ -167,7 +167,6 @@ where Zl: IntoIterator, Zr: IntoIterator, S: Clone + Add + Message, - // for<'a> &'a S: Add, for<'a> &'a S: Add<&'a S, Output = S>, Standard: Distribution, { @@ -191,7 +190,7 @@ where let (mut rng_perm_l, mut rng_perm_r) = ctx_perm.prss_rng(); apply_permutation(&mut rng_perm_r, &mut y_2); - let mut y_3: Vec = add_single_shares(y_2, z_23).collect(); + let mut y_3: Vec = add_single_shares_in_place(y_2, z_23); apply_permutation(&mut rng_perm_l, &mut y_3); let c_hat_2: Vec = add_single_shares(y_3.iter(), a_hat.iter()).collect(); @@ -224,6 +223,19 @@ where { l.into_iter().zip(r).map(|(a, b)| a + b) } + +fn add_single_shares_in_place(mut items: Vec, r: R) -> Vec +where + S: AddAssign, + R: IntoIterator, +{ + items + .iter_mut() + .zip(r) + .for_each(|(item, rhs)| item.add_assign(rhs)); + items +} + // --------------------------------------------------------------------------- // fn combine_single_shares(l: Il, r: Ir) -> impl Iterator>