Skip to content

Commit

Permalink
Merge pull request #1343 from akoshelev/reshard-stream
Browse files Browse the repository at this point in the history
Make reshard work with streams too
  • Loading branch information
akoshelev authored Oct 10, 2024
2 parents 7e1c180 + 51a8a28 commit 7eec6c6
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 15 deletions.
98 changes: 85 additions & 13 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ pub(crate) use malicious::TEST_DZKP_STEPS;
use crate::{
error::Error,
helpers::{
ChannelId, Direction, Gateway, Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd,
ShardReceivingEnd, TotalRecords,
stream::ExactSizeStream, ChannelId, Direction, Gateway, Message, MpcMessage,
MpcReceivingEnd, Role, SendingEnd, ShardReceivingEnd, TotalRecords,
},
protocol::{
context::dzkp_validator::DZKPValidator,
Expand Down Expand Up @@ -374,24 +374,44 @@ impl<'a> Inner<'a> {
///
/// [`calculations`]: https://docs.google.com/document/d/1vej6tYgNV3GWcldD4tl7a4Z9EeZwda3F5u7roPGArlU/
///
/// ## Stream size
/// Note that it currently works for streams where size is known in advance. Mainly because
/// we want to set up send buffer sizes and avoid sending records one-by-one to each shard.
/// Other than that, there are no technical limitation here, and it could be possible to make it
/// work with regular streams if the batching problem is somehow addressed.
///
///
/// ```compile_fail
/// use futures::stream::{self, StreamExt};
/// use ipa_core::protocol::context::reshard_stream;
/// use ipa_core::ff::boolean::Boolean;
/// use ipa_core::secret_sharing::SharedValue;
/// async {
/// let a = [Boolean::ZERO];
/// let mut s = stream::iter(a.into_iter()).cycle();
/// // this should fail to compile:
/// // the trait bound `futures::stream::Cycle<...>: ExactSizeStream` is not satisfied
/// reshard_stream(todo!(), s, todo!()).await;
/// };
/// ```
///
/// ## Panics
/// When `shard_picker` returns an out-of-bounds index.
///
/// ## Errors
/// If cross-shard communication fails
pub async fn reshard<L, K, C, S>(
///
pub async fn reshard_stream<L, K, C, S>(
ctx: C,
input: L,
shard_picker: S,
) -> Result<Vec<K>, crate::error::Error>
where
L: IntoIterator<Item = K>,
L::IntoIter: ExactSizeIterator,
L: ExactSizeStream<Item = K>,
S: Fn(C, RecordId, &K) -> ShardIndex,
K: Message + Clone,
C: ShardedContext,
{
let input = input.into_iter();
let input_len = input.len();

// We set channels capacity to be at least 1 to be able to open send channels to all peers.
Expand Down Expand Up @@ -426,6 +446,8 @@ where
})
.fuse();

let input = pin!(input);

// This produces a stream of outcomes of send requests.
// In order to make it compatible with receive stream, it also returns records that must
// stay on this shard, according to `shard_picker`'s decision.
Expand All @@ -439,13 +461,15 @@ where
// tracking per shard to work correctly. If tasks complete out of order, this will cause share
// misplacement on the recipient side.
(
input.enumerate().zip(iter::repeat(ctx.clone())),
input
.enumerate()
.zip(stream::iter(iter::repeat(ctx.clone()))),
&mut send_channels,
),
|(mut input, send_channels)| async {
// Process more data as it comes in, or close the sending channels, if there is nothing
// left.
if let Some(((i, val), ctx)) = input.next() {
if let Some(((i, val), ctx)) = input.next().await {
let dest_shard = shard_picker(ctx, RecordId::from(i), &val);
if dest_shard == my_shard {
Some(((my_shard, Ok(Some(val.clone()))), (input, send_channels)))
Expand Down Expand Up @@ -504,6 +528,27 @@ where
Ok(r.into_iter().flatten().collect())
}

/// Same as [`reshard_stream`] but takes an iterator with the known size
/// as input.
///
/// # Errors
///
/// # Panics
pub async fn reshard_iter<L, K, C, S>(
ctx: C,
input: L,
shard_picker: S,
) -> Result<Vec<K>, crate::error::Error>
where
L: IntoIterator<Item = K>,
L::IntoIter: ExactSizeIterator,
S: Fn(C, RecordId, &K) -> ShardIndex,
K: Message + Clone,
C: ShardedContext,
{
reshard_stream(ctx, stream::iter(input.into_iter()), shard_picker).await
}

/// trait for contexts that allow MPC multiplications that are protected against a malicious helper by using a DZKP
#[async_trait]
pub trait DZKPContext: Context {
Expand All @@ -526,7 +571,7 @@ pub trait DZKPContext: Context {
mod tests {
use std::{iter, iter::repeat};

use futures::{future::join_all, stream::StreamExt, try_join};
use futures::{future::join_all, stream, stream::StreamExt, try_join};
use ipa_step::StepNarrow;
use rand::{
distributions::{Distribution, Standard},
Expand All @@ -543,8 +588,8 @@ mod tests {
protocol::{
basics::ShareKnownValue,
context::{
reshard, step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable,
Context, ShardedContext, UpgradableContext, Validator,
reshard_iter, reshard_stream, step::MaliciousProtocolStep::MaliciousProtocol,
upgrade::Upgradable, Context, ShardedContext, UpgradableContext, Validator,
},
prss::SharedRandomness,
RecordId,
Expand Down Expand Up @@ -822,15 +867,42 @@ mod tests {

/// Ensure global record order across shards is consistent.
#[test]
fn shard_picker() {
fn reshard_stream_test() {
run(|| async move {
const SHARDS: u32 = 5;
let world: TestWorld<WithShards<5, RoundRobinInputDistribution>> =
TestWorld::with_shards(TestWorldConfig::default());

let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect();
let r = world
.semi_honest(input.clone().into_iter(), |ctx, shard_input| async move {
let shard_input = stream::iter(shard_input);
reshard_stream(ctx, shard_input, |_, record_id, _| {
ShardIndex::from(u32::from(record_id) % SHARDS)
})
.await
.unwrap()
})
.await
.into_iter()
.flat_map(|v| v.reconstruct())
.collect::<Vec<_>>();

assert_eq!(input, r);
});
}

/// Ensure global record order across shards is consistent.
#[test]
fn reshard_iter_test() {
run(|| async move {
const SHARDS: u32 = 5;
let world: TestWorld<WithShards<5, RoundRobinInputDistribution>> =
TestWorld::with_shards(TestWorldConfig::default());
let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect();
let r = world
.semi_honest(input.clone().into_iter(), |ctx, shard_input| async move {
reshard(ctx, shard_input, |_, record_id, _| {
reshard_iter(ctx, shard_input, |_, record_id, _| {
ShardIndex::from(u32::from(record_id) % SHARDS)
})
.await
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
ff::{boolean_array::BA64, U128Conversions},
helpers::{Direction, Error, Role, TotalRecords},
protocol::{
context::{reshard, ShardedContext},
context::{reshard_iter, ShardedContext},
prss::{FromRandom, FromRandomU128, SharedRandomness},
RecordId,
},
Expand Down Expand Up @@ -88,7 +88,7 @@ trait ShuffleContext: ShardedContext {
let data = data.into_iter();
async move {
let masking_ctx = self.narrow(&ShuffleStep::Mask);
let mut resharded = assert_send(reshard(
let mut resharded = assert_send(reshard_iter(
self.clone(),
data.enumerate().map(|(i, item)| {
// FIXME(1029): update PRSS trait to compute only left or right part
Expand Down

0 comments on commit 7eec6c6

Please sign in to comment.