diff --git a/crates/protocol/src/errors.rs b/crates/protocol/src/errors.rs index b76ec1d59..6555a05e4 100644 --- a/crates/protocol/src/errors.rs +++ b/crates/protocol/src/errors.rs @@ -31,6 +31,10 @@ pub enum GenericProtocolError { Broadcast(#[from] Box>), #[error("Mpsc send error: {0}")] Mpsc(#[from] tokio::sync::mpsc::error::SendError), + #[error("Could not get session out of Arc - session has finalized before message processing finished")] + ArcUnwrapError, + #[error("Message processing task panic or cancellation: {0}")] + JoinHandle(#[from] tokio::task::JoinError), } impl From for GenericProtocolError { @@ -61,6 +65,8 @@ impl From>> GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err), GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err), GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err), + GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError, + GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err), } } } @@ -73,6 +79,8 @@ impl From>> for ProtocolE GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err), GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err), GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err), + GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError, + GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err), } } } @@ -85,6 +93,8 @@ impl From>> for Prot GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err), GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err), GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err), + GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError, + GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err), } } } @@ -97,6 +107,8 @@ impl From>> for ProtocolEx GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err), GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err), GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err), + GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError, + GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err), } } } @@ -136,6 +148,10 @@ pub enum ProtocolExecutionErr { BadVerifyingKey(String), #[error("Expected verifying key but got a protocol message")] UnexpectedMessage, + #[error("Could not get session out of Arc")] + ArcUnwrapError, + #[error("Message processing task panic or cancellation: {0}")] + JoinHandle(#[from] tokio::task::JoinError), } #[derive(Debug, Error)] diff --git a/crates/protocol/src/execute_protocol.rs b/crates/protocol/src/execute_protocol.rs index c7df58fe0..835f3c673 100644 --- a/crates/protocol/src/execute_protocol.rs +++ b/crates/protocol/src/execute_protocol.rs @@ -15,10 +15,11 @@ //! A wrapper for the threshold signing library to handle sending and receiving messages. +use futures::future::try_join_all; use num::bigint::BigUint; use rand_core::{CryptoRngCore, OsRng}; use sp_core::{sr25519, Pair}; -use std::collections::VecDeque; +use std::sync::Arc; use subxt::utils::AccountId32; use synedrion::{ ecdsa::VerifyingKey, @@ -69,11 +70,15 @@ impl RandomizedPrehashSigner for PairWrapper { } } -pub async fn execute_protocol_generic( +pub async fn execute_protocol_generic( mut chans: Channels, session: Session, session_id_hash: [u8; 32], -) -> Result<(Res::Success, Channels), GenericProtocolError> { +) -> Result<(Res::Success, Channels), GenericProtocolError> +where + ::ProvableError: std::marker::Send, + ::CorrectnessProof: std::marker::Send, +{ let session_id = synedrion::SessionId::from_seed(&session_id_hash); let tx = &chans.0; let rx = &mut chans.1; @@ -85,64 +90,96 @@ pub async fn execute_protocol_generic( loop { let mut accum = session.make_accumulator(); - - // Send out messages - let destinations = session.message_destinations(); - // TODO (#641): this can happen in a spawned task - for destination in destinations.iter() { - let (message, artifact) = session.make_message(&mut OsRng, destination)?; - tx.send(ProtocolMessage::new(&my_id, destination, message))?; - - // This will happen in a host task - accum.add_artifact(artifact)?; + let current_round = session.current_round(); + let session_arc = Arc::new(session); + + // Send outgoing messages + let destinations = session_arc.message_destinations(); + let join_handles = destinations.iter().map(|destination| { + let session_arc = session_arc.clone(); + let tx = tx.clone(); + let my_id = my_id.clone(); + let destination = destination.clone(); + tokio::spawn(async move { + session_arc + .make_message(&mut OsRng, &destination) + .map(|(message, artifact)| { + tx.send(ProtocolMessage::new(&my_id, &destination, message)) + .map(|_| artifact) + .map_err(|err| { + let err: GenericProtocolError = err.into(); + err + }) + }) + .map_err(|err| { + let err: GenericProtocolError = err.into(); + err + }) + }) + }); + + for result in try_join_all(join_handles).await? { + accum.add_artifact(result??)?; } - for preprocessed in cached_messages { - // TODO (#641): this may happen in a spawned task. - let processed = session.process_message(&mut OsRng, preprocessed)?; + // Process cached messages + let join_handles = cached_messages.into_iter().map(|preprocessed| { + let session_arc = session_arc.clone(); + tokio::spawn(async move { session_arc.process_message(&mut OsRng, preprocessed) }) + }); - // This will happen in a host task. - accum.add_processed_message(processed)??; + for result in try_join_all(join_handles).await? { + accum.add_processed_message(result?)??; } - while !session.can_finalize(&accum)? { - let mut messages_for_later = VecDeque::new(); - let (from, payload) = loop { - let message = rx.recv().await.ok_or_else(|| { - GenericProtocolError::::IncomingStream(format!( - "{:?}", - session.current_round() - )) - })?; - - if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() { - if payload.session_id() == &session_id { - break (message.from, *payload); + // Receive and process incoming messages + let (process_tx, mut process_rx) = mpsc::channel(1024); + while !session_arc.can_finalize(&accum)? { + tokio::select! { + // Incoming message from remote peer + maybe_message = rx.recv() => { + let message = maybe_message.ok_or_else(|| { + GenericProtocolError::IncomingStream(format!("{:?}", current_round)) + })?; + + if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() { + if payload.session_id() == &session_id { + // Perform quick checks before proceeding with the verification. + let preprocessed = + session_arc.preprocess_message(&mut accum, &message.from, *payload)?; + + if let Some(preprocessed) = preprocessed { + let session_arc = session_arc.clone(); + let tx = process_tx.clone(); + tokio::spawn(async move { + let result = session_arc.process_message(&mut OsRng, preprocessed); + if tx.send(result).await.is_err() { + tracing::error!("Protocol finished before message processing result sent"); + } + }); + } + } else { + tracing::warn!("Got protocol message with incorrect session ID - putting back in queue"); + tx.incoming_sender.send(message).await?; + } } else { - tracing::warn!("Got protocol message with incorrect session ID - putting back in queue"); - messages_for_later.push_back(message); + tracing::warn!("Got verifying key during protocol - ignoring"); } - } else { - tracing::warn!("Got verifying key during protocol - ignoring"); } - }; - // Put messages which were not for this session back onto the incoming message channel - for message in messages_for_later.into_iter() { - tx.incoming_sender.send(message).await?; - } - // Perform quick checks before proceeding with the verification. - let preprocessed = session.preprocess_message(&mut accum, &from, payload)?; - if let Some(preprocessed) = preprocessed { - // TODO (#641): this may happen in a spawned task. - let result = session.process_message(&mut OsRng, preprocessed)?; - - // This will happen in a host task. - accum.add_processed_message(result)??; + // Result from processing a message + maybe_result = process_rx.recv() => { + if let Some(result) = maybe_result { + accum.add_processed_message(result?)??; + } + } } } - match session.finalize_round(&mut OsRng, accum)? { + // Get session back out of Arc + let session_inner = + Arc::try_unwrap(session_arc).map_err(|_| GenericProtocolError::ArcUnwrapError)?; + match session_inner.finalize_round(&mut OsRng, accum)? { FinalizeOutcome::Success(res) => break Ok((res, chans)), FinalizeOutcome::AnotherRound { session: new_session, diff --git a/crates/protocol/tests/protocol.rs b/crates/protocol/tests/protocol.rs index 73c19d814..19e04adaa 100644 --- a/crates/protocol/tests/protocol.rs +++ b/crates/protocol/tests/protocol.rs @@ -22,7 +22,7 @@ use futures::future; use rand_core::OsRng; use serial_test::serial; use sp_core::{sr25519, Pair}; -use std::time::Instant; +use std::{cmp::min, time::Instant}; use subxt::utils::AccountId32; use synedrion::{ecdsa::VerifyingKey, AuxInfo, KeyShare, ThresholdKeyShare}; use tokio::{net::TcpListener, runtime::Runtime, sync::oneshot}; @@ -33,37 +33,40 @@ use helpers::{server, ProtocolOutput}; use std::collections::BTreeSet; +/// The maximum number of worker threads that tokio should use +const MAX_THREADS: usize = 16; + #[test] #[serial] fn sign_protocol_with_time_logged() { - let cpus = num_cpus::get(); - get_tokio_runtime(cpus).block_on(async { - test_sign_with_parties(cpus).await; + let num_parties = min(num_cpus::get(), MAX_THREADS); + get_tokio_runtime(num_parties).block_on(async { + test_sign_with_parties(num_parties).await; }) } #[test] #[serial] fn refresh_protocol_with_time_logged() { - let cpus = num_cpus::get(); - get_tokio_runtime(cpus).block_on(async { - test_refresh_with_parties(cpus).await; + let num_parties = min(num_cpus::get(), MAX_THREADS); + get_tokio_runtime(num_parties).block_on(async { + test_refresh_with_parties(num_parties).await; }) } #[test] #[serial] fn dkg_protocol_with_time_logged() { - let cpus = num_cpus::get(); - get_tokio_runtime(cpus).block_on(async { - test_dkg_with_parties(cpus).await; + let num_parties = min(num_cpus::get(), MAX_THREADS); + get_tokio_runtime(num_parties).block_on(async { + test_dkg_with_parties(num_parties).await; }) } #[test] #[serial] fn t_of_n_dkg_and_sign() { - let cpus = num_cpus::get(); + let cpus = min(num_cpus::get(), MAX_THREADS); // For this test we need at least 3 parties let parties = 3; get_tokio_runtime(cpus).block_on(async {