Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A draft API for validation of replicated shares #936

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ipa-core/src/protocol/basics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod reshare;
mod reveal;
mod share_known_value;
pub mod sum_of_product;
pub mod validate;

pub use check_zero::check_zero;
pub use if_else::if_else;
Expand Down
284 changes: 284 additions & 0 deletions ipa-core/src/protocol/basics/validate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
use std::{
convert::Infallible,
marker::PhantomData,
pin::Pin,
task::{Context as TaskContext, Poll},
};

use futures::{
future::try_join,
stream::{Fuse, Stream, StreamExt},
Future, FutureExt,
};
use generic_array::GenericArray;
use pin_project::pin_project;
use sha2::{
digest::{typenum::Unsigned, FixedOutput, OutputSizeUser},
Digest, Sha256,
};

use crate::{
error::Error,
ff::Serializable,
helpers::{Direction, Message},
protocol::{context::Context, RecordId},
secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue},
seq_join::assert_send,
};

type HashFunction = Sha256;
type HashSize = <HashFunction as OutputSizeUser>::OutputSize;
type HashOutputArray = [u8; <HashSize as Unsigned>::USIZE];

#[derive(Debug, Clone, PartialEq, Eq)]
struct HashValue(GenericArray<u8, HashSize>);

impl Serializable for HashValue {
type Size = HashSize;
type DeserializationError = Infallible;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
buf.copy_from_slice(self.0.as_slice());
}

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
Ok(Self(buf.to_owned()))
}
}

impl Message for HashValue {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have also defined a hash struct that implements message in my draft. Should we move it to a more general crate since we seem to need it in several places?


impl From<HashFunction> for HashValue {
fn from(value: HashFunction) -> Self {
// Ugh: The version of sha2 we currently use doesn't use the same GenericArray version as we do.
HashValue(GenericArray::from(<HashOutputArray>::from(
value.finalize_fixed(),
)))
}
}

struct ReplicatedValidatorFinalization {
f: Pin<Box<(dyn Future<Output = Result<(), Error>> + Send)>>,
}

impl ReplicatedValidatorFinalization {
fn new<C: Context + 'static>(active: ReplicatedValidatorActive<C>) -> Self {
let ReplicatedValidatorActive {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be nicer from API's perspective to let ReplicatedValidatorActive to turn itself into a pair of hashes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is just internal, I didn't do that. Though I did implement From for HashValue, which made this function a little less ugly.

ctx,
left_hash,
right_hash,
} = active;
let left_hash = HashValue::from(left_hash);
let right_hash = HashValue::from(right_hash);
let left_peer = ctx.role().peer(Direction::Left);
let right_peer = ctx.role().peer(Direction::Right);

let f = Box::pin(assert_send(async move {
try_join(
ctx.send_channel(left_peer)
.send(RecordId::FIRST, left_hash.clone()),
ctx.send_channel(right_peer)
.send(RecordId::FIRST, right_hash.clone()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to send it to both, left and right. It seems sufficient to me if each party verifies one hash. It doesn't add significant costs if a party checks both hashes, however it also doesn't seem to add anything either.

)
.await?;
let (left_recvd, right_recvd) = try_join(
ctx.recv_channel(left_peer).receive(RecordId::FIRST),
ctx.recv_channel(right_peer).receive(RecordId::FIRST),
)
.await?;
if left_hash == left_recvd && right_hash == right_recvd {
Ok(())
} else {
Err(Error::Internal) // TODO add a code
}
}));
Self { f }
}

fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Error>> {
self.f.poll_unpin(cx)
}
}

struct ReplicatedValidatorActive<C> {
ctx: C,
left_hash: Sha256,
right_hash: Sha256,
}

impl<C: Context + 'static> ReplicatedValidatorActive<C> {
fn new(ctx: C) -> Self {
Self {
ctx,
left_hash: HashFunction::new(),
right_hash: HashFunction::new(),
}
}

fn update<S, V>(&mut self, s: &S)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For debugging, it might be useful if the update API takes both a Step and data, and creates a trace of the validator inputs. Then we can diagnose where things went wrong if there is a mismatch. (Obviously, we would want a flag so we only pay the cost of the detailed tracing when needed.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context carries a step, so we could smuggle that in somewhere.

I'm more concerned that I've been unable to instantiate this code :)

where
S: ReplicatedSecretSharing<V>,
martinthomson marked this conversation as resolved.
Show resolved Hide resolved
V: SharedValue,
{
let mut buf = GenericArray::default(); // ::<u8, <V as Serializable>::Size>
s.left().serialize(&mut buf);
self.left_hash.update(buf.as_slice());
s.right().serialize(&mut buf);
self.right_hash.update(buf.as_slice());
}

fn finalize(self) -> ReplicatedValidatorFinalization {
ReplicatedValidatorFinalization::new(self)
}
}

enum ReplicatedValidatorState<C> {
/// While the validator is waiting, it holds a context reference.
Pending(Option<Box<ReplicatedValidatorActive<C>>>),
/// After the validator has taken all of its inputs, it holds a future.
Finalizing(ReplicatedValidatorFinalization),
}

impl<C: Context + 'static> ReplicatedValidatorState<C> {
/// # Panics
/// This panics if it is called after `finalize()`.
fn update<S, V>(&mut self, s: &S)
where
S: ReplicatedSecretSharing<V>,
V: SharedValue,
{
if let Self::Pending(Some(a)) = self {
a.update(s);
} else {
panic!();
}
}

fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Error>> {
match self {
Self::Pending(ref mut active) => {
let mut f = active.take().unwrap().finalize();
let res = f.poll(cx);
*self = ReplicatedValidatorState::Finalizing(f);
res
}
Self::Finalizing(f) => f.poll(cx),
}
}
}

#[pin_project]
struct ReplicatedValidator<C, T: Stream, S, V> {
#[pin]
input: Fuse<T>,
state: ReplicatedValidatorState<C>,
_marker: PhantomData<(S, V)>,
}

impl<C: Context + 'static, T: Stream, S, V> ReplicatedValidator<C, T, S, V> {
pub fn new(ctx: C, s: T) -> Self {
Self {
input: s.fuse(),
state: ReplicatedValidatorState::Pending(Some(Box::new(
ReplicatedValidatorActive::new(ctx),
))),
_marker: PhantomData,
}
}
}

impl<C, T, S, V> Stream for ReplicatedValidator<C, T, S, V>
where
C: Context + 'static,
T: Stream<Item = Result<S, Error>>,
S: ReplicatedSecretSharing<V>,
V: SharedValue,
{
type Item = Result<S, Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.input.poll_next(cx) {
Poll::Ready(Some(v)) => match v {
Ok(v) => {
this.state.update(&v);
Poll::Ready(Some(Ok(v)))
}
Err(e) => Poll::Ready(Some(Err(e))),
},
Poll::Ready(None) => match this.state.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
},
Poll::Pending => Poll::Pending,
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.input.size_hint()
}
}

#[cfg(test)]
mod test {
use std::iter::repeat_with;

use futures::stream::{iter as stream_iter, Stream, StreamExt, TryStreamExt};

use crate::{
error::Error,
ff::Fp31,
helpers::Direction,
protocol::{basics::validate::ReplicatedValidator, context::Context, RecordId},
rand::{thread_rng, Rng},
secret_sharing::{
replicated::{
semi_honest::AdditiveShare as SemiHonestReplicated, ReplicatedSecretSharing,
},
SharedValue,
},
test_fixture::{Reconstruct, Runner, TestWorld},
};

fn assert_stream<S: Stream<Item = Result<T, Error>>, T>(s: S) -> S {
s
}

/// Successfully validate some shares.
#[tokio::test]
pub async fn simple() {
let mut rng = thread_rng();
let world = TestWorld::default();

let input = repeat_with(|| rng.gen::<Fp31>())
.take(10)
.collect::<Vec<_>>();
let result = world
.semi_honest(input.into_iter(), |ctx, shares| async move {
let ctx = ctx.set_total_records(shares.len());
let s = stream_iter(shares).map(|x| Ok(x));
let vs = ReplicatedValidator::new(ctx.narrow("validate"), s);
let sum = assert_stream(vs)
.try_fold(Fp31::ZERO, |sum, value| async move {
Ok(sum + value.left() - value.right())
})
.await?;
// This value should sum to zero now, so replicate the value.
// (We don't care here that this reveals our share to other helpers, it's just a test.)
ctx.send_channel(ctx.role().peer(Direction::Right))
.send(RecordId::FIRST, sum)
.await?;
let left = ctx
.recv_channel(ctx.role().peer(Direction::Left))
.receive(RecordId::FIRST)
.await?;
Ok(SemiHonestReplicated::new(left, sum))
})
.await
.map(Result::<_, Error>::unwrap)
.reconstruct();

assert_eq!(Fp31::ZERO, result);
}
}
14 changes: 6 additions & 8 deletions ipa-macros/src/derive_step/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use syn::{parse_macro_input, DeriveInput};

use crate::{
parser::{group_by_modules, ipa_state_transition_map, StepMetaData},
tree::Node,
tree::{self, Node},
};

const MAX_DYNAMIC_STEPS: usize = 1024;
Expand Down Expand Up @@ -115,7 +115,7 @@ fn impl_as_ref(ident: &syn::Ident, data: &syn::DataEnum) -> Result<TokenStream2,
let mut const_arrays = Vec::new();
let mut arms = Vec::new();

for v in data.variants.iter() {
for v in &data.variants {
let ident = &v.ident;
let ident_snake_case = ident.to_string().to_snake_case();
let ident_upper_case = ident_snake_case.to_uppercase();
Expand All @@ -128,7 +128,7 @@ fn impl_as_ref(ident: &syn::Ident, data: &syn::DataEnum) -> Result<TokenStream2,

// create an array of `num_steps` strings and use the variant index as array index
let steps = (0..num_steps)
.map(|i| format!("{}{}", ident_snake_case, i))
.map(|i| format!("{ident_snake_case}{i}"))
.collect::<Vec<_>>();
let steps_array_ident = format_ident!("{}_DYNAMIC_STEP", ident_upper_case);
const_arrays.extend(quote!(
Expand Down Expand Up @@ -272,9 +272,8 @@ fn get_meta_data_for(
1 => {
Ok(target_steps[0]
.iter()
.map(|s|
// we want to retain the references to the parents, so we use `upgrade()`
s.upgrade())
// we want to retain the references to the parents, so we use `upgrade()`
.map(tree::Node::upgrade)
.collect::<Vec<_>>())
}
_ => Err(syn::Error::new_spanned(
Expand Down Expand Up @@ -314,8 +313,7 @@ fn get_dynamic_step_count(variant: &syn::Variant) -> Result<usize, syn::Error> {
dynamic_attr,
format!(
"ipa_macros::step \"dynamic\" attribute expects a number of steps \
(<= {}) in parentheses: #[dynamic(...)].",
MAX_DYNAMIC_STEPS,
(<= {MAX_DYNAMIC_STEPS}) in parentheses: #[dynamic(...)].",
),
)),
}
Expand Down
16 changes: 12 additions & 4 deletions ipa-macros/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ pub(crate) fn read_steps_file(file_path: &str) -> Vec<String> {
let mut file = std::fs::File::open(path).expect("Could not open the steps file");
let mut contents = String::new();
file.read_to_string(&mut contents).unwrap();
contents.lines().map(|s| s.to_owned()).collect::<Vec<_>>()
contents
.lines()
.map(std::borrow::ToOwned::to_owned)
.collect::<Vec<_>>()
}

/// Constructs a tree structure with nodes that contain the `Step` instances.
Expand Down Expand Up @@ -109,10 +112,15 @@ pub(crate) fn construct_tree(steps: Vec<StepMetaData>) -> Node<StepMetaData> {
/// Split a single substep full path into the module path and the step's name.
///
/// # Example
/// ```ignore
/// input = "ipa::protocol::modulus_conversion::convert_shares::Step::xor1"
/// output = ("ipa::protocol::modulus_conversion::convert_shares::Step", "xor1")
/// ```
pub(crate) fn split_step_module_and_name(input: &str) -> (String, String) {
let mod_parts = input.split("::").map(|s| s.to_owned()).collect::<Vec<_>>();
let mod_parts = input
.split("::")
.map(std::borrow::ToOwned::to_owned)
.collect::<Vec<_>>();
let (substep_name, path) = mod_parts.split_last().unwrap();
(path.join("::"), substep_name.to_owned())
}
Expand All @@ -123,8 +131,8 @@ pub(crate) fn split_step_module_and_name(input: &str) -> (String, String) {
/// # Example
/// Let say we have the following steps:
///
/// - StepA::A1
/// - StepC::C1/StepD::D1/StepA::A2
/// - `StepA::A1`
/// - `StepC::C1/StepD::D1/StepA::A2`
///
/// If we generate code for each node while traversing, we will end up with the following:
///
Expand Down
Loading