Skip to content

Commit

Permalink
add test for reshard-stream
Browse files Browse the repository at this point in the history
  • Loading branch information
eriktaubeneck committed Oct 9, 2024
1 parent 0194eb8 commit 51a8a28
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ pub trait DZKPContext: Context {
mod tests {
use std::{iter, iter::repeat};

use futures::{future::join_all, stream::StreamExt, try_join};
use futures::{future::join_all, stream, stream::StreamExt, try_join};
use ipa_step::StepNarrow;
use rand::{
distributions::{Distribution, Standard},
Expand All @@ -588,8 +588,8 @@ mod tests {
protocol::{
basics::ShareKnownValue,
context::{
reshard_iter, step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable,
Context, ShardedContext, UpgradableContext, Validator,
reshard_iter, reshard_stream, step::MaliciousProtocolStep::MaliciousProtocol,
upgrade::Upgradable, Context, ShardedContext, UpgradableContext, Validator,
},
prss::SharedRandomness,
RecordId,
Expand Down Expand Up @@ -867,7 +867,34 @@ mod tests {

/// Ensure global record order across shards is consistent.
#[test]
fn shard_picker() {
fn reshard_stream_test() {
run(|| async move {
const SHARDS: u32 = 5;
let world: TestWorld<WithShards<5, RoundRobinInputDistribution>> =
TestWorld::with_shards(TestWorldConfig::default());

let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect();
let r = world
.semi_honest(input.clone().into_iter(), |ctx, shard_input| async move {
let shard_input = stream::iter(shard_input);
reshard_stream(ctx, shard_input, |_, record_id, _| {
ShardIndex::from(u32::from(record_id) % SHARDS)
})
.await
.unwrap()
})
.await
.into_iter()
.flat_map(|v| v.reconstruct())
.collect::<Vec<_>>();

assert_eq!(input, r);
});
}

/// Ensure global record order across shards is consistent.
#[test]
fn reshard_iter_test() {
run(|| async move {
const SHARDS: u32 = 5;
let world: TestWorld<WithShards<5, RoundRobinInputDistribution>> =
Expand Down

0 comments on commit 51a8a28

Please sign in to comment.