diff --git a/benches/oneshot/ipa.rs b/benches/oneshot/ipa.rs index 204624f90..7760793b9 100644 --- a/benches/oneshot/ipa.rs +++ b/benches/oneshot/ipa.rs @@ -1,3 +1,4 @@ +use clap::Parser; use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; use raw_ipa::{ error::Error, @@ -10,18 +11,41 @@ use raw_ipa::{ }; use std::time::Instant; +#[derive(Debug, Parser)] +pub struct BenchArgs { + #[arg( + long, + short = 'i', + help = "Input size, in records", + default_value = "100" + )] + pub query_size: usize, + + #[arg(value_enum, long, short = 'm', help = "Run malicious or semi-honest IPA", default_value_t = IpaSecurityModel::Malicious)] + pub mode: IpaSecurityModel, + + #[arg(long, help = "Capped user contributions", default_value = "1")] + pub user_cap: u32, + + #[arg(short, long, help = "ignored")] + pub bench: bool, +} + #[tokio::main(flavor = "multi_thread", worker_threads = 3)] async fn main() -> Result<(), Error> { const MAX_BREAKDOWN_KEY: u32 = 16; const MAX_TRIGGER_VALUE: u32 = 5; - const QUERY_SIZE: usize = 100; const MAX_RECORDS_PER_USER: usize = 10; const ATTRIBUTION_WINDOW_SECONDS: u32 = 0; type BenchField = Fp32BitPrime; + let args = BenchArgs::parse(); + let query_size = args.query_size; + let mode = args.mode; + let mut config = TestWorldConfig::default(); config.gateway_config = - GatewayConfig::symmetric_buffers::(QUERY_SIZE.clamp(16, 1024)); + GatewayConfig::symmetric_buffers::(query_size.clamp(16, 1024)); let random_seed = thread_rng().gen(); println!("Using random seed: {random_seed}"); @@ -29,15 +53,15 @@ async fn main() -> Result<(), Error> { let mut total_count = 0; - let mut random_user_records = Vec::with_capacity(QUERY_SIZE / MAX_RECORDS_PER_USER); - while total_count < QUERY_SIZE { + let mut random_user_records = Vec::with_capacity(query_size / MAX_RECORDS_PER_USER); + while total_count < query_size { let mut records_for_user = generate_random_user_records_in_reverse_chronological_order( &mut rng, MAX_RECORDS_PER_USER, MAX_BREAKDOWN_KEY, MAX_TRIGGER_VALUE, ); - records_for_user.truncate(QUERY_SIZE - total_count); + records_for_user.truncate(query_size - total_count); total_count += records_for_user.len(); random_user_records.push(records_for_user); @@ -49,8 +73,7 @@ async fn main() -> Result<(), Error> { // This is part of the IPA spec. Callers should do this before sending a batch of records in for processing. raw_data.sort_unstable_by(|a, b| a.timestamp.cmp(&b.timestamp)); - let per_user_cap = 3; - // for per_user_cap in [1, 3] { + let per_user_cap = args.user_cap; let mut expected_results = vec![0_u32; MAX_BREAKDOWN_KEY.try_into().unwrap()]; for records_for_user in &random_user_records { @@ -70,11 +93,13 @@ async fn main() -> Result<(), Error> { per_user_cap, MAX_BREAKDOWN_KEY, ATTRIBUTION_WINDOW_SECONDS, - IpaSecurityModel::Malicious, + mode, ) .await; - // } + let duration = start.elapsed().as_secs_f32(); - println!("IPA benchmark for QUERY_SIZE {QUERY_SIZE} complete successfully after {duration}s"); + println!( + "{mode:?} IPA benchmark for {query_size} records complete successfully after {duration}s" + ); Ok(()) } diff --git a/src/test_fixture/ipa.rs b/src/test_fixture/ipa.rs index faca35c0c..00c94a011 100644 --- a/src/test_fixture/ipa.rs +++ b/src/test_fixture/ipa.rs @@ -19,6 +19,8 @@ use crate::{ use super::TestWorld; +#[derive(Debug, Copy, Clone)] +#[cfg_attr(feature = "cli", derive(clap::ValueEnum))] pub enum IpaSecurityModel { SemiHonest, Malicious,