diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index fb41fcb9a..1a0d6755c 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -4,13 +4,14 @@ use std::{ fmt::Debug, fs::{File, OpenOptions}, io, - io::{stdout, BufReader, Write}, + io::{stdout, BufRead, BufReader, Write}, + iter::zip, ops::Deref, path::{Path, PathBuf}, }; use clap::{Parser, Subcommand}; -use hyper::http::uri::Scheme; +use hyper::{http::uri::Scheme, Uri}; use ipa_core::{ cli::{ playbook::{ @@ -24,11 +25,13 @@ use ipa_core::{ ff::{boolean_array::BA32, FieldType}, helpers::{ query::{ - DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QuerySize, QueryType, + DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QueryInput, QuerySize, + QueryType, }, BodyStream, }, net::{Helper, IpaHttpClient}, + protocol::QueryId, report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ ipa::{ipa_in_the_clear, CappingOrder, IpaSecurityModel, TestRawDataRecord}, @@ -143,7 +146,14 @@ enum ReportCollectorCommand { }, MaliciousHybrid { #[clap(flatten)] - encrypted_inputs: EncryptedInputs, + encrypted_inputs: Option, + + #[arg( + long, + help = "Read the list of URLs that contain the input from the provided file", + conflicts_with_all = ["enc_input_file1", "enc_input_file2", "enc_input_file3"] + )] + url_file_list: Option, #[clap(flatten)] hybrid_query_config: HybridQueryParams, @@ -267,6 +277,7 @@ async fn main() -> Result<(), Box> { } ReportCollectorCommand::MaliciousHybrid { ref encrypted_inputs, + ref url_file_list, hybrid_query_config, count, set_fixed_polling_ms, @@ -275,7 +286,19 @@ async fn main() -> Result<(), Box> { &args, hybrid_query_config, clients, - encrypted_inputs, + |query_id| { + if let Some(ref url_file_list) = url_file_list { + inputs_from_url_file(url_file_list, query_id, args.shard_count) + } else if let Some(ref encrypted_inputs) = encrypted_inputs { + Ok(inputs_from_encrypted_inputs( + encrypted_inputs, + query_id, + args.shard_count, + )) + } else { + panic!("Either --url-file-list or --enc-input-file1, --enc-input-file2, and --enc-input-file3 must be provided"); + } + }, count.try_into().expect("u32 should fit into usize"), set_fixed_polling_ms, ) @@ -286,6 +309,95 @@ async fn main() -> Result<(), Box> { Ok(()) } +fn inputs_from_url_file( + url_file_path: &Path, + query_id: QueryId, + shard_count: usize, +) -> Result, Box> { + let mut file = BufReader::new(File::open(url_file_path)?); + let mut buf = String::new(); + let mut inputs = [Vec::new(), Vec::new(), Vec::new()]; + for helper_input in inputs.iter_mut() { + for _ in 0..shard_count { + buf.clear(); + if file.read_line(&mut buf)? == 0 { + break; + } + helper_input + .push(Uri::try_from(buf.trim()).map_err(|e| format!("Invalid URL {buf:?}: {e}"))?); + } + } + + // make sure all helpers have the expected number of inputs (one per shard) + let all_rows = inputs.iter().map(|v| v.len()).sum::(); + if all_rows != 3 * shard_count { + return Err(format!( + "The number of URLs in {url_file_path:?} '{all_rows}' is less than 3*{shard_count}." + ) + .into()); + } + + let [h1, h2, h3] = inputs; + Ok(zip(zip(h1, h2), h3) + .map(|((h1, h2), h3)| { + [ + QueryInput::FromUrl { + url: h1.to_string(), + query_id, + }, + QueryInput::FromUrl { + url: h2.to_string(), + query_id, + }, + QueryInput::FromUrl { + url: h3.to_string(), + query_id, + }, + ] + }) + .collect()) +} + +fn inputs_from_encrypted_inputs( + encrypted_inputs: &EncryptedInputs, + query_id: QueryId, + shard_count: usize, +) -> Vec<[QueryInput; 3]> { + let [h1_streams, h2_streams, h3_streams] = [ + &encrypted_inputs.enc_input_file1, + &encrypted_inputs.enc_input_file2, + &encrypted_inputs.enc_input_file3, + ] + .map(|path| { + let file = File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); + RoundRobinSubmission::new(BufReader::new(file)) + }) + .map(|s| s.into_byte_streams(shard_count)); + + // create byte streams for each shard + h1_streams + .into_iter() + .zip(h2_streams) + .zip(h3_streams) + .map(|((s1, s2), s3)| { + [ + QueryInput::Inline { + input_stream: BodyStream::from_bytes_stream(s1), + query_id, + }, + QueryInput::Inline { + input_stream: BodyStream::from_bytes_stream(s2), + query_id, + }, + QueryInput::Inline { + input_stream: BodyStream::from_bytes_stream(s3), + query_id, + }, + ] + }) + .collect::>() +} + fn gen_hybrid_inputs( count: u32, seed: Option, @@ -422,41 +534,16 @@ fn write_hybrid_output_file( Ok(()) } -async fn hybrid( +async fn hybrid Result, Box>>( args: &Args, hybrid_query_config: HybridQueryParams, helper_clients: Vec<[IpaHttpClient; 3]>, - encrypted_inputs: &EncryptedInputs, + make_inputs_fn: F, count: usize, set_fixed_polling_ms: Option, ) -> Result<(), Box> { let query_type = QueryType::MaliciousHybrid(hybrid_query_config); - let [h1_streams, h2_streams, h3_streams] = [ - &encrypted_inputs.enc_input_file1, - &encrypted_inputs.enc_input_file2, - &encrypted_inputs.enc_input_file3, - ] - .map(|path| { - let file = File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); - RoundRobinSubmission::new(BufReader::new(file)) - }) - .map(|s| s.into_byte_streams(args.shard_count)); - - // create byte streams for each shard - let submissions = h1_streams - .into_iter() - .zip(h2_streams.into_iter()) - .zip(h3_streams.into_iter()) - .map(|((s1, s2), s3)| { - [ - BodyStream::from_bytes_stream(s1), - BodyStream::from_bytes_stream(s2), - BodyStream::from_bytes_stream(s3), - ] - }) - .collect::>(); - let query_config = QueryConfig { size: QuerySize::try_from(count).unwrap(), field_type: FieldType::Fp32BitPrime, @@ -469,6 +556,7 @@ async fn hybrid( .expect("Unable to create query!"); tracing::info!("Starting query for OPRF"); + let submissions = make_inputs_fn(query_id)?; // the value for histogram values (BA32) must be kept in sync with the server-side // implementation, otherwise a runtime reconstruct error will be generated. @@ -477,7 +565,6 @@ async fn hybrid( submissions, count, helper_clients, - query_id, hybrid_query_config, set_fixed_polling_ms, ) diff --git a/ipa-core/src/cli/playbook/hybrid.rs b/ipa-core/src/cli/playbook/hybrid.rs index 53bfc6c28..ff6f7d3a4 100644 --- a/ipa-core/src/cli/playbook/hybrid.rs +++ b/ipa-core/src/cli/playbook/hybrid.rs @@ -11,12 +11,8 @@ use tokio::time::sleep; use crate::{ ff::{Serializable, U128Conversions}, - helpers::{ - query::{HybridQueryParams, QueryInput, QuerySize}, - BodyStream, - }, + helpers::query::{HybridQueryParams, QueryInput, QuerySize}, net::{Helper, IpaHttpClient}, - protocol::QueryId, query::QueryStatus, secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, test_fixture::Reconstruct, @@ -26,10 +22,9 @@ use crate::{ /// if results are invalid #[allow(clippy::disallowed_methods)] // allow try_join_all pub async fn run_hybrid_query_and_validate( - inputs: Vec<[BodyStream; 3]>, + inputs: Vec<[QueryInput; 3]>, query_size: usize, clients: Vec<[IpaHttpClient; 3]>, - query_id: QueryId, query_config: HybridQueryParams, set_fixed_polling_ms: Option, ) -> HybridQueryResult @@ -37,19 +32,21 @@ where HV: SharedValue + U128Conversions, AdditiveShare: Serializable, { + let query_id = inputs + .first() + .map(|v| v[0].query_id()) + .expect("At least one shard must be used to run a Hybrid query"); let mpc_time = Instant::now(); assert_eq!(clients.len(), inputs.len()); // submit inputs to each shard let _ = try_join_all(zip(clients.iter(), inputs.into_iter()).map( |(shard_clients, shard_inputs)| { - try_join_all(shard_clients.iter().zip(shard_inputs.into_iter()).map( - |(client, input)| { - client.query_input(QueryInput::Inline { - query_id, - input_stream: input, - }) - }, - )) + try_join_all( + shard_clients + .iter() + .zip(shard_inputs.into_iter()) + .map(|(client, input)| client.query_input(input)), + ) }, )) .await diff --git a/ipa-core/tests/hybrid.rs b/ipa-core/tests/hybrid.rs index b40f524f0..50251f8bc 100644 --- a/ipa-core/tests/hybrid.rs +++ b/ipa-core/tests/hybrid.rs @@ -4,22 +4,34 @@ mod common; use std::{ fs::File, + io::{BufReader, Read, Write}, + iter::once, + net::TcpListener, + os::fd::AsRawFd, + path::{Path, PathBuf}, process::{Command, Stdio}, }; +use bytes::Bytes; +use command_fds::CommandFdExt; use common::{ spawn_shards, tempdir::TempDir, test_sharded_setup, CommandExt, TerminateOnDropExt, UnwrapStatusExt, CRYPTO_UTIL_BIN, TEST_RC_BIN, }; -use ipa_core::{cli::playbook::HybridQueryResult, helpers::query::HybridQueryParams}; +use futures_util::{StreamExt, TryStreamExt}; +use ipa_core::{ + cli::playbook::HybridQueryResult, + error::BoxError, + helpers::{query::HybridQueryParams, LengthDelimitedStream}, +}; use rand::thread_rng; use rand_core::RngCore; use serde_json::from_reader; +use crate::common::TEST_MPC_BIN; + pub const IN_THE_CLEAR_BIN: &str = env!("CARGO_BIN_EXE_in_the_clear"); -// this currently only generates data and runs in the clear -// eventaully we'll want to add the MPC as well #[test] fn test_hybrid() { const INPUT_SIZE: usize = 100; @@ -134,3 +146,212 @@ fn test_hybrid() { .zip(expected_result.iter()) .all(|(a, b)| a == b)); } + +#[test] +fn test_hybrid_poll() { + const INPUT_SIZE: usize = 100; + const SHARDS: usize = 5; + const MAX_CONVERSION_VALUE: usize = 5; + + let config = HybridQueryParams { + max_breakdown_key: 5, + with_dp: 0, + epsilon: 0.0, + // only encrypted inputs are supported + plaintext_match_keys: false, + }; + + let dir = TempDir::new_delete_on_drop(); + + // Gen inputs + let input_file = dir.path().join("ipa_inputs.txt"); + let in_the_clear_output_file = dir.path().join("ipa_output_in_the_clear.json"); + let output_file = dir.path().join("ipa_output.json"); + + let mut command = Command::new(TEST_RC_BIN); + command + .args(["--output-file".as_ref(), input_file.as_os_str()]) + .arg("gen-hybrid-inputs") + .args(["--count", &INPUT_SIZE.to_string()]) + .args(["--max-conversion-value", &MAX_CONVERSION_VALUE.to_string()]) + .args(["--max-breakdown-key", &config.max_breakdown_key.to_string()]) + .args(["--seed", &thread_rng().next_u64().to_string()]) + .silent() + .stdin(Stdio::piped()); + command.status().unwrap_status(); + + let mut command = Command::new(IN_THE_CLEAR_BIN); + command + .args(["--input-file".as_ref(), input_file.as_os_str()]) + .args([ + "--output-file".as_ref(), + in_the_clear_output_file.as_os_str(), + ]) + .silent() + .stdin(Stdio::piped()); + command.status().unwrap_status(); + + let config_path = dir.path().join("config"); + let sockets = test_sharded_setup::(&config_path); + let _helpers = spawn_shards(&config_path, &sockets, true); + + // encrypt input + let mut command = Command::new(CRYPTO_UTIL_BIN); + command + .arg("hybrid-encrypt") + .args(["--input-file".as_ref(), input_file.as_os_str()]) + .args(["--output-dir".as_ref(), dir.path().as_os_str()]) + .args(["--length-delimited"]) + .args(["--network".into(), config_path.join("network.toml")]) + .stdin(Stdio::piped()); + command.status().unwrap_status(); + let enc1 = dir.path().join("helper1.enc"); + let enc2 = dir.path().join("helper2.enc"); + let enc3 = dir.path().join("helper3.enc"); + + let poll_port = TcpListener::bind("127.0.0.1:0").unwrap(); + + // split encryption into N shards and create a metadata file that contains + // all files + let upload_metadata = create_upload_files::( + &enc1, + &enc2, + &enc3, + poll_port.local_addr().unwrap().port(), + dir.path(), + ) + .unwrap(); + + // spawn HTTP server to serve the uploaded files + let mut command = Command::new(TEST_MPC_BIN); + command + .arg("serve-input") + .preserved_fds(vec![poll_port.as_raw_fd()]) + .args(["--fd", &poll_port.as_raw_fd().to_string()]) + .args([ + "--dir".as_ref(), + upload_metadata.parent().unwrap().as_os_str(), + ]) + .silent(); + + let _server_handle = command.spawn().unwrap().terminate_on_drop(); + + // Run Hybrid + let mut command = Command::new(TEST_RC_BIN); + command + .args(["--network".into(), config_path.join("network.toml")]) + .args(["--output-file".as_ref(), output_file.as_os_str()]) + .args(["--shard-count", SHARDS.to_string().as_str()]) + .args(["--wait", "2"]) + .arg("malicious-hybrid") + .silent() + .args(["--count", INPUT_SIZE.to_string().as_str()]) + .args(["--url-file-list".into(), upload_metadata]) + .args(["--max-breakdown-key", &config.max_breakdown_key.to_string()]); + + match config.with_dp { + 0 => { + command.args(["--with-dp", &config.with_dp.to_string()]); + } + _ => { + command + .args(["--with-dp", &config.with_dp.to_string()]) + .args(["--epsilon", &config.epsilon.to_string()]); + } + } + command.stdin(Stdio::piped()); + + let test_mpc = command.spawn().unwrap().terminate_on_drop(); + test_mpc.wait().unwrap_status(); + + // basic output checks - output should have the exact size as number of breakdowns + let output = serde_json::from_str::( + &std::fs::read_to_string(&output_file).expect("IPA results file should exist"), + ) + .expect("IPA results file is valid JSON"); + + assert_eq!( + usize::try_from(config.max_breakdown_key).unwrap(), + output.breakdowns.len(), + "Number of breakdowns does not match the expected", + ); + assert_eq!(INPUT_SIZE, usize::from(output.input_size)); + + let expected_result: Vec = from_reader( + File::open(in_the_clear_output_file) + .expect("file should exist as it's created above in the test"), + ) + .expect("should match hard coded format from in_the_clear"); + assert!(output + .breakdowns + .iter() + .zip(expected_result.iter()) + .all(|(a, b)| a == b)); +} + +fn create_upload_files( + enc_file1: &Path, + enc_file2: &Path, + enc_file3: &Path, + port: u16, + dest: &Path, +) -> Result { + let manifest_path = dest.join("manifest.txt"); + let mut manifest_file = File::create_new(&manifest_path)?; + create_upload_file::("h1", enc_file1, port, dest, &mut manifest_file)?; + create_upload_file::("h2", enc_file2, port, dest, &mut manifest_file)?; + create_upload_file::("h3", enc_file3, port, dest, &mut manifest_file)?; + + manifest_file.flush()?; + + Ok(manifest_path) +} + +fn create_upload_file( + prefix: &str, + enc_file: &Path, + port: u16, + dest_dir: &Path, + metadata_file: &mut File, +) -> Result<(), BoxError> { + let mut files = (0..SHARDS) + .map(|i| { + let path = dest_dir.join(format!("{prefix}_shard_{i}.enc")); + let file = File::create_new(&path)?; + Ok((path, file)) + }) + .collect::>>()?; + + // we assume files are tiny for the integration tests + let mut input = BufReader::new(File::open(enc_file)?); + let mut buf = Vec::new(); + if input.read_to_end(&mut buf)? == 0 { + panic!("{:?} file is empty", enc_file); + } + + // read length delimited data and write it to each file + let stream = + LengthDelimitedStream::::new(futures::stream::iter(once(Ok::<_, BoxError>( + buf.into(), + )))) + .map_ok(|v| futures::stream::iter(v).map(Ok::<_, BoxError>)) + .try_flatten(); + + for (i, next_bytes) in futures::executor::block_on_stream(stream).enumerate() { + let next_bytes = next_bytes?; + let file = &mut files[i % SHARDS].1; + let len = u16::try_from(next_bytes.len()) + .map_err(|_| format!("record is too too big: {} > 65535", next_bytes.len()))?; + file.write(&len.to_le_bytes())?; + file.write_all(&next_bytes)?; + } + + // update manifest file + for (path, mut file) in files { + file.flush()?; + let path = path.file_name().and_then(|p| p.to_str()).unwrap(); + writeln!(metadata_file, "http://localhost:{port}/{path}")?; + } + + Ok(()) +}