Skip to content

Commit

Permalink
Make reshard work with streams too
Browse files Browse the repository at this point in the history
Internally, reshard used streams already, so it is only a matter of changing the API and connecting things together
  • Loading branch information
akoshelev committed Oct 9, 2024
1 parent 7e1c180 commit 0194eb8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
65 changes: 55 additions & 10 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ pub(crate) use malicious::TEST_DZKP_STEPS;
use crate::{
error::Error,
helpers::{
ChannelId, Direction, Gateway, Message, MpcMessage, MpcReceivingEnd, Role, SendingEnd,
ShardReceivingEnd, TotalRecords,
stream::ExactSizeStream, ChannelId, Direction, Gateway, Message, MpcMessage,
MpcReceivingEnd, Role, SendingEnd, ShardReceivingEnd, TotalRecords,
},
protocol::{
context::dzkp_validator::DZKPValidator,
Expand Down Expand Up @@ -374,24 +374,44 @@ impl<'a> Inner<'a> {
///
/// [`calculations`]: https://docs.google.com/document/d/1vej6tYgNV3GWcldD4tl7a4Z9EeZwda3F5u7roPGArlU/
///
/// ## Stream size
/// Note that it currently works for streams where size is known in advance. Mainly because
/// we want to set up send buffer sizes and avoid sending records one-by-one to each shard.
/// Other than that, there are no technical limitation here, and it could be possible to make it
/// work with regular streams if the batching problem is somehow addressed.
///
///
/// ```compile_fail
/// use futures::stream::{self, StreamExt};
/// use ipa_core::protocol::context::reshard_stream;
/// use ipa_core::ff::boolean::Boolean;
/// use ipa_core::secret_sharing::SharedValue;
/// async {
/// let a = [Boolean::ZERO];
/// let mut s = stream::iter(a.into_iter()).cycle();
/// // this should fail to compile:
/// // the trait bound `futures::stream::Cycle<...>: ExactSizeStream` is not satisfied
/// reshard_stream(todo!(), s, todo!()).await;
/// };
/// ```
///
/// ## Panics
/// When `shard_picker` returns an out-of-bounds index.
///
/// ## Errors
/// If cross-shard communication fails
pub async fn reshard<L, K, C, S>(
///
pub async fn reshard_stream<L, K, C, S>(
ctx: C,
input: L,
shard_picker: S,
) -> Result<Vec<K>, crate::error::Error>
where
L: IntoIterator<Item = K>,
L::IntoIter: ExactSizeIterator,
L: ExactSizeStream<Item = K>,
S: Fn(C, RecordId, &K) -> ShardIndex,
K: Message + Clone,
C: ShardedContext,
{
let input = input.into_iter();
let input_len = input.len();

// We set channels capacity to be at least 1 to be able to open send channels to all peers.
Expand Down Expand Up @@ -426,6 +446,8 @@ where
})
.fuse();

let input = pin!(input);

// This produces a stream of outcomes of send requests.
// In order to make it compatible with receive stream, it also returns records that must
// stay on this shard, according to `shard_picker`'s decision.
Expand All @@ -439,13 +461,15 @@ where
// tracking per shard to work correctly. If tasks complete out of order, this will cause share
// misplacement on the recipient side.
(
input.enumerate().zip(iter::repeat(ctx.clone())),
input
.enumerate()
.zip(stream::iter(iter::repeat(ctx.clone()))),
&mut send_channels,
),
|(mut input, send_channels)| async {
// Process more data as it comes in, or close the sending channels, if there is nothing
// left.
if let Some(((i, val), ctx)) = input.next() {
if let Some(((i, val), ctx)) = input.next().await {
let dest_shard = shard_picker(ctx, RecordId::from(i), &val);
if dest_shard == my_shard {
Some(((my_shard, Ok(Some(val.clone()))), (input, send_channels)))
Expand Down Expand Up @@ -504,6 +528,27 @@ where
Ok(r.into_iter().flatten().collect())
}

/// Same as [`reshard_stream`] but takes an iterator with the known size
/// as input.
///
/// # Errors
///
/// # Panics
pub async fn reshard_iter<L, K, C, S>(
ctx: C,
input: L,
shard_picker: S,
) -> Result<Vec<K>, crate::error::Error>
where
L: IntoIterator<Item = K>,
L::IntoIter: ExactSizeIterator,
S: Fn(C, RecordId, &K) -> ShardIndex,
K: Message + Clone,
C: ShardedContext,
{
reshard_stream(ctx, stream::iter(input.into_iter()), shard_picker).await
}

/// trait for contexts that allow MPC multiplications that are protected against a malicious helper by using a DZKP
#[async_trait]
pub trait DZKPContext: Context {
Expand Down Expand Up @@ -543,7 +588,7 @@ mod tests {
protocol::{
basics::ShareKnownValue,
context::{
reshard, step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable,
reshard_iter, step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable,
Context, ShardedContext, UpgradableContext, Validator,
},
prss::SharedRandomness,
Expand Down Expand Up @@ -830,7 +875,7 @@ mod tests {
let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect();
let r = world
.semi_honest(input.clone().into_iter(), |ctx, shard_input| async move {
reshard(ctx, shard_input, |_, record_id, _| {
reshard_iter(ctx, shard_input, |_, record_id, _| {
ShardIndex::from(u32::from(record_id) % SHARDS)
})
.await
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
ff::{boolean_array::BA64, U128Conversions},
helpers::{Direction, Error, Role, TotalRecords},
protocol::{
context::{reshard, ShardedContext},
context::{reshard_iter, ShardedContext},
prss::{FromRandom, FromRandomU128, SharedRandomness},
RecordId,
},
Expand Down Expand Up @@ -88,7 +88,7 @@ trait ShuffleContext: ShardedContext {
let data = data.into_iter();
async move {
let masking_ctx = self.narrow(&ShuffleStep::Mask);
let mut resharded = assert_send(reshard(
let mut resharded = assert_send(reshard_iter(
self.clone(),
data.enumerate().map(|(i, item)| {
// FIXME(1029): update PRSS trait to compute only left or right part
Expand Down

0 comments on commit 0194eb8

Please sign in to comment.