-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Concurrently process messages in protocol execution loop #1136
Changes from all commits
c329ae5
6a7d3d8
18ae464
43a5a27
11b6eff
a4d864d
d0c5f67
0221607
08dafa8
b79f36f
c65d47e
2a2acc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,10 +15,11 @@ | |
|
||
//! A wrapper for the threshold signing library to handle sending and receiving messages. | ||
|
||
use futures::future::try_join_all; | ||
use num::bigint::BigUint; | ||
use rand_core::{CryptoRngCore, OsRng}; | ||
use sp_core::{sr25519, Pair}; | ||
use std::collections::VecDeque; | ||
use std::sync::Arc; | ||
use subxt::utils::AccountId32; | ||
use synedrion::{ | ||
ecdsa::VerifyingKey, | ||
|
@@ -69,11 +70,15 @@ impl RandomizedPrehashSigner<sr25519::Signature> for PairWrapper { | |
} | ||
} | ||
|
||
pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult>( | ||
pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult + 'static>( | ||
mut chans: Channels, | ||
session: Session<Res, sr25519::Signature, PairWrapper, PartyId>, | ||
session_id_hash: [u8; 32], | ||
) -> Result<(Res::Success, Channels), GenericProtocolError<Res>> { | ||
) -> Result<(Res::Success, Channels), GenericProtocolError<Res>> | ||
where | ||
<Res as synedrion::ProtocolResult>::ProvableError: std::marker::Send, | ||
<Res as synedrion::ProtocolResult>::CorrectnessProof: std::marker::Send, | ||
Comment on lines
+79
to
+80
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you really need the full path here? Isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes you are right we can just do |
||
{ | ||
let session_id = synedrion::SessionId::from_seed(&session_id_hash); | ||
let tx = &chans.0; | ||
let rx = &mut chans.1; | ||
|
@@ -85,64 +90,96 @@ pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult>( | |
|
||
loop { | ||
let mut accum = session.make_accumulator(); | ||
|
||
// Send out messages | ||
let destinations = session.message_destinations(); | ||
// TODO (#641): this can happen in a spawned task | ||
for destination in destinations.iter() { | ||
let (message, artifact) = session.make_message(&mut OsRng, destination)?; | ||
tx.send(ProtocolMessage::new(&my_id, destination, message))?; | ||
|
||
// This will happen in a host task | ||
accum.add_artifact(artifact)?; | ||
let current_round = session.current_round(); | ||
let session_arc = Arc::new(session); | ||
|
||
// Send outgoing messages | ||
let destinations = session_arc.message_destinations(); | ||
let join_handles = destinations.iter().map(|destination| { | ||
let session_arc = session_arc.clone(); | ||
let tx = tx.clone(); | ||
let my_id = my_id.clone(); | ||
let destination = destination.clone(); | ||
tokio::spawn(async move { | ||
session_arc | ||
.make_message(&mut OsRng, &destination) | ||
.map(|(message, artifact)| { | ||
tx.send(ProtocolMessage::new(&my_id, &destination, message)) | ||
.map(|_| artifact) | ||
.map_err(|err| { | ||
let err: GenericProtocolError<Res> = err.into(); | ||
err | ||
}) | ||
}) | ||
.map_err(|err| { | ||
let err: GenericProtocolError<Res> = err.into(); | ||
err | ||
}) | ||
}) | ||
}); | ||
|
||
for result in try_join_all(join_handles).await? { | ||
accum.add_artifact(result??)?; | ||
} | ||
|
||
for preprocessed in cached_messages { | ||
// TODO (#641): this may happen in a spawned task. | ||
let processed = session.process_message(&mut OsRng, preprocessed)?; | ||
// Process cached messages | ||
let join_handles = cached_messages.into_iter().map(|preprocessed| { | ||
let session_arc = session_arc.clone(); | ||
tokio::spawn(async move { session_arc.process_message(&mut OsRng, preprocessed) }) | ||
}); | ||
|
||
// This will happen in a host task. | ||
accum.add_processed_message(processed)??; | ||
for result in try_join_all(join_handles).await? { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will abort all the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean if there is an error in one of the tasks, all will be aborted because of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Less because of the |
||
accum.add_processed_message(result?)??; | ||
} | ||
|
||
while !session.can_finalize(&accum)? { | ||
let mut messages_for_later = VecDeque::new(); | ||
let (from, payload) = loop { | ||
let message = rx.recv().await.ok_or_else(|| { | ||
GenericProtocolError::<Res>::IncomingStream(format!( | ||
"{:?}", | ||
session.current_round() | ||
)) | ||
})?; | ||
|
||
if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() { | ||
if payload.session_id() == &session_id { | ||
break (message.from, *payload); | ||
// Receive and process incoming messages | ||
let (process_tx, mut process_rx) = mpsc::channel(1024); | ||
while !session_arc.can_finalize(&accum)? { | ||
tokio::select! { | ||
// Incoming message from remote peer | ||
maybe_message = rx.recv() => { | ||
let message = maybe_message.ok_or_else(|| { | ||
GenericProtocolError::IncomingStream(format!("{:?}", current_round)) | ||
})?; | ||
|
||
if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() { | ||
if payload.session_id() == &session_id { | ||
// Perform quick checks before proceeding with the verification. | ||
let preprocessed = | ||
session_arc.preprocess_message(&mut accum, &message.from, *payload)?; | ||
|
||
if let Some(preprocessed) = preprocessed { | ||
let session_arc = session_arc.clone(); | ||
let tx = process_tx.clone(); | ||
tokio::spawn(async move { | ||
let result = session_arc.process_message(&mut OsRng, preprocessed); | ||
if tx.send(result).await.is_err() { | ||
tracing::error!("Protocol finished before message processing result sent"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, isn't this a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am pretty sure that it is the case. If the channel is full there will be no error, it will That only leaves that the receiver is closed due to either |
||
} | ||
}); | ||
} | ||
} else { | ||
tracing::warn!("Got protocol message with incorrect session ID - putting back in queue"); | ||
tx.incoming_sender.send(message).await?; | ||
} | ||
} else { | ||
tracing::warn!("Got protocol message with incorrect session ID - putting back in queue"); | ||
messages_for_later.push_back(message); | ||
tracing::warn!("Got verifying key during protocol - ignoring"); | ||
} | ||
} else { | ||
tracing::warn!("Got verifying key during protocol - ignoring"); | ||
} | ||
}; | ||
// Put messages which were not for this session back onto the incoming message channel | ||
for message in messages_for_later.into_iter() { | ||
tx.incoming_sender.send(message).await?; | ||
} | ||
// Perform quick checks before proceeding with the verification. | ||
let preprocessed = session.preprocess_message(&mut accum, &from, payload)?; | ||
|
||
if let Some(preprocessed) = preprocessed { | ||
// TODO (#641): this may happen in a spawned task. | ||
let result = session.process_message(&mut OsRng, preprocessed)?; | ||
|
||
// This will happen in a host task. | ||
accum.add_processed_message(result)??; | ||
// Result from processing a message | ||
maybe_result = process_rx.recv() => { | ||
if let Some(result) = maybe_result { | ||
accum.add_processed_message(result?)??; | ||
} | ||
} | ||
} | ||
} | ||
|
||
match session.finalize_round(&mut OsRng, accum)? { | ||
// Get session back out of Arc | ||
let session_inner = | ||
Arc::try_unwrap(session_arc).map_err(|_| GenericProtocolError::ArcUnwrapError)?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we assume there are no other references to the session at this point, which is always the case in the happy path as the protocol only finishes when all messages are processed. @fjarri Im not sure if there are some syendrion error cases where The alternative would be to put it in a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in the current The way I envision this API to be used is that as soon as you get the definitive answer from Now for the waiting part, I know how I would organize it in Python, but I am not exactly sure for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would the underlying cause for this error be though? Not actually having processed all the messages? Asking because I'm not sure how much help we're giving the caller by just saying that we failed on unwrap an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree we should put a better error message. But ideally i would only leave this in if i was pretty sure this was never going to happen. We only get to this point in the code when the loop has broken because Since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for an error type describing the problem, |
||
match session_inner.finalize_round(&mut OsRng, accum)? { | ||
FinalizeOutcome::Success(res) => break Ok((res, chans)), | ||
FinalizeOutcome::AnotherRound { | ||
session: new_session, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ use futures::future; | |
use rand_core::OsRng; | ||
use serial_test::serial; | ||
use sp_core::{sr25519, Pair}; | ||
use std::time::Instant; | ||
use std::{cmp::min, time::Instant}; | ||
use subxt::utils::AccountId32; | ||
use synedrion::{ecdsa::VerifyingKey, AuxInfo, KeyShare, ThresholdKeyShare}; | ||
use tokio::{net::TcpListener, runtime::Runtime, sync::oneshot}; | ||
|
@@ -33,37 +33,40 @@ use helpers::{server, ProtocolOutput}; | |
|
||
use std::collections::BTreeSet; | ||
|
||
/// The maximum number of worker threads that tokio should use | ||
const MAX_THREADS: usize = 16; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was setting the number of tokio worker threads to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What kind of problems did you run into? And perhaps why not pick a multiple of whatever There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem was the test just hangs and only finishes when the test harness timeout runs out. I guess we could pick half of |
||
|
||
#[test] | ||
#[serial] | ||
fn sign_protocol_with_time_logged() { | ||
let cpus = num_cpus::get(); | ||
get_tokio_runtime(cpus).block_on(async { | ||
test_sign_with_parties(cpus).await; | ||
let num_parties = min(num_cpus::get(), MAX_THREADS); | ||
get_tokio_runtime(num_parties).block_on(async { | ||
test_sign_with_parties(num_parties).await; | ||
}) | ||
} | ||
|
||
#[test] | ||
#[serial] | ||
fn refresh_protocol_with_time_logged() { | ||
let cpus = num_cpus::get(); | ||
get_tokio_runtime(cpus).block_on(async { | ||
test_refresh_with_parties(cpus).await; | ||
let num_parties = min(num_cpus::get(), MAX_THREADS); | ||
get_tokio_runtime(num_parties).block_on(async { | ||
test_refresh_with_parties(num_parties).await; | ||
}) | ||
} | ||
|
||
#[test] | ||
#[serial] | ||
fn dkg_protocol_with_time_logged() { | ||
let cpus = num_cpus::get(); | ||
get_tokio_runtime(cpus).block_on(async { | ||
test_dkg_with_parties(cpus).await; | ||
let num_parties = min(num_cpus::get(), MAX_THREADS); | ||
get_tokio_runtime(num_parties).block_on(async { | ||
test_dkg_with_parties(num_parties).await; | ||
}) | ||
} | ||
|
||
#[test] | ||
#[serial] | ||
fn t_of_n_dkg_and_sign() { | ||
let cpus = num_cpus::get(); | ||
let cpus = min(num_cpus::get(), MAX_THREADS); | ||
// For this test we need at least 3 parties | ||
let parties = 3; | ||
get_tokio_runtime(cpus).block_on(async { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to go with a
'static
lifetime here instead of something shorter?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is from the error handling using the generic
GenericProtocolError<Res>
in a spawned task. Probably its possible to use some other error type in the spawned task and converted to a generic error after its been passed back out to the host task. Hopefully then we wouldn't need the static lifetime.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried to remove the need for this static and can't come up with something.
Having read up about it, i am with @JesseAbram on this - think we should not worry.
The answer here:
https://stackoverflow.com/questions/66510485/when-is-a-static-lifetime-not-appropriate
says that
'static
means can live forever, not will live forever. We can see from this code that the references in the spawned tasks are only going to be around as long as the messages are still being processed.