diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 49348843d..0081a0a50 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -82,6 +82,7 @@ ipa-step = { version = "*", path = "../ipa-step" } ipa-step-derive = { version = "*", path = "../ipa-step-derive" } aes = "0.8.3" +assertions = "0.1.0" async-trait = "0.1.79" async-scoped = { version = "0.9.0", features = ["use-tokio"], optional = true } axum = { version = "0.7.5", optional = true, features = ["http2", "macros"] } diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs index 6fd7406c6..5de0051be 100644 --- a/ipa-core/src/protocol/hybrid/step.rs +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -1,4 +1,6 @@ use ipa_step_derive::CompactStep; #[derive(CompactStep)] -pub(crate) enum HybridStep {} +pub(crate) enum HybridStep { + ReshardByTag, +} diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index a6b1dae74..edd1662e4 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -39,7 +39,7 @@ use crate::{ Gate, }, query::{ - runner::{HybridQuery, OprfIpaQuery, QueryResult}, + runner::{OprfIpaQuery, QueryResult}, state::RunningQuery, }, sync::Arc, @@ -164,20 +164,7 @@ pub fn execute( ) }, ), - (QueryType::SemiHonestHybrid(query_params), _) => do_query( - runtime, - config, - gateway, - input, - move |prss, gateway, config, input| { - let ctx = SemiHonestContext::new(prss, gateway); - Box::pin( - HybridQuery::<_, BA32, R>::new(query_params, key_registry) - .execute(ctx, config.size, input) - .then(|res| ready(res.map(|out| Box::new(out) as Box))), - ) - }, - ), + (QueryType::SemiHonestHybrid(_), _) => todo!(), } } diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 6ff7ed7f2..cc3b861c7 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -1,6 +1,6 @@ use std::{marker::PhantomData, sync::Arc}; -use futures::{future, stream::iter, StreamExt, TryStreamExt}; +use futures::{stream::iter, StreamExt, TryStreamExt}; use crate::{ error::Error, @@ -10,16 +10,16 @@ use crate::{ BodyStream, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, - protocol::{context::UpgradableContext, ipa_prf::shuffle::Shuffle, step::ProtocolStep::Hybrid}, - report::hybrid::{EncryptedHybridReport, UniqueBytesValidator}, + protocol::{ + context::{reshard_iter, ShardedContext}, + hybrid::step::HybridStep, + step::ProtocolStep::Hybrid, + }, + report::hybrid::{EncryptedHybridReport, HybridReport, UniqueTag, UniqueTagValidator}, secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue}, }; -pub type BreakdownKey = BA8; -pub type Value = BA3; -// TODO: remove this when encryption/decryption works for HybridReports -pub type Timestamp = BA20; - +#[allow(dead_code)] pub struct Query { config: HybridQueryParams, key_registry: Arc, @@ -28,7 +28,7 @@ pub struct Query { impl Query where - C: UpgradableContext + Shuffle, + C: ShardedContext, { pub fn new(query_params: HybridQueryParams, key_registry: Arc) -> Self { Self { @@ -50,41 +50,56 @@ where key_registry, phantom_data: _, } = self; + tracing::info!("New hybrid query: {config:?}"); - let _ctx = ctx.narrow(&Hybrid); + let ctx = ctx.narrow(&Hybrid); let sz = usize::from(query_size); - let mut unique_encrypted_hybrid_reports = UniqueBytesValidator::new(sz); - if config.plaintext_match_keys { return Err(Error::Unsupported( "Hybrid queries do not currently support plaintext match keys".to_string(), )); } - let _input = LengthDelimitedStream::::new(input_stream) - .map_err(Into::::into) - .and_then(|enc_reports| { - future::ready( - unique_encrypted_hybrid_reports - .check_duplicates(&enc_reports) - .map(|()| enc_reports) - .map_err(Into::::into), + let (_decrypted_reports, tags): (Vec>, Vec) = + LengthDelimitedStream::::new(input_stream) + .map_err(Into::::into) + .map_ok(|enc_reports| { + iter(enc_reports.into_iter().map({ + |enc_report| { + let dec_report = enc_report + .decrypt::(key_registry.as_ref()) + .map_err(Into::::into); + let unique_tag = UniqueTag::from_unique_bytes(&enc_report); + dec_report.map(|dec_report1| (dec_report1, unique_tag)) + } + })) + }) + .try_flatten() + .take(sz) + .try_fold( + (Vec::with_capacity(sz), Vec::with_capacity(sz)), + |mut acc, result| async move { + acc.0.push(result.0); + acc.1.push(result.1); + Ok(acc) + }, ) - }) - .map_ok(|enc_reports| { - iter(enc_reports.into_iter().map({ - |enc_report| { - enc_report - .decrypt::(key_registry.as_ref()) - .map_err(Into::::into) - } - })) - }) - .try_flatten() - .take(sz) - .try_collect::>() - .await?; + .await?; + + let resharded_tags = reshard_iter( + ctx.narrow(&HybridStep::ReshardByTag), + tags, + |ctx, _, tag| tag.shard_picker(ctx.shard_count()), + ) + .await?; + + // this should use ? but until this returns a result, + //we want to capture the panic for the test + let mut unique_encrypted_hybrid_reports = UniqueTagValidator::new(resharded_tags.len()); + unique_encrypted_hybrid_reports + .check_duplicates(&resharded_tags) + .unwrap(); unimplemented!("query::runnner::HybridQuery.execute is not fully implemented") } @@ -107,10 +122,13 @@ mod tests { BodyStream, }, hpke::{KeyPair, KeyRegistry}, - query::runner::HybridQuery, + query::runner::hybrid::Query as HybridQuery, report::{OprfReport, DEFAULT_KEY_ID}, - secret_sharing::IntoShares, - test_fixture::{ipa::TestRawDataRecord, join3v, Reconstruct, TestWorld}, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + test_fixture::{ + flatten3v, ipa::TestRawDataRecord, Reconstruct, RoundRobinInputDistribution, TestWorld, + TestWorldConfig, WithShards, + }, }; const EXPECTED: &[u128] = &[0, 8, 5]; @@ -165,28 +183,44 @@ mod tests { } struct BufferAndKeyRegistry { - buffers: [Vec; 3], + buffers: [Vec>; 3], key_registry: Arc>, + query_sizes: Vec, } - fn build_buffers_from_records(records: &[TestRawDataRecord]) -> BufferAndKeyRegistry { + fn build_buffers_from_records(records: &[TestRawDataRecord], s: usize) -> BufferAndKeyRegistry { let mut rng = StdRng::seed_from_u64(42); let key_id = DEFAULT_KEY_ID; let key_registry = Arc::new(KeyRegistry::::random(1, &mut rng)); - let mut buffers: [_; 3] = std::array::from_fn(|_| Vec::new()); - + let mut buffers: [_; 3] = std::array::from_fn(|_| vec![Vec::new(); s]); let shares: [Vec>; 3] = records.iter().cloned().share(); for (buf, shares) in zip(&mut buffers, shares) { - for share in shares { + for (i, share) in shares.into_iter().enumerate() { share - .delimited_encrypt_to(key_id, key_registry.as_ref(), &mut rng, buf) + .delimited_encrypt_to(key_id, key_registry.as_ref(), &mut rng, &mut buf[i % s]) .unwrap(); } } + + let total_query_size = records.len(); + let base_size = total_query_size / s; + let remainder = total_query_size % s; + let query_sizes: Vec<_> = (0..s) + .map(|i| { + if i < remainder { + base_size + 1 + } else { + base_size + } + }) + .map(|size| QuerySize::try_from(size).unwrap()) + .collect(); + BufferAndKeyRegistry { buffers, key_registry, + query_sizes, } } @@ -200,37 +234,59 @@ mod tests { // While this test currently checks for an unimplemented panic it is // designed to test for a correct result for a complete implementation. + const SHARDS: usize = 2; let records = build_records(); - let query_size = QuerySize::try_from(records.len()).unwrap(); let BufferAndKeyRegistry { buffers, key_registry, - } = build_buffers_from_records(&records); + query_sizes, + } = build_buffers_from_records(&records, SHARDS); - let world = TestWorld::default(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); let contexts = world.contexts(); + #[allow(clippy::large_futures)] - let results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { - let query_params = HybridQueryParams { - per_user_credit_cap: 8, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 5.0, - plaintext_match_keys: false, - }; - let input = BodyStream::from(buffer); - - HybridQuery::<_, BA16, KeyRegistry>::new( - query_params, - Arc::clone(&key_registry), - ) - .execute(ctx, query_size, input) - })) + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) .await; + let results: Vec<[Vec>; 3]> = results + .chunks(3) + .map(|chunk| { + [ + chunk[0].as_ref().unwrap().clone(), + chunk[1].as_ref().unwrap().clone(), + chunk[2].as_ref().unwrap().clone(), + ] + }) + .collect(); + assert_eq!( - results.reconstruct()[0..3] + results.into_iter().next().unwrap().reconstruct()[0..3] .iter() .map(U128Conversions::as_u128) .collect::>(), @@ -240,45 +296,71 @@ mod tests { // cannot test for Err directly because join3v calls unwrap. This should be sufficient. #[tokio::test] - #[should_panic(expected = "DuplicateBytes(3)")] + #[should_panic(expected = "DuplicateBytes")] async fn duplicate_encrypted_hybrid_reports() { - let all_records = build_records(); - let records = &all_records[..2].to_vec(); + const SHARDS: usize = 2; + let records = build_records(); let BufferAndKeyRegistry { mut buffers, key_registry, - } = build_buffers_from_records(records); + query_sizes, + } = build_buffers_from_records(&records, SHARDS); // this is double, since we duplicate the data below - let query_size = QuerySize::try_from(records.len() * 2).unwrap(); - - // duplicate all the data - for buffer in &mut buffers { - let original = buffer.clone(); - buffer.extend(original); + let query_sizes = query_sizes + .into_iter() + .map(|query_size| QuerySize::try_from(usize::from(query_size) * 2).unwrap()) + .collect::>(); + + // duplicate all the data across shards + + for helper_buffers in &mut buffers { + // Get the last shard buffer to use for the first shard buffer extension + let last_shard_buffer = helper_buffers.last().unwrap().clone(); + let len = helper_buffers.len(); + for i in 0..len { + if i > 0 { + let previous = &helper_buffers[i - 1].clone(); + helper_buffers[i].extend_from_slice(previous); + } else { + helper_buffers[i].extend_from_slice(&last_shard_buffer); + } + } } - let world = TestWorld::default(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); let contexts = world.contexts(); + #[allow(clippy::large_futures)] - let _results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { - let query_params = HybridQueryParams { - per_user_credit_cap: 8, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 5.0, - plaintext_match_keys: false, - }; - let input = BodyStream::from(buffer); - - HybridQuery::<_, BA16, KeyRegistry>::new( - query_params, - Arc::clone(&key_registry), - ) - .execute(ctx, query_size, input) - })) + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: false, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) .await; + + results.into_iter().map(|r| r.unwrap()).for_each(drop); } // cannot test for Err directly because join3v calls unwrap. This should be sufficient. @@ -287,34 +369,46 @@ mod tests { expected = "Unsupported(\"Hybrid queries do not currently support plaintext match keys\")" )] async fn unsupported_plaintext_match_keys_hybrid_query() { - let all_records = build_records(); - let records = &all_records[..2].to_vec(); - let query_size = QuerySize::try_from(records.len()).unwrap(); + const SHARDS: usize = 2; + let records = build_records(); let BufferAndKeyRegistry { buffers, key_registry, - } = build_buffers_from_records(records); + query_sizes, + } = build_buffers_from_records(&records, SHARDS); - let world = TestWorld::default(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); let contexts = world.contexts(); + #[allow(clippy::large_futures)] - let _results = join3v(buffers.into_iter().zip(contexts).map(|(buffer, ctx)| { - let query_params = HybridQueryParams { - per_user_credit_cap: 8, - max_breakdown_key: 3, - with_dp: 0, - epsilon: 5.0, - plaintext_match_keys: true, - }; - let input = BodyStream::from(buffer); - - HybridQuery::<_, BA16, KeyRegistry>::new( - query_params, - Arc::clone(&key_registry), - ) - .execute(ctx, query_size, input) - })) + let results = flatten3v(buffers.into_iter().zip(contexts).map( + |(helper_buffers, helper_ctxs)| { + helper_buffers + .into_iter() + .zip(helper_ctxs) + .zip(query_sizes.clone()) + .map(|((buffer, ctx), query_size)| { + let query_params = HybridQueryParams { + per_user_credit_cap: 8, + max_breakdown_key: 3, + with_dp: 0, + epsilon: 5.0, + plaintext_match_keys: true, + }; + let input = BodyStream::from(buffer); + + HybridQuery::<_, BA16, KeyRegistry>::new( + query_params, + Arc::clone(&key_registry), + ) + .execute(ctx, query_size, input) + }) + }, + )) .await; + + results.into_iter().map(|r| r.unwrap()).for_each(drop); } } diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 9e5935c20..9bd739db9 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -7,7 +7,6 @@ mod test_multiply; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use add_in_prime_field::execute as test_add_in_prime_field; -pub use hybrid::Query as HybridQuery; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use test_multiply::execute_test_multiply; diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index d322f92ae..593c6d0e8 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -1,7 +1,8 @@ -use std::{collections::HashSet, ops::Add}; +use std::{collections::HashSet, convert::Infallible, ops::Add}; +use assertions::const_assert; use bytes::Bytes; -use generic_array::ArrayLength; +use generic_array::{ArrayLength, GenericArray}; use rand_core::{CryptoRng, RngCore}; use typenum::{Sum, Unsigned, U16}; @@ -11,6 +12,7 @@ use crate::{ hpke::{EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize}, report::{EncryptedOprfReport, EventType, InvalidReportError, KeyIdentifier}, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, SharedValue}, + sharding::ShardIndex, }; #[derive(Clone, Debug, Eq, PartialEq)] @@ -150,37 +152,92 @@ impl TryFrom for EncryptedHybridReport { } } +const TAG_SIZE: usize = TagSize::USIZE; + +#[derive(Clone, Debug)] +pub struct UniqueTag { + bytes: [u8; TAG_SIZE], +} + pub trait UniqueBytes { - fn unique_bytes(&self) -> Vec; + fn unique_bytes(&self) -> [u8; TAG_SIZE]; +} + +impl UniqueBytes for UniqueTag { + fn unique_bytes(&self) -> [u8; TAG_SIZE] { + self.bytes + } } impl UniqueBytes for EncryptedHybridReport { /// We use the `TagSize` (the first 16 bytes of the ciphertext) for collision-detection /// See [analysis here for uniqueness](https://eprint.iacr.org/2019/624) - fn unique_bytes(&self) -> Vec { - self.mk_ciphertext()[0..TagSize::USIZE].to_vec() + fn unique_bytes(&self) -> [u8; TAG_SIZE] { + let slice = &self.mk_ciphertext()[0..TAG_SIZE]; + let mut array = [0u8; TAG_SIZE]; + array.copy_from_slice(slice); + array + } +} + +impl UniqueTag { + fn _compile_check() { + // This will vaild at compile time if TAG_SIZE doesn't match U16 + // the macro expansion needs to be wrapped in a function + const_assert!(TAG_SIZE == 16); + } + + // Function to attempt to create a UniqueTag from a UniqueBytes implementor + pub fn from_unique_bytes(item: &T) -> Self { + UniqueTag { + bytes: item.unique_bytes(), + } + } + + /// Maps the tag into a consistent shard. + /// + /// ## Panics + /// if the `TAG_SIZE != 16` + /// note: ~10 below this, we have a compile time check that `TAG_SIZE = 16` + #[must_use] + pub fn shard_picker(&self, shard_count: ShardIndex) -> ShardIndex { + let num = u128::from_le_bytes(self.bytes); + let shard_count = u128::from(shard_count); + ShardIndex::try_from(num % shard_count).expect("Modulo a u32 will fit in u32") + } +} + +impl Serializable for UniqueTag { + type Size = U16; // This must match TAG_SIZE + type DeserializationError = Infallible; + + fn serialize(&self, buf: &mut GenericArray) { + buf.copy_from_slice(&self.bytes); + } + fn deserialize(buf: &GenericArray) -> Result { + let mut bytes = [0u8; TAG_SIZE]; + bytes.copy_from_slice(buf.as_slice()); + Ok(UniqueTag { bytes }) } } #[derive(Debug)] -pub struct UniqueBytesValidator { - hash_set: HashSet>, +pub struct UniqueTagValidator { + hash_set: HashSet<[u8; TAG_SIZE]>, check_counter: usize, } -impl UniqueBytesValidator { +impl UniqueTagValidator { #[must_use] pub fn new(size: usize) -> Self { - UniqueBytesValidator { + UniqueTagValidator { hash_set: HashSet::with_capacity(size), check_counter: 0, } } - - fn insert(&mut self, value: Vec) -> bool { + fn insert(&mut self, value: [u8; TAG_SIZE]) -> bool { self.hash_set.insert(value) } - /// Checks that item is unique among all checked thus far /// /// ## Errors @@ -193,7 +250,6 @@ impl UniqueBytesValidator { Err(Error::DuplicateBytes(self.check_counter)) } } - /// Checks that an iter of items is unique among the iter and any other items checked thus far /// /// ## Errors @@ -213,7 +269,7 @@ mod test { use super::{ EncryptedHybridReport, HybridConversionReport, HybridImpressionReport, HybridReport, - UniqueBytes, UniqueBytesValidator, + UniqueTag, UniqueTagValidator, }; use crate::{ error::Error, @@ -239,11 +295,11 @@ mod test { } } - fn generate_random_bytes(size: usize) -> Vec { + fn generate_random_tag() -> UniqueTag { let mut rng = thread_rng(); - let mut bytes = vec![0u8; size]; + let mut bytes = [0u8; 16]; rng.fill(&mut bytes[..]); - bytes + UniqueTag { bytes } } #[test] @@ -305,40 +361,22 @@ mod test { #[test] fn unique_encrypted_hybrid_reports() { - #[derive(Clone)] - pub struct UniqueByteHolder { - bytes: Vec, - } - - impl UniqueByteHolder { - pub fn new(size: usize) -> Self { - let bytes = generate_random_bytes(size); - UniqueByteHolder { bytes } - } - } - - impl UniqueBytes for UniqueByteHolder { - fn unique_bytes(&self) -> Vec { - self.bytes.clone() - } - } - - let bytes1 = UniqueByteHolder::new(4); - let bytes2 = UniqueByteHolder::new(4); - let bytes3 = UniqueByteHolder::new(4); - let bytes4 = UniqueByteHolder::new(4); + let tag1 = generate_random_tag(); + let tag2 = generate_random_tag(); + let tag3 = generate_random_tag(); + let tag4 = generate_random_tag(); - let mut unique_bytes = UniqueBytesValidator::new(4); + let mut unique_bytes = UniqueTagValidator::new(4); - unique_bytes.check_duplicate(&bytes1).unwrap(); + unique_bytes.check_duplicate(&tag1).unwrap(); unique_bytes - .check_duplicates(&[bytes2.clone(), bytes3.clone()]) + .check_duplicates(&[tag2.clone(), tag3.clone()]) .unwrap(); - let expected_err = unique_bytes.check_duplicate(&bytes2); + let expected_err = unique_bytes.check_duplicate(&tag2); assert!(matches!(expected_err, Err(Error::DuplicateBytes(4)))); - let expected_err = unique_bytes.check_duplicates(&[bytes4, bytes3]); + let expected_err = unique_bytes.check_duplicates(&[tag4, tag3]); assert!(matches!(expected_err, Err(Error::DuplicateBytes(6)))); } } diff --git a/ipa-core/src/report/ipa.rs b/ipa-core/src/report/ipa.rs index a9da93454..bceba37e3 100644 --- a/ipa-core/src/report/ipa.rs +++ b/ipa-core/src/report/ipa.rs @@ -827,9 +827,9 @@ mod test { fn check_compatibility_impressionmk_with_ios_encryption() { let enc_report_bytes1 = hex::decode( "12854879d86ef277cd70806a7f6bad269877adc95ee107380381caf15b841a7e995e41\ - 4c63a9d82f834796cdd6c40529189fca82720714d24200d8a916a1e090b123f27eaf24\ - f047f3930a77e5bcd33eeb823b73b0e9546c59d3d6e69383c74ae72b79645698fe1422\ - f83886bd3cbca9fbb63f7019e2139191dd000000007777772e6d6574612e636f6d", + 4c63a9d82f834796cdd6c40529189fca82720714d24200d8a916a1e090b123f27eaf24\ + f047f3930a77e5bcd33eeb823b73b0e9546c59d3d6e69383c74ae72b79645698fe1422\ + f83886bd3cbca9fbb63f7019e2139191dd000000007777772e6d6574612e636f6d", ) .unwrap(); let enc_report_bytes2 = hex::decode( diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index 195573f45..64ed9b6d3 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -7,12 +7,68 @@ use std::{ #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ShardIndex(pub u32); +impl ShardIndex { + pub const FIRST: Self = Self(0); + + /// Returns an iterator over all shard indices that precede this one, excluding this one. + pub fn iter(self) -> impl Iterator { + (0..self.0).map(Self) + } +} + +impl From for ShardIndex { + fn from(value: u32) -> Self { + Self(value) + } +} + +impl From for u64 { + fn from(value: ShardIndex) -> Self { + u64::from(value.0) + } +} + +impl From for u128 { + fn from(value: ShardIndex) -> Self { + Self::from(value.0) + } +} + +#[cfg(target_pointer_width = "64")] +impl From for usize { + fn from(value: ShardIndex) -> Self { + usize::try_from(value.0).unwrap() + } +} + impl From for u32 { fn from(value: ShardIndex) -> Self { value.0 } } +impl TryFrom for ShardIndex { + type Error = TryFromIntError; + + fn try_from(value: usize) -> Result { + u32::try_from(value).map(Self) + } +} + +impl TryFrom for ShardIndex { + type Error = TryFromIntError; + + fn try_from(value: u128) -> Result { + u32::try_from(value).map(Self) + } +} + +impl Display for ShardIndex { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + #[derive(Debug, Copy, Clone)] pub struct Sharded { pub shard_id: ShardIndex, @@ -29,12 +85,6 @@ impl ShardConfiguration for Sharded { } } -impl Display for ShardIndex { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - /// Shard-specific configuration required by sharding API. Each shard must know its own index and /// the total number of shards in the system. pub trait ShardConfiguration { @@ -70,48 +120,6 @@ pub struct NotSharded; impl ShardBinding for NotSharded {} impl ShardBinding for Sharded {} -impl ShardIndex { - pub const FIRST: Self = Self(0); - - /// Returns an iterator over all shard indices that precede this one, excluding this one. - pub fn iter(self) -> impl Iterator { - (0..self.0).map(Self) - } -} - -impl From for ShardIndex { - fn from(value: u32) -> Self { - Self(value) - } -} - -impl From for u64 { - fn from(value: ShardIndex) -> Self { - u64::from(value.0) - } -} - -impl From for u128 { - fn from(value: ShardIndex) -> Self { - Self::from(value.0) - } -} - -#[cfg(target_pointer_width = "64")] -impl From for usize { - fn from(value: ShardIndex) -> Self { - usize::try_from(value.0).unwrap() - } -} - -impl TryFrom for ShardIndex { - type Error = TryFromIntError; - - fn try_from(value: usize) -> Result { - u32::try_from(value).map(Self) - } -} - #[cfg(all(test, unit_test))] mod tests { use std::iter::empty; diff --git a/ipa-core/src/test_fixture/mod.rs b/ipa-core/src/test_fixture/mod.rs index 9c6e7995f..38d12eb21 100644 --- a/ipa-core/src/test_fixture/mod.rs +++ b/ipa-core/src/test_fixture/mod.rs @@ -147,6 +147,37 @@ where join3(fut0, fut1, fut2) } +/// Wrapper for flattening 3 vecs of vecs into a single future +/// # Panics +/// If the tasks return `Err` or if `a` is the wrong length. +pub fn flatten3v(a: V) -> impl Future::Output>> +where + V: IntoIterator, + I: IntoIterator, + T: TryFuture, + T::Output: Debug, + T::Ok: Debug, + T::Error: Debug, +{ + let mut it = a.into_iter(); + + let outer0 = it.next().unwrap().into_iter(); + let outer1 = it.next().unwrap().into_iter(); + let outer2 = it.next().unwrap().into_iter(); + + assert!(it.next().is_none()); + + // only used for tests + #[allow(clippy::disallowed_methods)] + futures::future::join_all( + outer0 + .zip(outer1) + .zip(outer2) + .flat_map(|((fut0, fut1), fut2)| vec![fut0, fut1, fut2]) + .collect::>(), + ) +} + /// Take a slice of bits in `{0,1} ⊆ F_p`, and reconstruct the integer in `Z` pub fn bits_to_value(x: &[F]) -> u128 { #[allow(clippy::cast_possible_truncation)] diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 1290e9f98..5c23cb2d8 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -207,6 +207,42 @@ impl TestWorld> { .ok() .unwrap() } + + /// Creates protocol contexts for 3 helpers across all shards + /// + /// # Panics + /// Panics if world has more or less than 3 gateways/participants + #[must_use] + pub fn contexts(&self) -> [Vec>; 3] { + let gate = &self.next_gate(); + self.shards().iter().map(|shard| shard.contexts(gate)).fold( + [Vec::new(), Vec::new(), Vec::new()], + |mut acc, contexts| { + // Distribute contexts into the respective vectors. + for (vec, context) in acc.iter_mut().zip(contexts.iter()) { + vec.push(context.clone()); + } + acc + }, + ) + } + /// Creates malicious protocol contexts for 3 helpers across all shards + /// + /// # Panics + /// Panics if world has more or less than 3 gateways/participants + #[must_use] + pub fn malicious_contexts(&self) -> [Vec>; 3] { + self.shards() + .iter() + .map(|shard| shard.malicious_contexts(&self.next_gate())) + .fold([Vec::new(), Vec::new(), Vec::new()], |mut acc, contexts| { + // Distribute contexts into the respective vectors. + for (vec, context) in acc.iter_mut().zip(contexts.iter()) { + vec.push(context.clone()); + } + acc + }) + } } /// Backward-compatible API for tests that don't use sharding.