From 4bff1d6c7b9e8708756d02ecd8b35904f2b980bb Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Mon, 2 Dec 2024 16:45:17 -0800 Subject: [PATCH] add plumbing for starting a hybrid query --- ipa-core/src/query/executor.rs | 19 +++++++++++-- ipa-core/src/query/runner/hybrid.rs | 41 +++++++++++++++++++++++++---- ipa-core/src/query/runner/mod.rs | 2 +- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index 47a0d3bf5..d626e12bf 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -39,7 +39,7 @@ use crate::{ Gate, }, query::{ - runner::{OprfIpaQuery, QueryResult}, + runner::{execute_hybrid_protocol, OprfIpaQuery, QueryResult}, state::RunningQuery, }, sync::Arc, @@ -165,7 +165,22 @@ pub fn execute( ) }, ), - (QueryType::MaliciousHybrid(_), _) => todo!(), + (QueryType::MaliciousHybrid(ipa_config), _) => do_query( + runtime, + config, + gateway, + input, + move |prss, gateway, config, input| { + Box::pin(execute_hybrid_protocol( + prss, + gateway, + input, + ipa_config, + config, + key_registry, + )) + }, + ), } } diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 199ee385f..9cc5294d8 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -8,31 +8,35 @@ use std::{ use futures::{stream::iter, StreamExt, TryStreamExt}; use generic_array::ArrayLength; +use super::QueryResult; use crate::{ error::{Error, LengthError}, ff::{ boolean::Boolean, - boolean_array::{BooleanArray, BA3, BA8}, + boolean_array::{BooleanArray, BA16, BA3, BA8}, curve_points::RP25519, ec_prime_field::Fp25519, Serializable, U128Conversions, }, helpers::{ - query::{DpMechanism, HybridQueryParams, QuerySize}, - BodyStream, LengthDelimitedStream, + query::{DpMechanism, HybridQueryParams, QueryConfig, QuerySize}, + setup_cross_shard_prss, BodyStream, Gateway, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, protocol::{ basics::{shard_fin::FinalizerContext, BooleanArrayMul, BooleanProtocols, Reveal}, - context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, + context::{ + DZKPUpgraded, MacUpgraded, ShardedContext, ShardedMaliciousContext, UpgradableContext, + }, hybrid::{ hybrid_protocol, oprf::{CONV_CHUNK, PRF_CHUNK}, step::HybridStep, }, ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle}, - prss::FromPrss, + prss::{Endpoint, FromPrss}, step::ProtocolStep::Hybrid, + Gate, }, query::runner::reshard_tag::reshard_aad, report::hybrid::{ @@ -42,6 +46,7 @@ use crate::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, Vectorizable, }, + sharding::{ShardConfiguration, Sharded}, }; #[allow(dead_code)] @@ -165,6 +170,32 @@ where } } +pub async fn execute_hybrid_protocol<'a, R: PrivateKeyRegistry>( + prss: &'a Endpoint, + gateway: &'a Gateway, + input: BodyStream, + ipa_config: HybridQueryParams, + config: &QueryConfig, + key_registry: Arc, +) -> QueryResult { + let gate = Gate::default(); + let cross_shard_prss = + setup_cross_shard_prss(gateway, &gate, prss.indexed(&gate), gateway).await?; + let sharded = Sharded { + shard_id: gateway.shard_id(), + shard_count: gateway.shard_count(), + prss: Arc::new(cross_shard_prss), + }; + + let ctx = ShardedMaliciousContext::new_with_gate(prss, gateway, gate, sharded); + + Ok(Box::new( + Query::<_, BA16, R>::new(ipa_config, key_registry) + .execute(ctx, config.size, input) + .await?, + )) +} + #[cfg(all(test, unit_test, feature = "in-memory-infra"))] mod tests { use std::{ diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 83f033fe4..642e2a846 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -15,7 +15,7 @@ pub(super) use sharded_shuffle::execute_sharded_shuffle; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use test_multiply::execute_test_multiply; -pub use self::oprf_ipa::OprfIpaQuery; +pub use self::{hybrid::execute_hybrid_protocol, oprf_ipa::OprfIpaQuery}; use crate::{error::Error, query::ProtocolResult}; pub(super) type QueryResult = Result, Error>;