diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 98461ad7e..2ec06cf0d 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -4,7 +4,7 @@ use std::{ fmt::Debug, fs::{File, OpenOptions}, io, - io::{stdout, BufRead, BufReader, Write}, + io::{stdout, BufReader, Write}, ops::Deref, path::{Path, PathBuf}, }; @@ -34,6 +34,8 @@ use ipa_core::{ }; use rand::{distributions::Alphanumeric, rngs::StdRng, thread_rng, Rng}; use rand_core::SeedableRng; +use ipa_core::cli::playbook::StreamingSubmission; +use ipa_core::helpers::BodyStream; #[derive(Debug, Parser)] #[clap(name = "rc", about = "Report Collector CLI")] @@ -421,18 +423,22 @@ async fn hybrid( ) -> Result<(), Box> { let query_type = QueryType::MaliciousHybrid(hybrid_query_config); - let files = [ + let [h1_streams, h2_streams, h3_streams] = [ &encrypted_inputs.enc_input_file1, &encrypted_inputs.enc_input_file2, &encrypted_inputs.enc_input_file3, - ]; - - let submissions = files - .iter() - .map(|path| { - let file = - File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); - RoundRobinSubmission::new(BufReader::new(file)) + ].map(|path| { + let file = + File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); + RoundRobinSubmission::new(BufReader::new(file)) + }).map(|s| s.into_byte_streams(args.shard_count)); + + // create byte streams for each shard + let submissions = h1_streams.into_iter() + .zip(h2_streams.into_iter()) + .zip(h3_streams.into_iter()) + .map(|((s1, s2), s3)| { + [BodyStream::from_bytes_stream(s1), BodyStream::from_bytes_stream(s2), BodyStream::from_bytes_stream(s3)] }) .collect::>(); @@ -452,12 +458,9 @@ async fn hybrid( // implementation, otherwise a runtime reconstruct error will be generated. // see ipa-core/src/query/executor.rs - let actual = run_hybrid_query_and_validate::( + let actual = run_hybrid_query_and_validate::( submissions, count, - args.shard_count - .try_into() - .expect("u32 should fit in usize"), helper_clients, query_id, hybrid_query_config, diff --git a/ipa-core/src/cli/playbook/hybrid.rs b/ipa-core/src/cli/playbook/hybrid.rs index 1fb9b293c..b4c358bda 100644 --- a/ipa-core/src/cli/playbook/hybrid.rs +++ b/ipa-core/src/cli/playbook/hybrid.rs @@ -1,16 +1,14 @@ #![cfg(all(feature = "web-app", feature = "cli"))] use std::{ cmp::min, - io::BufRead, time::{Duration, Instant}, }; - +use std::iter::zip; use futures_util::future::try_join_all; use serde::{Deserialize, Serialize}; use tokio::time::sleep; use crate::{ - cli::playbook::{RoundRobinSubmission, StreamingSubmission}, ff::{Serializable, U128Conversions}, helpers::{ query::{HybridQueryParams, QueryInput, QuerySize}, @@ -26,10 +24,9 @@ use crate::{ /// # Panics /// if results are invalid #[allow(clippy::disallowed_methods)] // allow try_join_all -pub async fn run_hybrid_query_and_validate( - inputs: [RoundRobinSubmission; 3], +pub async fn run_hybrid_query_and_validate( + inputs: Vec<[BodyStream; 3]>, query_size: usize, - shard_count: usize, clients: Vec<[IpaHttpClient; 3]>, query_id: QueryId, query_config: HybridQueryParams, @@ -37,43 +34,33 @@ pub async fn run_hybrid_query_and_validate( where HV: SharedValue + U128Conversions, AdditiveShare: Serializable, - R: BufRead + Send, { let mpc_time = Instant::now(); - let leader_clients = clients[0].clone(); - - let transposed_inputs = inputs[0] - .into_byte_streams(shard_count) - .iter() - .zip( - inputs[1] - .into_byte_streams(shard_count) + assert_eq!(clients.len(), inputs.len()); + // submit inputs to each shard + let _ = try_join_all(zip(clients.iter(), inputs.into_iter()) + .map(|(shard_clients, shard_inputs)| { + try_join_all(shard_clients .iter() - .zip(inputs[2].into_byte_streams(shard_count).iter()), - ) - .map(|(i1, (i2, i3))| [i1, i2, i3]) - .collect::>(); + .zip(shard_inputs.into_iter()) + .map(|(client, input)| + { + client.query_input(QueryInput { + query_id, + input_stream: input + }) + } + ) + ) + })).await.unwrap(); - try_join_all( - transposed_inputs - .into_iter() - .flatten() - .zip(clients.into_iter().flatten()) - .map(|(stream, client)| { - client.query_input(QueryInput { - query_id, - input_stream: BodyStream::from_bytes_stream(*stream), - }) - }), - ) - .await - .unwrap(); + let leader_clients = &clients[0]; let mut delay = Duration::from_millis(125); loop { if try_join_all( leader_clients - .iter() + .each_ref() .map(|client| client.query_status(query_id)), ) .await diff --git a/ipa-core/tests/hybrid.rs b/ipa-core/tests/hybrid.rs index c57783286..819854ac6 100644 --- a/ipa-core/tests/hybrid.rs +++ b/ipa-core/tests/hybrid.rs @@ -88,9 +88,9 @@ fn test_hybrid() { .args(["--network".into(), config_path.join("network.toml")]) .args(["--output-file".as_ref(), output_file.as_os_str()]) .args(["--shard-count", SHARDS.to_string().as_str()]) - .args(["--count", INPUT_SIZE.to_string().as_str()]) .args(["--wait", "2"]) .arg("malicious-hybrid") + .args(["--count", INPUT_SIZE.to_string().as_str()]) .args(["--enc-input-file1".as_ref(), enc1.as_os_str()]) .args(["--enc-input-file2".as_ref(), enc2.as_os_str()]) .args(["--enc-input-file3".as_ref(), enc3.as_os_str()])