Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrently process messages in protocol execution loop #1136

Merged
merged 12 commits into from
Nov 1, 2024
16 changes: 16 additions & 0 deletions crates/protocol/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pub enum GenericProtocolError<Res: ProtocolResult> {
Broadcast(#[from] Box<tokio::sync::broadcast::error::SendError<ProtocolMessage>>),
#[error("Mpsc send error: {0}")]
Mpsc(#[from] tokio::sync::mpsc::error::SendError<ProtocolMessage>),
#[error("Could not get session out of Arc - session has finalized before message processing finished")]
ArcUnwrapError,
#[error("Message processing task panic or cancellation: {0}")]
JoinHandle(#[from] tokio::task::JoinError),
}

impl<Res: ProtocolResult> From<sessions::LocalError> for GenericProtocolError<Res> {
Expand Down Expand Up @@ -61,6 +65,8 @@ impl From<GenericProtocolError<InteractiveSigningResult<KeyParams, PartyId>>>
GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err),
GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err),
GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err),
GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError,
GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err),
}
}
}
Expand All @@ -73,6 +79,8 @@ impl From<GenericProtocolError<KeyInitResult<KeyParams, PartyId>>> for ProtocolE
GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err),
GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err),
GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err),
GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError,
GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err),
}
}
}
Expand All @@ -85,6 +93,8 @@ impl From<GenericProtocolError<KeyResharingResult<KeyParams, PartyId>>> for Prot
GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err),
GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err),
GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err),
GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError,
GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err),
}
}
}
Expand All @@ -97,6 +107,8 @@ impl From<GenericProtocolError<AuxGenResult<KeyParams, PartyId>>> for ProtocolEx
GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err),
GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err),
GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err),
GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError,
GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err),
}
}
}
Expand Down Expand Up @@ -136,6 +148,10 @@ pub enum ProtocolExecutionErr {
BadVerifyingKey(String),
#[error("Expected verifying key but got a protocol message")]
UnexpectedMessage,
#[error("Could not get session out of Arc")]
ArcUnwrapError,
#[error("Message processing task panic or cancellation: {0}")]
JoinHandle(#[from] tokio::task::JoinError),
}

#[derive(Debug, Error)]
Expand Down
135 changes: 86 additions & 49 deletions crates/protocol/src/execute_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

//! A wrapper for the threshold signing library to handle sending and receiving messages.

use futures::future::try_join_all;
use num::bigint::BigUint;
use rand_core::{CryptoRngCore, OsRng};
use sp_core::{sr25519, Pair};
use std::collections::VecDeque;
use std::sync::Arc;
use subxt::utils::AccountId32;
use synedrion::{
ecdsa::VerifyingKey,
Expand Down Expand Up @@ -69,11 +70,15 @@ impl RandomizedPrehashSigner<sr25519::Signature> for PairWrapper {
}
}

pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult>(
pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult + 'static>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to go with a 'static lifetime here instead of something shorter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from the error handling using the generic GenericProtocolError<Res> in a spawned task. Probably its possible to use some other error type in the spawned task and converted to a generic error after its been passed back out to the host task. Hopefully then we wouldn't need the static lifetime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to remove the need for this static and can't come up with something.

Having read up about it, i am with @JesseAbram on this - think we should not worry.

The answer here:
https://stackoverflow.com/questions/66510485/when-is-a-static-lifetime-not-appropriate

says that 'static means can live forever, not will live forever. We can see from this code that the references in the spawned tasks are only going to be around as long as the messages are still being processed.

mut chans: Channels,
session: Session<Res, sr25519::Signature, PairWrapper, PartyId>,
session_id_hash: [u8; 32],
) -> Result<(Res::Success, Channels), GenericProtocolError<Res>> {
) -> Result<(Res::Success, Channels), GenericProtocolError<Res>>
where
<Res as synedrion::ProtocolResult>::ProvableError: std::marker::Send,
<Res as synedrion::ProtocolResult>::CorrectnessProof: std::marker::Send,
Comment on lines +79 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need the full path here? Isn't Send enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right we can just do Send

{
let session_id = synedrion::SessionId::from_seed(&session_id_hash);
let tx = &chans.0;
let rx = &mut chans.1;
Expand All @@ -85,64 +90,96 @@ pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult>(

loop {
let mut accum = session.make_accumulator();

// Send out messages
let destinations = session.message_destinations();
// TODO (#641): this can happen in a spawned task
for destination in destinations.iter() {
let (message, artifact) = session.make_message(&mut OsRng, destination)?;
tx.send(ProtocolMessage::new(&my_id, destination, message))?;

// This will happen in a host task
accum.add_artifact(artifact)?;
let current_round = session.current_round();
let session_arc = Arc::new(session);

// Send outgoing messages
let destinations = session_arc.message_destinations();
let join_handles = destinations.iter().map(|destination| {
let session_arc = session_arc.clone();
let tx = tx.clone();
let my_id = my_id.clone();
let destination = destination.clone();
tokio::spawn(async move {
session_arc
.make_message(&mut OsRng, &destination)
.map(|(message, artifact)| {
tx.send(ProtocolMessage::new(&my_id, &destination, message))
.map(|_| artifact)
.map_err(|err| {
let err: GenericProtocolError<Res> = err.into();
err
})
})
.map_err(|err| {
let err: GenericProtocolError<Res> = err.into();
err
})
})
});

for result in try_join_all(join_handles).await? {
accum.add_artifact(result??)?;
}

for preprocessed in cached_messages {
// TODO (#641): this may happen in a spawned task.
let processed = session.process_message(&mut OsRng, preprocessed)?;
// Process cached messages
let join_handles = cached_messages.into_iter().map(|preprocessed| {
let session_arc = session_arc.clone();
tokio::spawn(async move { session_arc.process_message(&mut OsRng, preprocessed) })
});

// This will happen in a host task.
accum.add_processed_message(processed)??;
for result in try_join_all(join_handles).await? {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will abort all the process_message tasks and return immediately. Is this what we want here? It might be exactly what we want, but just checking it's what we expect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean if there is an error in one of the tasks, all will be aborted because of the ? on try_join_all? If so, yes i think this is what we want. If we fail to process any message we assume we cannot recover and abort the protocol session.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less because of the ? and more because of how try_join_all works: abort as soon as any of the tasks abort.

accum.add_processed_message(result?)??;
}

while !session.can_finalize(&accum)? {
let mut messages_for_later = VecDeque::new();
let (from, payload) = loop {
let message = rx.recv().await.ok_or_else(|| {
GenericProtocolError::<Res>::IncomingStream(format!(
"{:?}",
session.current_round()
))
})?;

if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() {
if payload.session_id() == &session_id {
break (message.from, *payload);
// Receive and process incoming messages
let (process_tx, mut process_rx) = mpsc::channel(1024);
while !session_arc.can_finalize(&accum)? {
tokio::select! {
// Incoming message from remote peer
maybe_message = rx.recv() => {
let message = maybe_message.ok_or_else(|| {
GenericProtocolError::IncomingStream(format!("{:?}", current_round))
})?;

if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() {
if payload.session_id() == &session_id {
// Perform quick checks before proceeding with the verification.
let preprocessed =
session_arc.preprocess_message(&mut accum, &message.from, *payload)?;

if let Some(preprocessed) = preprocessed {
let session_arc = session_arc.clone();
let tx = process_tx.clone();
tokio::spawn(async move {
let result = session_arc.process_message(&mut OsRng, preprocessed);
if tx.send(result).await.is_err() {
tracing::error!("Protocol finished before message processing result sent");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, isn't this a SendError, i.e. the rx side of the channel closed or the buffer is full or some such shenanigans? Like, the log message seems to assume that cause is that the protocol finished prematurely but are we positive that is the case?

Copy link
Contributor Author

@ameba23 ameba23 Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure that it is the case.

If the channel is full there will be no error, it will await until there is space for a message.

That only leaves that the receiver is closed due to either close() being called (which we never do) or the receiver being dropped.

}
});
}
} else {
tracing::warn!("Got protocol message with incorrect session ID - putting back in queue");
tx.incoming_sender.send(message).await?;
}
} else {
tracing::warn!("Got protocol message with incorrect session ID - putting back in queue");
messages_for_later.push_back(message);
tracing::warn!("Got verifying key during protocol - ignoring");
}
} else {
tracing::warn!("Got verifying key during protocol - ignoring");
}
};
// Put messages which were not for this session back onto the incoming message channel
for message in messages_for_later.into_iter() {
tx.incoming_sender.send(message).await?;
}
// Perform quick checks before proceeding with the verification.
let preprocessed = session.preprocess_message(&mut accum, &from, payload)?;

if let Some(preprocessed) = preprocessed {
// TODO (#641): this may happen in a spawned task.
let result = session.process_message(&mut OsRng, preprocessed)?;

// This will happen in a host task.
accum.add_processed_message(result)??;
// Result from processing a message
maybe_result = process_rx.recv() => {
if let Some(result) = maybe_result {
accum.add_processed_message(result?)??;
}
}
}
}

match session.finalize_round(&mut OsRng, accum)? {
// Get session back out of Arc
let session_inner =
Arc::try_unwrap(session_arc).map_err(|_| GenericProtocolError::ArcUnwrapError)?;
Copy link
Contributor Author

@ameba23 ameba23 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we assume there are no other references to the session at this point, which is always the case in the happy path as the protocol only finishes when all messages are processed.

@fjarri Im not sure if there are some syendrion error cases where session.can_finalize() will return true before all received messages have been processed. If that is the case, instead of returning an error here we should wait until all other references to session are dropped.

The alternative would be to put it in a Arc<RwLock<>> for the entire protocol (not just one round).

Copy link
Member

@fjarri fjarri Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in the current synedrion. The current manul version can exit prematurely if there are so many invalid messages already registered that it will be impossible to finalize (and for the current protocols it means if at least one invalid message is registered). Also it will be possible to finalize early in the future with the support of nodes dropping out during the protocol (entropyxyz/manul#11).

The way I envision this API to be used is that as soon as you get the definitive answer from can_finalize() (true now, or CanFinalize::Yes or Never in manul), you stop processing messages, wait for all the existing processing tasks to finish (there are no cancel points in them), and finalize the round (or terminate the session if the round can never be finalized, in manul).

Now for the waiting part, I know how I would organize it in Python, but I am not exactly sure for tokio - haven't played with async much in Rust. I just assumed that if Python can do it, it would be possible somehow in Rust too, I hope I was correct :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would the underlying cause for this error be though? Not actually having processed all the messages?

Asking because I'm not sure how much help we're giving the caller by just saying that we failed on unwrap an Arc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree we should put a better error message. But ideally i would only leave this in if i was pretty sure this was never going to happen. We only get to this point in the code when the loop has broken because can_finalize is true - which means the protocol round is finished. If there are still some references to Session lying around in spawned tasks at that point then this will fail. But that would also mean some messages have not yet been processed, so the round should not be finished.

Since try_unwrap actually returns the Arc back in the error variant of the Result if it fails, it would be possible to loop around waiting until those references get dropped because the tasks finish. But i think its only worth doing that if it is actually possible to get into such a state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for an error type describing the problem, ArcUnwrapError.

match session_inner.finalize_round(&mut OsRng, accum)? {
FinalizeOutcome::Success(res) => break Ok((res, chans)),
FinalizeOutcome::AnotherRound {
session: new_session,
Expand Down
25 changes: 14 additions & 11 deletions crates/protocol/tests/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use futures::future;
use rand_core::OsRng;
use serial_test::serial;
use sp_core::{sr25519, Pair};
use std::time::Instant;
use std::{cmp::min, time::Instant};
use subxt::utils::AccountId32;
use synedrion::{ecdsa::VerifyingKey, AuxInfo, KeyShare, ThresholdKeyShare};
use tokio::{net::TcpListener, runtime::Runtime, sync::oneshot};
Expand All @@ -33,37 +33,40 @@ use helpers::{server, ProtocolOutput};

use std::collections::BTreeSet;

/// The maximum number of worker threads that tokio should use
const MAX_THREADS: usize = 16;
Copy link
Contributor Author

@ameba23 ameba23 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was setting the number of tokio worker threads to num_cpus, but i had issues running with 32 worker threads on our 32 core box. Tokio documentation says 'it is advised to keep this value on the smaller side'.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of problems did you run into? And perhaps why not pick a multiple of whatever num_cpu reports? Like half or a quarter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was the test just hangs and only finishes when the test harness timeout runs out. I guess we could pick half of num_cpus if num_cpus is greater than some limit. Since in CI i think we only have 2 cpus, so we definitely want both. Bear in mind this is only for tests / benchmarks. I think in production we use the tokio default which if ive understood right is actually the number of cpus: https://docs.rs/tokio/latest/tokio/attr.main.html#multi-threaded-runtime


#[test]
#[serial]
fn sign_protocol_with_time_logged() {
let cpus = num_cpus::get();
get_tokio_runtime(cpus).block_on(async {
test_sign_with_parties(cpus).await;
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_sign_with_parties(num_parties).await;
})
}

#[test]
#[serial]
fn refresh_protocol_with_time_logged() {
let cpus = num_cpus::get();
get_tokio_runtime(cpus).block_on(async {
test_refresh_with_parties(cpus).await;
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_refresh_with_parties(num_parties).await;
})
}

#[test]
#[serial]
fn dkg_protocol_with_time_logged() {
let cpus = num_cpus::get();
get_tokio_runtime(cpus).block_on(async {
test_dkg_with_parties(cpus).await;
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_dkg_with_parties(num_parties).await;
})
}

#[test]
#[serial]
fn t_of_n_dkg_and_sign() {
let cpus = num_cpus::get();
let cpus = min(num_cpus::get(), MAX_THREADS);
// For this test we need at least 3 parties
let parties = 3;
get_tokio_runtime(cpus).block_on(async {
Expand Down
Loading