Skip to content

Commit

Permalink
url_file_list parameter for report collector
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Dec 19, 2024
1 parent 9148d17 commit 76238ac
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 47 deletions.
148 changes: 116 additions & 32 deletions ipa-core/src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use std::{
fmt::Debug,
fs::{File, OpenOptions},
io,
io::{stdout, BufReader, Write},
io::{stdout, BufRead, BufReader, Write},
iter::zip,
ops::Deref,
path::{Path, PathBuf},
};

use clap::{Parser, Subcommand};
use hyper::http::uri::Scheme;
use hyper::{http::uri::Scheme, Uri};
use ipa_core::{
cli::{
playbook::{
Expand All @@ -24,11 +25,13 @@ use ipa_core::{
ff::{boolean_array::BA32, FieldType},
helpers::{
query::{
DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QuerySize, QueryType,
DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QueryInput, QuerySize,
QueryType,
},
BodyStream,
},
net::{Helper, IpaHttpClient},
protocol::QueryId,
report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID},
test_fixture::{
ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord},
Expand Down Expand Up @@ -145,6 +148,13 @@ enum ReportCollectorCommand {
#[clap(flatten)]
encrypted_inputs: EncryptedInputs,

#[arg(
long,
help = "Read the list of URLs that contain the input from the provided file",
conflicts_with_all = ["enc_input_file1", "enc_input_file2", "enc_input_file3"]
)]
url_file_list: Option<PathBuf>,

#[clap(flatten)]
hybrid_query_config: HybridQueryParams,

Expand Down Expand Up @@ -267,6 +277,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
}
ReportCollectorCommand::MaliciousHybrid {
ref encrypted_inputs,
ref url_file_list,
hybrid_query_config,
count,
set_fixed_polling_ms,
Expand All @@ -275,7 +286,18 @@ async fn main() -> Result<(), Box<dyn Error>> {
&args,
hybrid_query_config,
clients,
encrypted_inputs,
|query_id| {
if let Some(ref url_file_list) = url_file_list {
inputs_from_url_file(&url_file_list, query_id, args.shard_count)
} else {
Ok(inputs_from_encrypted_inputs(
encrypted_inputs,
query_id,
args.shard_count,
))
}
},
// encrypted_inputs,
count.try_into().expect("u32 should fit into usize"),
set_fixed_polling_ms,
)
Expand All @@ -286,6 +308,93 @@ async fn main() -> Result<(), Box<dyn Error>> {
Ok(())
}

fn inputs_from_url_file(
url_file_path: &Path,
query_id: QueryId,
shard_count: usize,
) -> Result<Vec<[QueryInput; 3]>, Box<dyn Error>> {
let mut file = BufReader::new(File::open(url_file_path)?);
let mut buf = String::new();
let mut inputs = [Vec::new(), Vec::new(), Vec::new()];
for helper_id in 0..3 {
for _ in 0..shard_count {
buf.clear();
if file.read_line(&mut buf)? == 0 {
break;
}
inputs[helper_id]
.push(Uri::try_from(buf.trim()).map_err(|e| format!("Invalid URL {buf:?}: {e}"))?);
}
if inputs[helper_id].len() != shard_count {
return Err(format!(
"Helper {helper_id} does not have enough input. Expected {shard_count}, got {}",
inputs[helper_id].len()
)
.into());
}
}

let [h1, h2, h3] = inputs;
Ok(zip(zip(h1, h2), h3)
.map(|((h1, h2), h3)| {
[
QueryInput::FromUrl {
url: h1.to_string(),
query_id,
},
QueryInput::FromUrl {
url: h2.to_string(),
query_id,
},
QueryInput::FromUrl {
url: h3.to_string(),
query_id,
},
]
})
.collect())
}

fn inputs_from_encrypted_inputs(
encrypted_inputs: &EncryptedInputs,
query_id: QueryId,
shard_count: usize,
) -> Vec<[QueryInput; 3]> {
let [h1_streams, h2_streams, h3_streams] = [
&encrypted_inputs.enc_input_file1,
&encrypted_inputs.enc_input_file2,
&encrypted_inputs.enc_input_file3,
]
.map(|path| {
let file = File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}"));
RoundRobinSubmission::new(BufReader::new(file))
})
.map(|s| s.into_byte_streams(shard_count));

// create byte streams for each shard
h1_streams
.into_iter()
.zip(h2_streams.into_iter())
.zip(h3_streams.into_iter())
.map(|((s1, s2), s3)| {
[
QueryInput::Inline {
input_stream: BodyStream::from_bytes_stream(s1),
query_id,
},
QueryInput::Inline {
input_stream: BodyStream::from_bytes_stream(s2),
query_id,
},
QueryInput::Inline {
input_stream: BodyStream::from_bytes_stream(s3),
query_id,
},
]
})
.collect::<Vec<_>>()
}

fn gen_hybrid_inputs(
count: u32,
seed: Option<u64>,
Expand Down Expand Up @@ -422,41 +531,16 @@ fn write_hybrid_output_file(
Ok(())
}

async fn hybrid(
async fn hybrid<F: FnOnce(QueryId) -> Result<Vec<[QueryInput; 3]>, Box<dyn Error>>>(
args: &Args,
hybrid_query_config: HybridQueryParams,
helper_clients: Vec<[IpaHttpClient<Helper>; 3]>,
encrypted_inputs: &EncryptedInputs,
make_inputs_fn: F,
count: usize,
set_fixed_polling_ms: Option<u64>,
) -> Result<(), Box<dyn Error>> {
let query_type = QueryType::MaliciousHybrid(hybrid_query_config);

let [h1_streams, h2_streams, h3_streams] = [
&encrypted_inputs.enc_input_file1,
&encrypted_inputs.enc_input_file2,
&encrypted_inputs.enc_input_file3,
]
.map(|path| {
let file = File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}"));
RoundRobinSubmission::new(BufReader::new(file))
})
.map(|s| s.into_byte_streams(args.shard_count));

// create byte streams for each shard
let submissions = h1_streams
.into_iter()
.zip(h2_streams.into_iter())
.zip(h3_streams.into_iter())
.map(|((s1, s2), s3)| {
[
BodyStream::from_bytes_stream(s1),
BodyStream::from_bytes_stream(s2),
BodyStream::from_bytes_stream(s3),
]
})
.collect::<Vec<_>>();

let query_config = QueryConfig {
size: QuerySize::try_from(count).unwrap(),
field_type: FieldType::Fp32BitPrime,
Expand All @@ -469,6 +553,7 @@ async fn hybrid(
.expect("Unable to create query!");

tracing::info!("Starting query for OPRF");
let submissions = make_inputs_fn(query_id)?;

// the value for histogram values (BA32) must be kept in sync with the server-side
// implementation, otherwise a runtime reconstruct error will be generated.
Expand All @@ -477,7 +562,6 @@ async fn hybrid(
submissions,
count,
helper_clients,
query_id,
hybrid_query_config,
set_fixed_polling_ms,
)
Expand Down
27 changes: 12 additions & 15 deletions ipa-core/src/cli/playbook/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,8 @@ use tokio::time::sleep;

use crate::{
ff::{Serializable, U128Conversions},
helpers::{
query::{HybridQueryParams, QueryInput, QuerySize},
BodyStream,
},
helpers::query::{HybridQueryParams, QueryInput, QuerySize},
net::{Helper, IpaHttpClient},
protocol::QueryId,
query::QueryStatus,
secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue},
test_fixture::Reconstruct,
Expand All @@ -26,30 +22,31 @@ use crate::{
/// if results are invalid
#[allow(clippy::disallowed_methods)] // allow try_join_all
pub async fn run_hybrid_query_and_validate<HV>(
inputs: Vec<[BodyStream; 3]>,
inputs: Vec<[QueryInput; 3]>,
query_size: usize,
clients: Vec<[IpaHttpClient<Helper>; 3]>,
query_id: QueryId,
query_config: HybridQueryParams,
set_fixed_polling_ms: Option<u64>,
) -> HybridQueryResult
where
HV: SharedValue + U128Conversions,
AdditiveShare<HV>: Serializable,
{
let query_id = inputs
.first()
.map(|v| v[0].query_id())
.expect("At least one shard must be used to run a Hybrid query");
let mpc_time = Instant::now();
assert_eq!(clients.len(), inputs.len());
// submit inputs to each shard
let _ = try_join_all(zip(clients.iter(), inputs.into_iter()).map(
|(shard_clients, shard_inputs)| {
try_join_all(shard_clients.iter().zip(shard_inputs.into_iter()).map(
|(client, input)| {
client.query_input(QueryInput::Inline {
query_id,
input_stream: input,
})
},
))
try_join_all(
shard_clients
.iter()
.zip(shard_inputs.into_iter())
.map(|(client, input)| client.query_input(input)),
)
},
))
.await
Expand Down

0 comments on commit 76238ac

Please sign in to comment.