From b27a4b304d316a9732d6bed470814733a1b279d6 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sat, 7 Dec 2024 11:24:26 -0800 Subject: [PATCH] Improve performance of hybrid encrypt CLI (#1483) * Improve performance of hybrid encrypt CLI. There were quite a few bottlenecks here: * Writes were done serially, writing one file at a time. * Shares were encrypted on a single CPU core I almost used `rayon` to parallelize encryption, but the problem is that we need to get the output sorted to maintain total order across files. Rayon can do that, but requires collecting `ParallelIterator` which would be bad for generating 100M+ reports. Our goal is to be able to share and encrypt 1B, so streaming and manual fiddling with thread pools is justified imo. The way this CLI works right now: it keeps a compute pool for encryption (thread-per-core) and a separate pool of 3 threads to write data for each helper in parallel I also made a few tweaks to improve code re-usability in this module. ## Benchmarks Done locally on M1 Mac Pro (10 cores) Before this change: Encryption process is completed. 442.15834075s After this change Encryption process is completed. 55.63269625s * Feedback --- ipa-core/Cargo.toml | 5 +- ipa-core/src/bin/crypto_util.rs | 10 +- ipa-core/src/bin/report_collector.rs | 4 +- ipa-core/src/cli/crypto/encrypt.rs | 4 +- ipa-core/src/cli/crypto/hybrid_encrypt.rs | 231 ++++++++++++++++++++-- ipa-core/src/config.rs | 14 +- ipa-core/src/hpke/registry.rs | 9 + ipa-core/src/report/hybrid.rs | 13 +- 8 files changed, 249 insertions(+), 41 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 6c8df1f33..496699e34 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -17,13 +17,11 @@ default = [ # by default remove all TRACE, DEBUG spans from release builds "tracing/max_level_trace", "tracing/release_max_level_info", - "aggregate-circuit", "stall-detection", - "aggregate-circuit", "ipa-prf", "descriptive-gate", ] -cli = ["comfy-table", "clap"] +cli = ["comfy-table", "clap", "num_cpus"] # Enable compact gate optimization compact-gate = [] # mutually exclusive with compact-gate and disables compact gate optimization. @@ -130,6 +128,7 @@ hyper-util = { version = "0.1.3", optional = true, features = ["http2"] } http-body-util = { version = "0.1.1", optional = true } http-body = { version = "1", optional = true } iai = { version = "0.1.1", optional = true } +num_cpus = { version = "1.0", optional = true } once_cell = "1.18" pin-project = "1.0" rand = "0.8" diff --git a/ipa-core/src/bin/crypto_util.rs b/ipa-core/src/bin/crypto_util.rs index c884560f5..8d0de6870 100644 --- a/ipa-core/src/bin/crypto_util.rs +++ b/ipa-core/src/bin/crypto_util.rs @@ -2,7 +2,10 @@ use std::fmt::Debug; use clap::{Parser, Subcommand}; use ipa_core::{ - cli::crypto::{DecryptArgs, EncryptArgs, HybridDecryptArgs, HybridEncryptArgs}, + cli::{ + crypto::{DecryptArgs, EncryptArgs, HybridDecryptArgs, HybridEncryptArgs}, + Verbosity, + }, error::BoxError, }; @@ -10,6 +13,10 @@ use ipa_core::{ #[clap(name = "crypto-util", about = "Crypto Util CLI")] #[command(about)] struct Args { + // Configure logging. + #[clap(flatten)] + logging: Verbosity, + #[command(subcommand)] action: CryptoUtilCommand, } @@ -25,6 +32,7 @@ enum CryptoUtilCommand { #[tokio::main] async fn main() -> Result<(), BoxError> { let args = Args::parse(); + let _handle = args.logging.setup_logging(); match args.action { CryptoUtilCommand::Encrypt(encrypt_args) => encrypt_args.encrypt()?, CryptoUtilCommand::HybridEncrypt(hybrid_encrypt_args) => hybrid_encrypt_args.encrypt()?, diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 9866fb62a..861e2f0c5 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -573,7 +573,7 @@ async fn ipa_test( r }; - let mut key_registries = KeyRegistries::default(); + let key_registries = KeyRegistries::default(); let Some(key_registries) = key_registries.init_from(network) else { panic!("could not load network file") }; @@ -585,7 +585,7 @@ async fn ipa_test( helper_clients, query_id, ipa_query_config, - Some((DEFAULT_KEY_ID, key_registries)), + Some((DEFAULT_KEY_ID, key_registries.each_ref())), ) .await; diff --git a/ipa-core/src/cli/crypto/encrypt.rs b/ipa-core/src/cli/crypto/encrypt.rs index 1b29b4dcf..9a63499d3 100644 --- a/ipa-core/src/cli/crypto/encrypt.rs +++ b/ipa-core/src/cli/crypto/encrypt.rs @@ -55,7 +55,7 @@ impl EncryptArgs { let input = InputSource::from_file(&self.input_file); let mut rng = thread_rng(); - let mut key_registries = KeyRegistries::default(); + let key_registries = KeyRegistries::default(); let network = NetworkConfig::from_toml_str(&read_to_string(&self.network).unwrap_or_else(|e| { @@ -84,7 +84,7 @@ impl EncryptArgs { for share in shares { let output = share - .encrypt(DEFAULT_KEY_ID, key_registry, &mut rng) + .encrypt(DEFAULT_KEY_ID, &key_registry, &mut rng) .unwrap(); let hex_output = hex::encode(&output); writeln!(writer, "{hex_output}")?; diff --git a/ipa-core/src/cli/crypto/hybrid_encrypt.rs b/ipa-core/src/cli/crypto/hybrid_encrypt.rs index 7db024681..e7d903e20 100644 --- a/ipa-core/src/cli/crypto/hybrid_encrypt.rs +++ b/ipa-core/src/cli/crypto/hybrid_encrypt.rs @@ -1,8 +1,13 @@ use std::{ - fs::{read_to_string, OpenOptions}, - io::Write, - iter::zip, + array, + collections::BTreeMap, + fs::{read_to_string, File, OpenOptions}, + io::{BufWriter, Write}, path::{Path, PathBuf}, + sync::mpsc::{channel, Sender}, + thread, + thread::JoinHandle, + time::Instant, }; use clap::Parser; @@ -15,11 +20,23 @@ use crate::{ }, config::{KeyRegistries, NetworkConfig}, error::BoxError, + hpke::{KeyRegistry, PublicKeyOnly}, report::hybrid::{HybridReport, DEFAULT_KEY_ID}, secret_sharing::IntoShares, test_fixture::hybrid::TestHybridRecord, }; +/// Encryptor takes 3 arguments: `report_id`, helper that the shares must be encrypted towards +/// and the actual share ([`HybridReport`]) to encrypt. +type EncryptorInput = (usize, usize, HybridReport); +/// Encryptor sends report id and encrypted bytes down to file worker to write those bytes +/// down +type EncryptorOutput = (usize, Vec); +type FileWorkerInput = EncryptorOutput; + +/// This type is used quite often in this module +type UnitResult = Result<(), BoxError>; + #[derive(Debug, Parser)] #[clap(name = "test_hybrid_encrypt", about = "Test Hybrid Encrypt")] #[command(about)] @@ -51,11 +68,12 @@ impl HybridEncryptArgs { /// if input file or network file are not correctly formatted /// # Errors /// if it cannot open the files - pub fn encrypt(&self) -> Result<(), BoxError> { + pub fn encrypt(&self) -> UnitResult { + tracing::info!("encrypting input from {:?}", self.input_file); + let start = Instant::now(); let input = InputSource::from_file(&self.input_file); - let mut rng = thread_rng(); - let mut key_registries = KeyRegistries::default(); + let key_registries = KeyRegistries::default(); let network = NetworkConfig::from_toml_str_sharded(&read_to_string(&self.network).unwrap_or_else( @@ -71,28 +89,199 @@ impl HybridEncryptArgs { panic!("could not load network file") }; - let shares: [Vec>; 3] = - input.iter::().share(); + let mut worker_pool = ReportWriter::new(key_registries, &self.output_dir); + for (report_id, record) in input.iter::().enumerate() { + worker_pool.submit(report_id, record.share())?; + } + + worker_pool.join()?; + + let elapsed = start.elapsed(); + tracing::info!( + "Encryption process is completed. {}s", + elapsed.as_secs_f64() + ); + + Ok(()) + } +} - for (index, (shares, key_registry)) in zip(shares, key_registries).enumerate() { - let output_filename = format!("helper{}.enc", index + 1); - let mut writer = OpenOptions::new() +/// A thread-per-core pool responsible for encrypting reports in parallel. +/// This pool is shared across all writers to reduce the number of context switches. +struct EncryptorPool { + pool: Vec<(Sender, JoinHandle)>, + next_worker: usize, +} + +impl EncryptorPool { + pub fn with_worker_threads( + thread_count: usize, + file_writer: [Sender; 3], + key_registries: [KeyRegistry; 3], + ) -> Self { + Self { + pool: (0..thread_count) + .map(move |i| { + let (tx, rx) = channel::(); + let key_registries = key_registries.clone(); + let file_writer = file_writer.clone(); + ( + tx, + std::thread::Builder::new() + .name(format!("encryptor-{i}")) + .spawn(move || { + for (i, helper_id, report) in rx { + let key_registry = &key_registries[helper_id]; + let output = report.encrypt( + DEFAULT_KEY_ID, + key_registry, + &mut thread_rng(), + )?; + file_writer[helper_id].send((i, output))?; + } + + Ok(()) + }) + .unwrap(), + ) + }) + .collect(), + next_worker: 0, + } + } + + pub fn encrypt_share(&mut self, report: EncryptorInput) -> UnitResult { + let tx = &self.pool[self.next_worker].0; + tx.send(report)?; + self.next_worker = (self.next_worker + 1) % self.pool.len(); + + Ok(()) + } + + pub fn stop(self) -> UnitResult { + for (tx, handle) in self.pool { + drop(tx); + handle.join().unwrap()?; + } + + Ok(()) + } +} + +/// Performs end-to-end encryption, taking individual shares as input +/// (see [`ReportWriter::submit`]), encrypting them in parallel and writing +/// encrypted shares into 3 separate files. This optimizes for memory usage, +/// and maximizes CPU utilization. +struct ReportWriter { + encryptor_pool: EncryptorPool, + workers: Option<[FileWriteWorker; 3]>, +} + +impl ReportWriter { + pub fn new(key_registries: [KeyRegistry; 3], output_dir: &Path) -> Self { + // create 3 worker threads to write data into 3 files + let workers = array::from_fn(|i| { + let output_filename = format!("helper{}.enc", i + 1); + let file = OpenOptions::new() .write(true) .create_new(true) - .open(self.output_dir.join(&output_filename)) - .unwrap_or_else(|e| panic!("unable write to {}. {}", &output_filename, e)); - - for share in shares { - let output = share - .encrypt(DEFAULT_KEY_ID, key_registry, &mut rng) - .unwrap(); - let hex_output = hex::encode(&output); - writeln!(writer, "{hex_output}")?; - } + .open(output_dir.join(&output_filename)) + .unwrap_or_else(|e| panic!("unable write to {:?}. {}", &output_filename, e)); + + FileWriteWorker::new(file) + }); + let encryptor_pool = EncryptorPool::with_worker_threads( + num_cpus::get(), + workers.each_ref().map(|x| x.sender.clone()), + key_registries, + ); + + Self { + encryptor_pool, + workers: Some(workers), + } + } + + pub fn submit( + &mut self, + report_id: usize, + shares: [HybridReport; 3], + ) -> UnitResult { + for (i, share) in shares.into_iter().enumerate() { + self.encryptor_pool.encrypt_share((report_id, i, share))?; } Ok(()) } + + pub fn join(mut self) -> UnitResult { + self.encryptor_pool.stop()?; + self.workers + .take() + .unwrap() + .map(|worker| { + let FileWriteWorker { handle, sender } = worker; + drop(sender); + handle.join().unwrap() + }) + .into_iter() + .collect() + } +} + +/// This takes a file and writes all encrypted reports to it, +/// ensuring the same total order based on `report_id`. Report id is +/// just the index of file input row that guarantees consistency +/// of shares written across 3 files +struct FileWriteWorker { + sender: Sender, + handle: JoinHandle, +} + +impl FileWriteWorker { + pub fn new(file: File) -> Self { + let (tx, rx) = std::sync::mpsc::channel(); + Self { + sender: tx, + handle: thread::spawn(move || { + fn write_report(writer: &mut W, report: &[u8]) -> Result<(), BoxError> { + let hex_output = hex::encode(report); + writeln!(writer, "{hex_output}")?; + Ok(()) + } + + // write low watermark. All reports below this line have been written + let mut lw = 0; + let mut pending_reports = BTreeMap::new(); + + // Buffered writes should improve IO, but it is likely not the bottleneck here. + let mut writer = BufWriter::new(file); + for (report_id, report) in rx { + // Because reports are encrypted in parallel, it is possible + // to receive report_id = X+1 before X. To mitigate that, we keep + // a buffer, ordered by report_id and always write from low watermark. + // This ensures consistent order of reports written to files. Any misalignment + // will result in broken shares and garbage output. + assert!( + report_id >= lw, + "Internal error: received a report {report_id} below low watermark" + ); + assert!( + pending_reports.insert(report_id, report).is_none(), + "Internal error: received a duplicate report {report_id}" + ); + while let Some(report) = pending_reports.remove(&lw) { + write_report(&mut writer, &report)?; + lw += 1; + if lw % 1_000_000 == 0 { + tracing::info!("Encrypted {}M reports", lw / 1_000_000); + } + } + } + Ok(()) + }), + } + } } #[cfg(all(test, unit_test))] diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 132f80a64..2060320ac 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -466,10 +466,11 @@ pub struct KeyRegistries(Vec>); impl KeyRegistries { /// # Panics /// If network file is improperly formatted + #[must_use] pub fn init_from( - &mut self, + mut self, network: &NetworkConfig, - ) -> Option<[&KeyRegistry; 3]> { + ) -> Option<[KeyRegistry; 3]> { // Get the configs, if all three peers have one let peers = network.peers(); let configs = peers.iter().try_fold(Vec::new(), |acc, peer| { @@ -487,7 +488,14 @@ impl KeyRegistries { .map(|hpke| KeyRegistry::from_keys([PublicKeyOnly(hpke.public_key.clone())])) .collect::>>(); - Some(self.0.iter().collect::>().try_into().ok().unwrap()) + Some( + self.0 + .into_iter() + .collect::>() + .try_into() + .ok() + .unwrap(), + ) } } diff --git a/ipa-core/src/hpke/registry.rs b/ipa-core/src/hpke/registry.rs index 283d1fbd0..2811ca439 100644 --- a/ipa-core/src/hpke/registry.rs +++ b/ipa-core/src/hpke/registry.rs @@ -48,6 +48,7 @@ impl KeyPair { // The coherence rules prohibit us from implementing `PublicKeyRegistry` both for our concrete type // `KeyPair` and for `IpaPublicKey`, because the impls would overlap if hpke chose to define // `IpaPublicKey` to be the same as `KeyPair`. +#[derive(Clone)] pub struct PublicKeyOnly(pub IpaPublicKey); impl Deref for PublicKeyOnly { @@ -85,6 +86,14 @@ pub struct KeyRegistry { keys: Box<[K]>, } +impl Clone for KeyRegistry { + fn clone(&self) -> Self { + Self { + keys: self.keys.clone(), + } + } +} + impl KeyRegistry { /// Create a key registry with no keys. Since the registry is immutable, it is useless, /// but this avoids `Option` when the registry is ultimately not optional. diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index 6b01e3bde..eeada7823 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -27,7 +27,7 @@ //! all secret sharings (including the sharings of zero), making the collection of reports //! cryptographically indistinguishable. -use std::{collections::HashSet, convert::Infallible, iter::once, marker::PhantomData, ops::Add}; +use std::{collections::HashSet, convert::Infallible, marker::PhantomData, ops::Add}; use bytes::{Buf, BufMut, Bytes}; use generic_array::{ArrayLength, GenericArray}; @@ -477,14 +477,9 @@ where key_registry: &impl PublicKeyRegistry, rng: &mut R, ) -> Result, InvalidHybridReportError> { - match self { - HybridReport::Impression(impression_report) => { - impression_report.encrypt(key_id, key_registry, rng).map(|v| once(HybridEventType::Impression as u8).chain(v).collect()) - }, - HybridReport::Conversion(conversion_report) => { - conversion_report.encrypt(key_id, key_registry, rng).map(|v| once(HybridEventType::Conversion as u8).chain(v).collect()) - }, - } + let mut buf = Vec::new(); + self.encrypt_to(key_id, key_registry, rng, &mut buf)?; + Ok(buf) } /// # Errors