From 6a6517107bccb6bc0521d32ca1cc88d43d0f5abf Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Thu, 19 Dec 2024 11:54:11 -0800 Subject: [PATCH] Parallelize decryption of reports --- ipa-core/src/query/runner/hybrid.rs | 33 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index d9d2099fe..e41fbc152 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -5,7 +5,7 @@ use std::{ sync::Arc, }; -use futures::{stream::iter, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use generic_array::ArrayLength; use super::QueryResult; @@ -20,7 +20,9 @@ use crate::{ }, helpers::{ query::{DpMechanism, HybridQueryParams, QueryConfig, QuerySize}, - setup_cross_shard_prss, BodyStream, Gateway, LengthDelimitedStream, + setup_cross_shard_prss, + stream::TryFlattenItersExt, + BodyStream, Gateway, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, protocol::{ @@ -105,7 +107,7 @@ where config, key_registry, phantom_data: _, - } = self; + } = &self; tracing::info!("New hybrid query: {config:?}"); let ctx = ctx.narrow(&Hybrid); @@ -118,21 +120,18 @@ where } let stream = LengthDelimitedStream::, _>::new(input_stream) - .map_err(Into::::into) - .map_ok(|enc_reports| { - iter(enc_reports.into_iter().map({ - |enc_report| { - let dec_report = enc_report - .decrypt(key_registry.as_ref()) - .map_err(Into::::into); - let unique_tag = UniqueTag::from_unique_bytes(&enc_report); - dec_report.map(|dec_report1| (dec_report1, unique_tag)) - } - })) + .map_err(Into::into) + .try_flatten_iters() + .map(|enc_report_res| async move { + enc_report_res.and_then(|enc_report| { + let dec_report = enc_report + .decrypt(key_registry.as_ref()) + .map_err(Into::::into); + let unique_tag = UniqueTag::from_unique_bytes(&enc_report); + dec_report.map(|dec_report1| (dec_report1, unique_tag)) + }) }) - .try_flatten() - .take(sz) - .map(|v| async move { v }); + .take(sz); let (decrypted_reports, resharded_tags) = reshard_aad( ctx.narrow(&HybridStep::ReshardByTag),