Skip to content

Commit

Permalink
Per shard submission from report collector
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Dec 5, 2024
1 parent 8effb67 commit 1e6418c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 49 deletions.
31 changes: 17 additions & 14 deletions ipa-core/src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
fmt::Debug,
fs::{File, OpenOptions},
io,
io::{stdout, BufRead, BufReader, Write},
io::{stdout, BufReader, Write},
ops::Deref,
path::{Path, PathBuf},
};
Expand Down Expand Up @@ -34,6 +34,8 @@ use ipa_core::{
};
use rand::{distributions::Alphanumeric, rngs::StdRng, thread_rng, Rng};
use rand_core::SeedableRng;
use ipa_core::cli::playbook::StreamingSubmission;
use ipa_core::helpers::BodyStream;

#[derive(Debug, Parser)]
#[clap(name = "rc", about = "Report Collector CLI")]
Expand Down Expand Up @@ -421,18 +423,22 @@ async fn hybrid(
) -> Result<(), Box<dyn Error>> {
let query_type = QueryType::MaliciousHybrid(hybrid_query_config);

let files = [
let [h1_streams, h2_streams, h3_streams] = [
&encrypted_inputs.enc_input_file1,
&encrypted_inputs.enc_input_file2,
&encrypted_inputs.enc_input_file3,
];

let submissions = files
.iter()
.map(|path| {
let file =
File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}"));
RoundRobinSubmission::new(BufReader::new(file))
].map(|path| {
let file =
File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}"));
RoundRobinSubmission::new(BufReader::new(file))
}).map(|s| s.into_byte_streams(args.shard_count));

// create byte streams for each shard
let submissions = h1_streams.into_iter()
.zip(h2_streams.into_iter())
.zip(h3_streams.into_iter())
.map(|((s1, s2), s3)| {
[BodyStream::from_bytes_stream(s1), BodyStream::from_bytes_stream(s2), BodyStream::from_bytes_stream(s3)]
})
.collect::<Vec<_>>();

Expand All @@ -452,12 +458,9 @@ async fn hybrid(
// implementation, otherwise a runtime reconstruct error will be generated.
// see ipa-core/src/query/executor.rs

let actual = run_hybrid_query_and_validate::<BA32, BufReader>(
let actual = run_hybrid_query_and_validate::<BA32>(
submissions,
count,
args.shard_count
.try_into()
.expect("u32 should fit in usize"),
helper_clients,
query_id,
hybrid_query_config,
Expand Down
55 changes: 21 additions & 34 deletions ipa-core/src/cli/playbook/hybrid.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#![cfg(all(feature = "web-app", feature = "cli"))]
use std::{
cmp::min,
io::BufRead,
time::{Duration, Instant},
};

use std::iter::zip;
use futures_util::future::try_join_all;
use serde::{Deserialize, Serialize};
use tokio::time::sleep;

use crate::{
cli::playbook::{RoundRobinSubmission, StreamingSubmission},
ff::{Serializable, U128Conversions},
helpers::{
query::{HybridQueryParams, QueryInput, QuerySize},
Expand All @@ -26,54 +24,43 @@ use crate::{
/// # Panics
/// if results are invalid
#[allow(clippy::disallowed_methods)] // allow try_join_all
pub async fn run_hybrid_query_and_validate<HV, R>(
inputs: [RoundRobinSubmission<R>; 3],
pub async fn run_hybrid_query_and_validate<HV>(
inputs: Vec<[BodyStream; 3]>,
query_size: usize,
shard_count: usize,
clients: Vec<[IpaHttpClient<Helper>; 3]>,
query_id: QueryId,
query_config: HybridQueryParams,
) -> HybridQueryResult
where
HV: SharedValue + U128Conversions,
AdditiveShare<HV>: Serializable,
R: BufRead + Send,
{
let mpc_time = Instant::now();
let leader_clients = clients[0].clone();

let transposed_inputs = inputs[0]
.into_byte_streams(shard_count)
.iter()
.zip(
inputs[1]
.into_byte_streams(shard_count)
assert_eq!(clients.len(), inputs.len());
// submit inputs to each shard
let _ = try_join_all(zip(clients.iter(), inputs.into_iter())
.map(|(shard_clients, shard_inputs)| {
try_join_all(shard_clients
.iter()
.zip(inputs[2].into_byte_streams(shard_count).iter()),
)
.map(|(i1, (i2, i3))| [i1, i2, i3])
.collect::<Vec<_>>();
.zip(shard_inputs.into_iter())
.map(|(client, input)|
{
client.query_input(QueryInput {
query_id,
input_stream: input
})
}
)
)
})).await.unwrap();

try_join_all(
transposed_inputs
.into_iter()
.flatten()
.zip(clients.into_iter().flatten())
.map(|(stream, client)| {
client.query_input(QueryInput {
query_id,
input_stream: BodyStream::from_bytes_stream(*stream),
})
}),
)
.await
.unwrap();
let leader_clients = &clients[0];

let mut delay = Duration::from_millis(125);
loop {
if try_join_all(
leader_clients
.iter()
.each_ref()
.map(|client| client.query_status(query_id)),
)
.await
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/tests/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ fn test_hybrid() {
.args(["--network".into(), config_path.join("network.toml")])
.args(["--output-file".as_ref(), output_file.as_os_str()])
.args(["--shard-count", SHARDS.to_string().as_str()])
.args(["--count", INPUT_SIZE.to_string().as_str()])
.args(["--wait", "2"])
.arg("malicious-hybrid")
.args(["--count", INPUT_SIZE.to_string().as_str()])
.args(["--enc-input-file1".as_ref(), enc1.as_os_str()])
.args(["--enc-input-file2".as_ref(), enc2.as_os_str()])
.args(["--enc-input-file3".as_ref(), enc3.as_os_str()])
Expand Down

0 comments on commit 1e6418c

Please sign in to comment.