Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of hybrid encrypt CLI #1483

Merged
merged 2 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions ipa-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@ default = [
# by default remove all TRACE, DEBUG spans from release builds
"tracing/max_level_trace",
"tracing/release_max_level_info",
"aggregate-circuit",
"stall-detection",
"aggregate-circuit",
"ipa-prf",
"descriptive-gate",
]
cli = ["comfy-table", "clap"]
cli = ["comfy-table", "clap", "num_cpus"]
# Enable compact gate optimization
compact-gate = []
# mutually exclusive with compact-gate and disables compact gate optimization.
Expand Down Expand Up @@ -130,6 +128,7 @@ hyper-util = { version = "0.1.3", optional = true, features = ["http2"] }
http-body-util = { version = "0.1.1", optional = true }
http-body = { version = "1", optional = true }
iai = { version = "0.1.1", optional = true }
num_cpus = { version = "1.0", optional = true }
once_cell = "1.18"
pin-project = "1.0"
rand = "0.8"
Expand Down
10 changes: 9 additions & 1 deletion ipa-core/src/bin/crypto_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@ use std::fmt::Debug;

use clap::{Parser, Subcommand};
use ipa_core::{
cli::crypto::{DecryptArgs, EncryptArgs, HybridDecryptArgs, HybridEncryptArgs},
cli::{
crypto::{DecryptArgs, EncryptArgs, HybridDecryptArgs, HybridEncryptArgs},
Verbosity,
},
error::BoxError,
};

#[derive(Debug, Parser)]
#[clap(name = "crypto-util", about = "Crypto Util CLI")]
#[command(about)]
struct Args {
// Configure logging.
#[clap(flatten)]
logging: Verbosity,

#[command(subcommand)]
action: CryptoUtilCommand,
}
Expand All @@ -25,6 +32,7 @@ enum CryptoUtilCommand {
#[tokio::main]
async fn main() -> Result<(), BoxError> {
let args = Args::parse();
let _handle = args.logging.setup_logging();
match args.action {
CryptoUtilCommand::Encrypt(encrypt_args) => encrypt_args.encrypt()?,
CryptoUtilCommand::HybridEncrypt(hybrid_encrypt_args) => hybrid_encrypt_args.encrypt()?,
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ async fn ipa_test(
r
};

let mut key_registries = KeyRegistries::default();
let key_registries = KeyRegistries::default();
let Some(key_registries) = key_registries.init_from(network) else {
eriktaubeneck marked this conversation as resolved.
Show resolved Hide resolved
panic!("could not load network file")
};
Expand All @@ -546,7 +546,7 @@ async fn ipa_test(
helper_clients,
query_id,
ipa_query_config,
Some((DEFAULT_KEY_ID, key_registries)),
Some((DEFAULT_KEY_ID, key_registries.each_ref())),
)
.await;

Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/cli/crypto/encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl EncryptArgs {
let input = InputSource::from_file(&self.input_file);

let mut rng = thread_rng();
let mut key_registries = KeyRegistries::default();
let key_registries = KeyRegistries::default();

let network =
NetworkConfig::from_toml_str(&read_to_string(&self.network).unwrap_or_else(|e| {
Expand Down Expand Up @@ -84,7 +84,7 @@ impl EncryptArgs {

for share in shares {
let output = share
.encrypt(DEFAULT_KEY_ID, key_registry, &mut rng)
.encrypt(DEFAULT_KEY_ID, &key_registry, &mut rng)
.unwrap();
let hex_output = hex::encode(&output);
writeln!(writer, "{hex_output}")?;
Expand Down
231 changes: 210 additions & 21 deletions ipa-core/src/cli/crypto/hybrid_encrypt.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use std::{
fs::{read_to_string, OpenOptions},
io::Write,
iter::zip,
array,
collections::BTreeMap,
fs::{read_to_string, File, OpenOptions},
io::{BufWriter, Write},
path::{Path, PathBuf},
sync::mpsc::{channel, Sender},
thread,
thread::JoinHandle,
time::Instant,
};

use clap::Parser;
Expand All @@ -15,11 +20,23 @@
},
config::{KeyRegistries, NetworkConfig},
error::BoxError,
hpke::{KeyRegistry, PublicKeyOnly},
report::hybrid::{HybridReport, DEFAULT_KEY_ID},
secret_sharing::IntoShares,
test_fixture::hybrid::TestHybridRecord,
};

/// Encryptor takes 3 arguments: `report_id`, helper that the shares must be encrypted towards
/// and the actual share to encrypt.
eriktaubeneck marked this conversation as resolved.
Show resolved Hide resolved
type EncryptorInput = (usize, usize, HybridReport<BreakdownKey, TriggerValue>);
/// Encryptor sends report id and encrypted bytes down to file worker to write those bytes
/// down
type EncryptorOutput = (usize, Vec<u8>);
type FileWorkerInput = EncryptorOutput;

/// This type is used quite often in this module
type UnitResult = Result<(), BoxError>;

#[derive(Debug, Parser)]
#[clap(name = "test_hybrid_encrypt", about = "Test Hybrid Encrypt")]
#[command(about)]
Expand Down Expand Up @@ -51,11 +68,12 @@
/// if input file or network file are not correctly formatted
/// # Errors
/// if it cannot open the files
pub fn encrypt(&self) -> Result<(), BoxError> {
pub fn encrypt(&self) -> UnitResult {
tracing::info!("encrypting input from {:?}", self.input_file);
let start = Instant::now();
let input = InputSource::from_file(&self.input_file);

let mut rng = thread_rng();
let mut key_registries = KeyRegistries::default();
let key_registries = KeyRegistries::default();

let network =
NetworkConfig::from_toml_str(&read_to_string(&self.network).unwrap_or_else(|e| {
Expand All @@ -71,28 +89,199 @@
panic!("could not load network file")
};

let shares: [Vec<HybridReport<BreakdownKey, TriggerValue>>; 3] =
input.iter::<TestHybridRecord>().share();
let mut worker_pool = ReportWriter::new(key_registries, &self.output_dir);
for (report_id, record) in input.iter::<TestHybridRecord>().enumerate() {
worker_pool.submit(report_id, record.share())?;
}

worker_pool.join()?;

let elapsed = start.elapsed();
tracing::info!(
"Encryption process is completed. {}s",
elapsed.as_secs_f64()

Check warning on line 102 in ipa-core/src/cli/crypto/hybrid_encrypt.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/crypto/hybrid_encrypt.rs#L101-L102

Added lines #L101 - L102 were not covered by tests
);

Ok(())
}
}

for (index, (shares, key_registry)) in zip(shares, key_registries).enumerate() {
let output_filename = format!("helper{}.enc", index + 1);
let mut writer = OpenOptions::new()
/// A thread-per-core pool responsible for encrypting reports in parallel.
/// This pool is shared across all writers to reduce the number of context switches.
struct EncryptorPool {
pool: Vec<(Sender<EncryptorInput>, JoinHandle<UnitResult>)>,
next_worker: usize,
}

impl EncryptorPool {
pub fn with_worker_threads(
thread_count: usize,
file_writer: [Sender<EncryptorOutput>; 3],
key_registries: [KeyRegistry<PublicKeyOnly>; 3],
) -> Self {
Self {
pool: (0..thread_count)
.map(move |i| {
let (tx, rx) = channel::<EncryptorInput>();
eriktaubeneck marked this conversation as resolved.
Show resolved Hide resolved
let key_registries = key_registries.clone();
let file_writer = file_writer.clone();
(
tx,
std::thread::Builder::new()
.name(format!("encryptor-{i}"))
.spawn(move || {
for (i, helper_id, report) in rx {
let key_registry = &key_registries[helper_id];
let output = report.encrypt(
DEFAULT_KEY_ID,
key_registry,
&mut thread_rng(),
)?;
file_writer[helper_id].send((i, output))?;
}

Ok(())
})
.unwrap(),
)
})
.collect(),
next_worker: 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is next_worker always 0? what's the point of it if it's constant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutation occurs inside encrypt_share

}
}

pub fn encrypt_share(&mut self, report: EncryptorInput) -> UnitResult {
let tx = &self.pool[self.next_worker].0;
tx.send(report)?;
self.next_worker = (self.next_worker + 1) % self.pool.len();

Ok(())
}

pub fn stop(self) -> UnitResult {
for (tx, handle) in self.pool {
drop(tx);
handle.join().unwrap()?;
}

Ok(())
}
}

/// Performs end-to-end encryption, taking individual shares as input
/// (see [`ReportWriter::submit`]), encrypting them in parallel and writing
/// encrypted shares into 3 separate files. This optimizes for memory usage,
/// and maximizes CPU utilization.
struct ReportWriter {
encryptor_pool: EncryptorPool,
workers: Option<[FileWriteWorker; 3]>,
}

impl ReportWriter {
pub fn new(key_registries: [KeyRegistry<PublicKeyOnly>; 3], output_dir: &Path) -> Self {
// create 3 worker threads to write data into 3 files
let workers = array::from_fn(|i| {
let output_filename = format!("helper{}.enc", i + 1);
let file = OpenOptions::new()
.write(true)
.create_new(true)
.open(self.output_dir.join(&output_filename))
.unwrap_or_else(|e| panic!("unable write to {}. {}", &output_filename, e));

for share in shares {
let output = share
.encrypt(DEFAULT_KEY_ID, key_registry, &mut rng)
.unwrap();
let hex_output = hex::encode(&output);
writeln!(writer, "{hex_output}")?;
}
.open(output_dir.join(&output_filename))
.unwrap_or_else(|e| panic!("unable write to {:?}. {}", &output_filename, e));

FileWriteWorker::new(file)
});
let encryptor_pool = EncryptorPool::with_worker_threads(
num_cpus::get(),
workers.each_ref().map(|x| x.sender.clone()),
key_registries,
);

Self {
encryptor_pool,
workers: Some(workers),
}
}

pub fn submit(
&mut self,
report_id: usize,
shares: [HybridReport<BreakdownKey, TriggerValue>; 3],
) -> UnitResult {
for (i, share) in shares.into_iter().enumerate() {
self.encryptor_pool.encrypt_share((report_id, i, share))?;
}

Ok(())
}

pub fn join(mut self) -> UnitResult {
self.encryptor_pool.stop()?;
self.workers
.take()
.unwrap()
.map(|worker| {
let FileWriteWorker { handle, sender } = worker;
drop(sender);
handle.join().unwrap()
})
.into_iter()
.collect()
}
}

/// This takes a file and writes all encrypted reports to it,
/// ensuring the same total order based on `report_id`. Report id is
/// just the index of file input row that guarantees consistency
/// of shares written across 3 files
struct FileWriteWorker {
sender: Sender<FileWorkerInput>,
handle: JoinHandle<UnitResult>,
}

impl FileWriteWorker {
pub fn new(file: File) -> Self {
let (tx, rx) = std::sync::mpsc::channel();
Self {
sender: tx,
handle: thread::spawn(move || {
fn write_report<W: Write>(writer: &mut W, report: &[u8]) -> Result<(), BoxError> {
let hex_output = hex::encode(report);
writeln!(writer, "{hex_output}")?;
Ok(())
}

// write low watermark. All reports below this line have been written
let mut lw = 0;
let mut pending_reports = BTreeMap::new();

// Buffered writes should improve IO, but it is likely not the bottleneck here.
let mut writer = BufWriter::new(file);
for (report_id, report) in rx {
eriktaubeneck marked this conversation as resolved.
Show resolved Hide resolved
// Because reports are encrypted in parallel, it is possible
// to receive report_id = X+1 before X. To mitigate that, we keep
// a buffer, ordered by report_id and always write from low watermark.
// This ensures consistent order of reports written to files. Any misalignment
// will result in broken shares and garbage output.
assert!(
report_id >= lw,
"Internal error: received a report {report_id} below low watermark"

Check warning on line 267 in ipa-core/src/cli/crypto/hybrid_encrypt.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/crypto/hybrid_encrypt.rs#L267

Added line #L267 was not covered by tests
);
assert!(
pending_reports.insert(report_id, report).is_none(),
"Internal error: received a duplicate report {report_id}"

Check warning on line 271 in ipa-core/src/cli/crypto/hybrid_encrypt.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/crypto/hybrid_encrypt.rs#L271

Added line #L271 was not covered by tests
);
while let Some(report) = pending_reports.remove(&lw) {
write_report(&mut writer, &report)?;
lw += 1;
if lw % 1_000_000 == 0 {
tracing::info!("Encrypted {}M reports", lw / 1_000_000);

Check warning on line 277 in ipa-core/src/cli/crypto/hybrid_encrypt.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/crypto/hybrid_encrypt.rs#L277

Added line #L277 was not covered by tests
}
}
}
Ok(())
}),
}
}
}

#[cfg(all(test, unit_test))]
Expand Down
14 changes: 11 additions & 3 deletions ipa-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,11 @@ pub struct KeyRegistries(Vec<KeyRegistry<PublicKeyOnly>>);
impl KeyRegistries {
/// # Panics
/// If network file is improperly formatted
#[must_use]
pub fn init_from(
&mut self,
mut self,
network: &NetworkConfig<Helper>,
) -> Option<[&KeyRegistry<PublicKeyOnly>; 3]> {
) -> Option<[KeyRegistry<PublicKeyOnly>; 3]> {
// Get the configs, if all three peers have one
let peers = network.peers();
let configs = peers.iter().try_fold(Vec::new(), |acc, peer| {
Expand All @@ -487,7 +488,14 @@ impl KeyRegistries {
.map(|hpke| KeyRegistry::from_keys([PublicKeyOnly(hpke.public_key.clone())]))
.collect::<Vec<KeyRegistry<PublicKeyOnly>>>();

Some(self.0.iter().collect::<Vec<_>>().try_into().ok().unwrap())
Some(
self.0
.into_iter()
.collect::<Vec<_>>()
.try_into()
.ok()
.unwrap(),
)
}
}

Expand Down
Loading
Loading