Skip to content

Commit

Permalink
Concurrently process messages in protocol execution loop (#1136)
Browse files Browse the repository at this point in the history
* Re-apply fjarris suggestions

* Use RwLock rather than mutex

* Dont use a mutex

* Error handling

* Error handling

* Tidy

* Also concurrently process outgoing messages

* Only send outgoing messages if there are messages to send

* Error handling

* Dont use more than 16 worker threads

* Use join handle for outgoing messages

* Tidy
  • Loading branch information
ameba23 authored Nov 1, 2024
1 parent f01b719 commit 4528bf8
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 60 deletions.
16 changes: 16 additions & 0 deletions crates/protocol/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pub enum GenericProtocolError<Res: ProtocolResult> {
Broadcast(#[from] Box<tokio::sync::broadcast::error::SendError<ProtocolMessage>>),
#[error("Mpsc send error: {0}")]
Mpsc(#[from] tokio::sync::mpsc::error::SendError<ProtocolMessage>),
#[error("Could not get session out of Arc - session has finalized before message processing finished")]
ArcUnwrapError,
#[error("Message processing task panic or cancellation: {0}")]
JoinHandle(#[from] tokio::task::JoinError),
}

impl<Res: ProtocolResult> From<sessions::LocalError> for GenericProtocolError<Res> {
Expand Down Expand Up @@ -61,6 +65,8 @@ impl From<GenericProtocolError<InteractiveSigningResult<KeyParams, PartyId>>>
GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err),
GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err),
GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err),
GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError,
GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err),
}
}
}
Expand All @@ -73,6 +79,8 @@ impl From<GenericProtocolError<KeyInitResult<KeyParams, PartyId>>> for ProtocolE
GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err),
GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err),
GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err),
GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError,
GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err),
}
}
}
Expand All @@ -85,6 +93,8 @@ impl From<GenericProtocolError<KeyResharingResult<KeyParams, PartyId>>> for Prot
GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err),
GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err),
GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err),
GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError,
GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err),
}
}
}
Expand All @@ -97,6 +107,8 @@ impl From<GenericProtocolError<AuxGenResult<KeyParams, PartyId>>> for ProtocolEx
GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err),
GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err),
GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err),
GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError,
GenericProtocolError::JoinHandle(err) => ProtocolExecutionErr::JoinHandle(err),
}
}
}
Expand Down Expand Up @@ -136,6 +148,10 @@ pub enum ProtocolExecutionErr {
BadVerifyingKey(String),
#[error("Expected verifying key but got a protocol message")]
UnexpectedMessage,
#[error("Could not get session out of Arc")]
ArcUnwrapError,
#[error("Message processing task panic or cancellation: {0}")]
JoinHandle(#[from] tokio::task::JoinError),
}

#[derive(Debug, Error)]
Expand Down
135 changes: 86 additions & 49 deletions crates/protocol/src/execute_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

//! A wrapper for the threshold signing library to handle sending and receiving messages.
use futures::future::try_join_all;
use num::bigint::BigUint;
use rand_core::{CryptoRngCore, OsRng};
use sp_core::{sr25519, Pair};
use std::collections::VecDeque;
use std::sync::Arc;
use subxt::utils::AccountId32;
use synedrion::{
ecdsa::VerifyingKey,
Expand Down Expand Up @@ -69,11 +70,15 @@ impl RandomizedPrehashSigner<sr25519::Signature> for PairWrapper {
}
}

pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult>(
pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult + 'static>(
mut chans: Channels,
session: Session<Res, sr25519::Signature, PairWrapper, PartyId>,
session_id_hash: [u8; 32],
) -> Result<(Res::Success, Channels), GenericProtocolError<Res>> {
) -> Result<(Res::Success, Channels), GenericProtocolError<Res>>
where
<Res as synedrion::ProtocolResult>::ProvableError: std::marker::Send,
<Res as synedrion::ProtocolResult>::CorrectnessProof: std::marker::Send,
{
let session_id = synedrion::SessionId::from_seed(&session_id_hash);
let tx = &chans.0;
let rx = &mut chans.1;
Expand All @@ -85,64 +90,96 @@ pub async fn execute_protocol_generic<Res: synedrion::ProtocolResult>(

loop {
let mut accum = session.make_accumulator();

// Send out messages
let destinations = session.message_destinations();
// TODO (#641): this can happen in a spawned task
for destination in destinations.iter() {
let (message, artifact) = session.make_message(&mut OsRng, destination)?;
tx.send(ProtocolMessage::new(&my_id, destination, message))?;

// This will happen in a host task
accum.add_artifact(artifact)?;
let current_round = session.current_round();
let session_arc = Arc::new(session);

// Send outgoing messages
let destinations = session_arc.message_destinations();
let join_handles = destinations.iter().map(|destination| {
let session_arc = session_arc.clone();
let tx = tx.clone();
let my_id = my_id.clone();
let destination = destination.clone();
tokio::spawn(async move {
session_arc
.make_message(&mut OsRng, &destination)
.map(|(message, artifact)| {
tx.send(ProtocolMessage::new(&my_id, &destination, message))
.map(|_| artifact)
.map_err(|err| {
let err: GenericProtocolError<Res> = err.into();
err
})
})
.map_err(|err| {
let err: GenericProtocolError<Res> = err.into();
err
})
})
});

for result in try_join_all(join_handles).await? {
accum.add_artifact(result??)?;
}

for preprocessed in cached_messages {
// TODO (#641): this may happen in a spawned task.
let processed = session.process_message(&mut OsRng, preprocessed)?;
// Process cached messages
let join_handles = cached_messages.into_iter().map(|preprocessed| {
let session_arc = session_arc.clone();
tokio::spawn(async move { session_arc.process_message(&mut OsRng, preprocessed) })
});

// This will happen in a host task.
accum.add_processed_message(processed)??;
for result in try_join_all(join_handles).await? {
accum.add_processed_message(result?)??;
}

while !session.can_finalize(&accum)? {
let mut messages_for_later = VecDeque::new();
let (from, payload) = loop {
let message = rx.recv().await.ok_or_else(|| {
GenericProtocolError::<Res>::IncomingStream(format!(
"{:?}",
session.current_round()
))
})?;

if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() {
if payload.session_id() == &session_id {
break (message.from, *payload);
// Receive and process incoming messages
let (process_tx, mut process_rx) = mpsc::channel(1024);
while !session_arc.can_finalize(&accum)? {
tokio::select! {
// Incoming message from remote peer
maybe_message = rx.recv() => {
let message = maybe_message.ok_or_else(|| {
GenericProtocolError::IncomingStream(format!("{:?}", current_round))
})?;

if let ProtocolMessagePayload::MessageBundle(payload) = message.payload.clone() {
if payload.session_id() == &session_id {
// Perform quick checks before proceeding with the verification.
let preprocessed =
session_arc.preprocess_message(&mut accum, &message.from, *payload)?;

if let Some(preprocessed) = preprocessed {
let session_arc = session_arc.clone();
let tx = process_tx.clone();
tokio::spawn(async move {
let result = session_arc.process_message(&mut OsRng, preprocessed);
if tx.send(result).await.is_err() {
tracing::error!("Protocol finished before message processing result sent");
}
});
}
} else {
tracing::warn!("Got protocol message with incorrect session ID - putting back in queue");
tx.incoming_sender.send(message).await?;
}
} else {
tracing::warn!("Got protocol message with incorrect session ID - putting back in queue");
messages_for_later.push_back(message);
tracing::warn!("Got verifying key during protocol - ignoring");
}
} else {
tracing::warn!("Got verifying key during protocol - ignoring");
}
};
// Put messages which were not for this session back onto the incoming message channel
for message in messages_for_later.into_iter() {
tx.incoming_sender.send(message).await?;
}
// Perform quick checks before proceeding with the verification.
let preprocessed = session.preprocess_message(&mut accum, &from, payload)?;

if let Some(preprocessed) = preprocessed {
// TODO (#641): this may happen in a spawned task.
let result = session.process_message(&mut OsRng, preprocessed)?;

// This will happen in a host task.
accum.add_processed_message(result)??;
// Result from processing a message
maybe_result = process_rx.recv() => {
if let Some(result) = maybe_result {
accum.add_processed_message(result?)??;
}
}
}
}

match session.finalize_round(&mut OsRng, accum)? {
// Get session back out of Arc
let session_inner =
Arc::try_unwrap(session_arc).map_err(|_| GenericProtocolError::ArcUnwrapError)?;
match session_inner.finalize_round(&mut OsRng, accum)? {
FinalizeOutcome::Success(res) => break Ok((res, chans)),
FinalizeOutcome::AnotherRound {
session: new_session,
Expand Down
25 changes: 14 additions & 11 deletions crates/protocol/tests/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use futures::future;
use rand_core::OsRng;
use serial_test::serial;
use sp_core::{sr25519, Pair};
use std::time::Instant;
use std::{cmp::min, time::Instant};
use subxt::utils::AccountId32;
use synedrion::{ecdsa::VerifyingKey, AuxInfo, KeyShare, ThresholdKeyShare};
use tokio::{net::TcpListener, runtime::Runtime, sync::oneshot};
Expand All @@ -33,37 +33,40 @@ use helpers::{server, ProtocolOutput};

use std::collections::BTreeSet;

/// The maximum number of worker threads that tokio should use
const MAX_THREADS: usize = 16;

#[test]
#[serial]
fn sign_protocol_with_time_logged() {
let cpus = num_cpus::get();
get_tokio_runtime(cpus).block_on(async {
test_sign_with_parties(cpus).await;
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_sign_with_parties(num_parties).await;
})
}

#[test]
#[serial]
fn refresh_protocol_with_time_logged() {
let cpus = num_cpus::get();
get_tokio_runtime(cpus).block_on(async {
test_refresh_with_parties(cpus).await;
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_refresh_with_parties(num_parties).await;
})
}

#[test]
#[serial]
fn dkg_protocol_with_time_logged() {
let cpus = num_cpus::get();
get_tokio_runtime(cpus).block_on(async {
test_dkg_with_parties(cpus).await;
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_dkg_with_parties(num_parties).await;
})
}

#[test]
#[serial]
fn t_of_n_dkg_and_sign() {
let cpus = num_cpus::get();
let cpus = min(num_cpus::get(), MAX_THREADS);
// For this test we need at least 3 parties
let parties = 3;
get_tokio_runtime(cpus).block_on(async {
Expand Down

0 comments on commit 4528bf8

Please sign in to comment.