diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 948280f..3fa7461 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,19 +91,10 @@ jobs: with: cache-on-failure: "true" - - uses: taiki-e/install-action@v2 - with: - tool: nextest - - name: Forge build run: forge update && forge build - uses: taiki-e/github-actions/free-device-space@main - - name: Download Tangle Manual Sealing - run: | - wget https://github.com/tangle-network/tangle/releases/download/v1.2.3/tangle-testnet-manual-seal-linux-amd64 - chmod +x tangle-testnet-manual-seal-linux-amd64 - - name: tests - run: TANGLE_NODE=$(pwd)/tangle-testnet-manual-seal-linux-amd64 cargo nextest run \ No newline at end of file + run: cargo test \ No newline at end of file diff --git a/.gitignore b/.gitignore index bd954d6..143bda5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Generated by Cargo /target/ +dependencies # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html diff --git a/Cargo.toml b/Cargo.toml index cd4375b..42c003f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ thiserror = "2.0.3" itertools = "0.13.0" rand = "0.8.5" parking_lot = { version = "0.12.3", features = ["serde"]} +p256k1 = "5.4" +frost-taproot = { git = "https://github.com/webb-tools/tangle.git", branch = "main", default-features = false} # MPC specific deps wsts = "3.0.0" diff --git a/build.rs b/build.rs index dcbb32b..acb638b 100644 --- a/build.rs +++ b/build.rs @@ -1,7 +1,5 @@ fn main() { - println!("cargo:rerun-if-changed=src/cli"); println!("cargo:rerun-if-changed=src/lib.rs"); println!("cargo:rerun-if-changed=src/main.rs"); - println!("cargo:rerun-if-changed=src/*"); blueprint_metadata::generate_json(); } diff --git a/config b/config deleted file mode 100644 index b335fac..0000000 --- a/config +++ /dev/null @@ -1,18 +0,0 @@ -[core] - repositoryformatversion = 0 - filemode = true - bare = false - logallrefupdates = true - ignorecase = true - precomposeunicode = true -[branch "main"] - remote = origin - merge = refs/heads/main -[user] - email = tbraun96@gmail.com - name = Thomas Braun -[credential] - username = tbraun96 -[remote "origin"] - url = https://github.com/tangle-network/wsts-blueprint.git - fetch = +refs/heads/*:refs/remotes/origin/* diff --git a/contracts/.gitignore b/contracts/.gitignore index deea2d0..83cc846 100644 --- a/contracts/.gitignore +++ b/contracts/.gitignore @@ -11,4 +11,5 @@ out/ docs/ # Dotenv file -.env \ No newline at end of file +.env +./lib \ No newline at end of file diff --git a/contracts/src/BlsBlueprint.sol b/contracts/src/BlsBlueprint.sol deleted file mode 100644 index f57ded9..0000000 --- a/contracts/src/BlsBlueprint.sol +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-License-Identifier: UNLICENSE -pragma solidity >=0.8.13; - -import "contracts/lib/tnt-core/src/BlueprintServiceManagerBase.sol"; - -/** - * @title WstsBlueprint - * @dev This contract is an example of a service blueprint that provides a single service. - * @dev For all supported hooks, check the `BlueprintServiceManagerBase` contract. - */ -contract WstsBlueprint is BlueprintServiceManagerBase { - /** - * @dev Hook for service operator registration. Called when a service operator - * attempts to register with the blueprint. - * @param operator The operator's details. - * @param _registrationInputs Inputs required for registration. - */ - function onRegister(bytes calldata operator, bytes calldata _registrationInputs) - public - payable - override - onlyFromRootChain - { - // Do something with the operator's details - } - - /** - * @dev Hook for service instance requests. Called when a user requests a service - * instance from the blueprint. - * @param serviceId The ID of the requested service. - * @param operators The operators involved in the service. - * @param _requestInputs Inputs required for the service request. - */ - function onRequest(uint64 serviceId, bytes[] calldata operators, bytes calldata _requestInputs) - public - payable - override - onlyFromRootChain - { - // Do something with the service request - } - - /** - * @dev Hook for handling job call results. Called when operators send the result - * of a job execution. - * @param serviceId The ID of the service related to the job. - * @param job The job identifier. - * @param _jobCallId The unique ID for the job call. - * @param participant The participant (operator) sending the result. - * @param _inputs Inputs used for the job execution. - * @param _outputs Outputs resulting from the job execution. - */ - function onJobResult( - uint64 serviceId, - uint8 job, - uint64 _jobCallId, - bytes calldata participant, - bytes calldata _inputs, - bytes calldata _outputs - ) public payable virtual override onlyFromRootChain { - // Do something with the job call result - } - - /** - * @dev Converts a public key to an operator address. - * @param publicKey The public key to convert. - * @return operator address The operator address. - */ - function operatorAddressFromPublicKey(bytes calldata publicKey) internal pure returns (address operator) { - return address(uint160(uint256(keccak256(publicKey)))); - } -} diff --git a/contracts/src/WstsBlueprint.sol b/contracts/src/WstsBlueprint.sol new file mode 100644 index 0000000..4c1e185 --- /dev/null +++ b/contracts/src/WstsBlueprint.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: UNLICENSE +pragma solidity >=0.8.13; + +import "dependencies/tnt-core-0.1.0/src/BlueprintServiceManagerBase.sol"; + +/** + * @title WstsBlueprint + * @dev This contract is an example of a service blueprint that provides a single service. + * @dev For all supported hooks, check the `BlueprintServiceManagerBase` contract. + */ +contract WstsBlueprint is BlueprintServiceManagerBase {} diff --git a/foundry.toml b/foundry.toml index fbaf7f1..dd00769 100644 --- a/foundry.toml +++ b/foundry.toml @@ -5,9 +5,11 @@ out = "contracts/out" script = "contracts/script" cache_path = "contracts/cache" broadcast = "contracts/broadcast" -libs = ["contracts/lib", "dependencies"] +libs = ["dependencies"] auto_detect_remappings = true [dependencies] +forge-std = "1.9.4" +tnt-core = "0.1.0" # See more config options https://github.com/foundry-rs/foundry/blob/master/crates/config/README.md#all-options diff --git a/keygen.rs b/keygen.rs deleted file mode 100644 index 782995a..0000000 --- a/keygen.rs +++ /dev/null @@ -1,401 +0,0 @@ -use crate::protocols::util::{ - generate_party_key_ids, validate_parameters, FrostMessage, FrostState, -}; -use frost_taproot::VerifyingKey; -use futures::{SinkExt, StreamExt}; -use gadget_common::client::ClientWithApi; -use gadget_common::config::Network; -use gadget_common::gadget::message::UserID; -use gadget_common::gadget::JobInitMetadata; -use gadget_common::keystore::KeystoreBackend; -use gadget_common::prelude::{DebugLogger, GadgetProtocolMessage, WorkManager}; -use gadget_common::prelude::{ECDSAKeyStore, JobError}; -use gadget_common::tangle_runtime::*; -use gadget_common::utils::recover_ecdsa_pub_key; -use gadget_common::{ - BuiltExecutableJobWrapper, JobBuilder, ProtocolWorkManager, WorkManagerInterface, -}; -use hashbrown::HashMap; -use itertools::Itertools; -use rand::{CryptoRng, RngCore}; -use sp_core::{ecdsa, keccak_256, ByteArray, Pair}; -use std::sync::Arc; -use tokio::sync::mpsc::UnboundedReceiver; -use tokio::sync::Mutex; -use wsts::common::PolyCommitment; -use wsts::v2::Party; -use wsts::Scalar; - -pub const K: u32 = 1; - -#[derive(Clone)] -pub struct WstsKeygenExtraParams { - job_id: u64, - n: u32, - i: u32, - k: u32, - t: u32, - user_id_mapping: Arc>, - my_id: ecdsa::Public, -} - -pub async fn create_next_job( - config: &crate::WstsKeygenProtocol, - job: JobInitMetadata, - _work_manager: &ProtocolWorkManager, -) -> Result { - if let jobs::JobType::DKGTSSPhaseOne(p1_job) = job.job_type { - let participants = job.participants_role_ids.clone(); - let user_id_to_account_id_mapping = Arc::new( - participants - .clone() - .into_iter() - .enumerate() - .map(|r| (r.0 as UserID, r.1)) - .collect(), - ); - - let i = p1_job - .participants - .0 - .iter() - .position(|p| p.0 == config.account_id.0) - .expect("Should exist") as u16; - - let t = p1_job.threshold; - let n = p1_job.participants.0.len() as u32; - - Ok(WstsKeygenExtraParams { - job_id: job.job_id, - n, - i: i as _, - k: n, // Each party will own exactly n keys for this protocol - t: t as _, - user_id_mapping: user_id_to_account_id_mapping, - my_id: config.key_store.pair().public(), - }) - } else { - Err(gadget_common::Error::ClientError { - err: "The supplied job is not a phase 1 job".to_string(), - }) - } -} - -pub async fn generate_protocol_from( - config: &crate::WstsKeygenProtocol, - associated_block_id: ::Clock, - associated_retry_id: ::RetryID, - associated_session_id: ::SessionID, - associated_task_id: ::TaskID, - protocol_message_channel: UnboundedReceiver, - additional_params: WstsKeygenExtraParams, -) -> Result { - let result = Arc::new(Mutex::new(None)); - let result_clone = result.clone(); - let logger = config.logger.clone(); - let logger_clone = logger.clone(); - let network = config.clone(); - let keystore = config.key_store.clone(); - let keystore_clone = keystore.clone(); - let client = config.pallet_tx.clone(); - let WstsKeygenExtraParams { - job_id, - n, - i, - k, - t, - user_id_mapping, - my_id, - } = additional_params; - - let participants = user_id_mapping - .keys() - .copied() - .map(|r| r as u8) - .collect::>(); - - Ok(JobBuilder::new() - .protocol(async move { - let (tx0, rx0, tx1, rx1) = - gadget_common::channels::create_job_manager_to_async_protocol_channel_split::< - _, - FrostMessage, - FrostMessage, - >( - protocol_message_channel, - associated_block_id, - associated_retry_id, - associated_session_id, - associated_task_id, - user_id_mapping, - my_id, - network, - logger.clone(), - ); - let output = protocol(n, i, k, t, tx0, rx0, tx1, rx1, &logger, &keystore_clone).await?; - result.lock().await.replace(output); - - Ok(()) - }) - .post(async move { - if let Some((state, signatures)) = result_clone.lock().await.take() { - keystore - .set_job_result(job_id, &state) - .await - .map_err(|err| JobError { - reason: err.to_string(), - })?; - - let job_result_for_pallet = - jobs::JobResult::DKGPhaseOne(DKGTSSKeySubmissionResult { - signature_scheme: DigitalSignatureScheme::SchnorrSecp256k1, - key: BoundedVec(state.public_key_frost_format), - participants: BoundedVec(vec![BoundedVec(participants)]), - signatures: BoundedVec(signatures), - threshold: t as _, - chain_code: None, - __ignore: Default::default(), - }); - - client - .submit_job_result( - RoleType::Tss(roles::tss::ThresholdSignatureRoleType::WstsV2), - job_id, - job_result_for_pallet, - ) - .await - .map_err(|err| JobError { - reason: err.to_string(), - })?; - } - - logger_clone.info("Finished AsyncProtocol - WSTS Keygen"); - Ok(()) - }) - .build()) -} - -/// `party_id`: Should be in the range [0, n). For the DKG, should be our index in the best -/// authorities starting from 0. -/// -/// Returns the state of the party after the protocol has finished. This should be saved to the keystore and -/// later used for signing -#[allow(clippy::too_many_arguments)] -pub async fn protocol( - n: u32, - party_id: u32, - k: u32, - t: u32, - tx_to_network: futures::channel::mpsc::UnboundedSender, - rx_from_network: futures::channel::mpsc::UnboundedReceiver>, - tx_to_network_broadcast: tokio::sync::mpsc::UnboundedSender, - mut rx_from_network_broadcast: UnboundedReceiver, - logger: &DebugLogger, - key_store: &ECDSAKeyStore, -) -> Result<(FrostState, Vec>), JobError> { - validate_parameters(n, k, t)?; - - let mut rng = rand::rngs::OsRng; - let key_ids = generate_party_key_ids(n, k); - let our_key_ids = key_ids.get(party_id as usize).ok_or_else(|| JobError { - reason: "Bad party_id".to_string(), - })?; - - let mut party = Party::new(party_id, our_key_ids, n, k, t, &mut rng); - let public_key = run_dkg( - &mut party, - &mut rng, - n as usize, - tx_to_network, - rx_from_network, - logger, - ) - .await?; - - let party = party.save(); - logger.debug(format!("Combined public key: {:?}", party.group_key)); - - // Convert the WSTS group key into a FROST-compatible format - let group_point = party.group_key; - let compressed_group_point = group_point.compress(); - let verifying_key = - VerifyingKey::deserialize(compressed_group_point.data).map_err(|e| JobError { - reason: format!("Failed to convert group key to VerifyingKey: {e}"), - })?; - let public_key_frost_format = verifying_key.serialize().as_ref().to_vec(); - - // Sign this public key using our ECDSA key - let hash_of_public_key = keccak_256(&public_key_frost_format); - let signature_of_public_key = key_store.pair().sign_prehashed(&hash_of_public_key); - - // Gossip the public key - let pkey_message = FrostMessage::PublicKeyBroadcast { - party_id, - combined_public_key: public_key_frost_format.clone(), - signature_of_public_key: signature_of_public_key.clone(), - }; - - // Gossip the public key - tx_to_network_broadcast - .send(pkey_message) - .map_err(|err| JobError { - reason: format!("Error sending FROST message: {err:?}"), - })?; - - let mut received = 0; - let mut received_signatures = HashMap::new(); - received_signatures.insert(party_id, signature_of_public_key.clone()); - - // We normally need t+1, however, we aren't verifying our own signature, thus collect t keys - while received < t { - let next_message = rx_from_network_broadcast - .recv() - .await - .ok_or_else(|| JobError { - reason: "broadcast stream died".to_string(), - })?; - match next_message { - FrostMessage::PublicKeyBroadcast { - party_id, - combined_public_key, - signature_of_public_key, - } => { - // Make sure their public key is equivalent to ours - if combined_public_key.as_slice() != public_key_frost_format.as_slice() { - return Err(JobError { reason: format!("The received public key from party {party_id} does not match our public key. Aborting") }); - } - - // Verify the public key signature - recover_ecdsa_pub_key(&public_key_frost_format, signature_of_public_key.as_slice()) - .map_err(|_| JobError { - reason: format!("Failed to verify signature from party {party_id}"), - })?; - - received_signatures.insert(party_id, signature_of_public_key); - received += 1; - } - - message => { - logger.warn(format!("Received improper message: {message:?}")); - } - } - } - - let signatures = received_signatures - .into_iter() - .sorted_by_key(|k| k.0) - .map(|r| BoundedVec(r.1 .0.to_vec())) - .collect_vec(); - - logger.info("Finished public key gossip"); - - let frost_state = FrostState { - public_key, - public_key_frost_format, - party: Arc::new(party), - }; - - Ok((frost_state, signatures)) -} - -pub async fn run_dkg( - signer: &mut Party, - rng: &mut RNG, - n_signers: usize, - mut tx_to_network: futures::channel::mpsc::UnboundedSender, - mut rx_from_network: futures::channel::mpsc::UnboundedReceiver>, - logger: &DebugLogger, -) -> Result, JobError> { - // Broadcast our party_id, shares, and key_ids to each other - let party_id = signer.party_id; - let shares: HashMap = signer.get_shares().into_iter().collect(); - let key_ids = signer.key_ids.clone(); - logger.info(format!( - "Our party ID: {party_id} | Our key IDS: {key_ids:?}" - )); - let poly_commitment = signer.get_poly_commitment(rng); - - let message = FrostMessage::Keygen { - party_id, - shares: shares.clone(), - key_ids: key_ids.clone(), - poly_commitment: poly_commitment.clone(), - }; - - // Send the message - tx_to_network.send(message).await.map_err(|err| JobError { - reason: format!("Error sending FROST message: {err:?}"), - })?; - - let mut received_shares = HashMap::new(); - let mut received_key_ids = HashMap::new(); - let mut received_poly_commitments = HashMap::new(); - // insert our own shared into the received map - received_shares.insert(party_id, shares); - received_key_ids.insert(party_id, key_ids); - received_poly_commitments.insert(party_id, poly_commitment); - - // Wait for n_signers to send their messages to us - while received_shares.len() < n_signers { - match rx_from_network.next().await { - Some(Ok(FrostMessage::Keygen { - party_id, - shares, - key_ids, - poly_commitment, - })) => { - if party_id != signer.party_id { - logger.trace(format!( - "Received shares from {party_id} with key ids: {key_ids:?}" - )); - received_shares.insert(party_id, shares); - received_key_ids.insert(party_id, key_ids); - received_poly_commitments.insert(party_id, poly_commitment); - } - } - - Some(evt) => logger.warn(format!("Received unexpected FROST event: {evt:?}")), - - None => { - return Err(JobError { - reason: "NetListen connection died".to_string(), - }) - } - } - } - - logger.trace(format!( - "Received shares: {:?}", - received_shares.keys().collect::>() - )); - // Generate the party_shares: for each key id we own, we take our received key share at that - // index - let party_shares = signer - .key_ids - .iter() - .copied() - .map(|key_id| { - let mut key_shares = HashMap::new(); - - for (id, shares) in &received_shares { - key_shares.insert(*id, shares[&key_id]); - } - - (key_id, key_shares.into_iter().collect()) - }) - .collect(); - let polys = received_poly_commitments - .iter() - .sorted_by(|a, b| a.0.cmp(b.0)) - .map(|r| r.1.clone()) - .collect_vec(); - - signer - .compute_secret(&party_shares, &polys) - .map_err(|err| JobError { - reason: err.to_string(), - })?; - - logger.info("Keygen finished computing secret"); - Ok(received_poly_commitments) -} diff --git a/remappings.txt b/remappings.txt index e69de29..61021c7 100644 --- a/remappings.txt +++ b/remappings.txt @@ -0,0 +1,2 @@ +forge-std-1.9.4/=dependencies/forge-std-1.9.4/src/ +tnt-core=dependencies/tnt-core-0.1.0/src/ diff --git a/soldeer.lock b/soldeer.lock deleted file mode 100644 index 4177f1a..0000000 --- a/soldeer.lock +++ /dev/null @@ -1 +0,0 @@ -dependencies = [] diff --git a/src/context.rs b/src/context.rs index 4639dae..77c6848 100644 --- a/src/context.rs +++ b/src/context.rs @@ -20,6 +20,8 @@ const NETWORK_PROTOCOL: &str = "/wsts/frost/1.0.0"; pub struct WstsContext { #[config] pub config: sdk::config::StdGadgetConfiguration, + #[call_id] + pub call_id: Option, pub network_backend: Arc, pub store: Arc>, pub identity: ecdsa::Pair, @@ -47,6 +49,7 @@ impl WstsContext { Ok(Self { store, + call_id: None, identity, config, network_backend: Arc::new(NetworkMultiplexer::new(gossip_handle)), diff --git a/src/keygen.rs b/src/keygen.rs index 3f99db8..a0f290b 100644 --- a/src/keygen.rs +++ b/src/keygen.rs @@ -10,7 +10,7 @@ use gadget_sdk::{ job, network::round_based_compat::NetworkDeliveryWrapper, tangle_subxt::tangle_testnet_runtime::api::services::events::JobCalled, - ByteBuf, Error as GadgetError, + Error as GadgetError, }; use sp_core::ecdsa::Public; use std::collections::BTreeMap; @@ -18,7 +18,7 @@ use wsts::v2::Party; #[job( id = 0, - params(n), + params(t), event_listener( listener = TangleEventListener, pre_processor = services_pre_processor, @@ -28,7 +28,7 @@ use wsts::v2::Party; /// Runs a distributed key generation (DKG) process using the WSTS protocol /// /// # Arguments -/// * `n` - Number of parties participating in the DKG +/// * `t` - The threshold for the DKG /// * `context` - The DFNS context containing network and storage configuration /// /// # Returns @@ -40,8 +40,7 @@ use wsts::v2::Party; /// - Failed to get party information /// - MPC protocol execution failed /// - Serialization of results failed -pub async fn keygen(n: u16, context: WstsContext) -> Result { - let t = n - 1; +pub async fn keygen(t: u16, context: WstsContext) -> Result, GadgetError> { // Get configuration and compute deterministic values let blueprint_id = context .blueprint_id() @@ -51,9 +50,6 @@ pub async fn keygen(n: u16, context: WstsContext) -> Result Result Result, pub n_signers: usize, pub party: Arc>>, + pub public_key_frost_format: Vec, } impl WstsState { @@ -129,6 +131,15 @@ where .map_err(|err| KeygenError::MpcError(err.to_string()))?; let party = signer.save(); + + // Convert the WSTS group key into a FROST-compatible format + let group_point = party.group_key; + let compressed_group_point = group_point.compress(); + let verifying_key = VerifyingKey::deserialize(compressed_group_point.data).map_err(|e| { + KeygenError::MpcError(format!("Failed to convert group key to VerifyingKey: {e}")) + })?; + let public_key_frost_format = verifying_key.serialize().as_ref().to_vec(); + state.public_key_frost_format = public_key_frost_format; state.party = Arc::new(parking_lot::Mutex::new(Some(party))); gadget_sdk::info!("Keygen finished computing secret"); diff --git a/src/lib.rs b/src/lib.rs index 5e2b943..0408be5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,8 @@ pub mod context; pub mod keygen; pub(crate) mod keygen_state_machine; +pub mod signing; +pub(crate) mod signing_state_machine; pub(crate) mod utils; const META_SALT: &str = "wsts-protocol"; diff --git a/src/main.rs b/src/main.rs index 78a84cc..b639e14 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,9 +18,11 @@ async fn main() { let tangle_config = TangleConfig::default(); let keygen = wsts_blueprint::keygen::KeygenEventHandler::new(&env, context.clone()).await?; + let signing = wsts_blueprint::signing::SignEventHandler::new(&env, context.clone()).await?; BlueprintRunner::new(tangle_config, env.clone()) .job(keygen) + .job(signing) .run() .await?; diff --git a/src/signing.rs b/src/signing.rs new file mode 100644 index 0000000..2c0068e --- /dev/null +++ b/src/signing.rs @@ -0,0 +1,143 @@ +use std::collections::BTreeMap; + +use crate::context::WstsContext; +use gadget_sdk::contexts::MPCContext; +use gadget_sdk::{ + event_listener::tangle::{ + jobs::{services_post_processor, services_pre_processor}, + TangleEventListener, + }, + job, + network::round_based_compat::NetworkDeliveryWrapper, + tangle_subxt::tangle_testnet_runtime::api::services::events::JobCalled, + Error as GadgetError, +}; +use sp_core::ecdsa::Public; + +/// Configuration constants for the WSTS signing process +const SIGNING_SALT: &str = "wsts-signing"; + +#[job( + id = 1, + params(keygen_call_id, message), + event_listener( + listener = TangleEventListener, + pre_processor = services_pre_processor, + post_processor = services_post_processor, + ), +)] +/// Signs a message using the WSTS protocol with a previously generated key +/// +/// # Arguments +/// * `message` - The message to sign as a byte vector +/// * `context` - The DFNS context containing network and storage configuration +/// +/// # Returns +/// Returns the signature as a byte vector on success +/// +/// # Errors +/// Returns an error if: +/// - Failed to retrieve blueprint ID or call ID +/// - Failed to retrieve the key entry +/// - Signing process failed +pub async fn sign( + keygen_call_id: u64, + message: Vec, + context: WstsContext, +) -> Result, GadgetError> { + // let message = message.into_bytes(); + // Get configuration and compute deterministic values + let blueprint_id = context + .blueprint_id() + .map_err(|e| SigningError::ContextError(e.to_string()))?; + + let call_id = context + .current_call_id() + .await + .map_err(|e| SigningError::ContextError(e.to_string()))?; + + // Setup party information + let (i, operators) = context + .get_party_index_and_operators() + .await + .map_err(|e| SigningError::ContextError(e.to_string()))?; + + let parties: BTreeMap = operators + .into_iter() + .enumerate() + .map(|(j, (_, ecdsa))| (j as u16, ecdsa)) + .collect(); + + let n = parties.len() as u16; + let i = i as u16; + + // Compute hash for key retrieval. Must use the call_id of the keygen job + let (meta_hash, deterministic_hash) = + crate::compute_deterministic_hashes(n, blueprint_id, keygen_call_id, SIGNING_SALT); + + // Retrieve the key entry + let store_key = hex::encode(meta_hash); + let state = context + .store + .get(&store_key) + .ok_or_else(|| SigningError::ContextError("Key entry not found".to_string()))?; + + gadget_sdk::info!( + "Starting WSTS Signing for party {i}, n={n}, eid={}", + hex::encode(deterministic_hash) + ); + + let network = NetworkDeliveryWrapper::new( + context.network_backend.clone(), + i, + deterministic_hash, + parties.clone(), + ); + + let mut rng = rand::rngs::OsRng; + + let network = round_based::party::MpcParty::connected(network); + + let output = + crate::signing_state_machine::wsts_signing_protocol(network, &state, message, &mut rng) + .await?; + + let signature_frost_format = output.signature_frost_format.clone(); + Ok(signature_frost_format) +} + +#[derive(Debug, thiserror::Error)] +pub enum SigningError { + #[error("Failed to serialize data: {0}")] + SerializationError(String), + + #[error("MPC protocol error: {0}")] + MpcError(String), + + #[error("Context error: {0}")] + ContextError(String), + + #[error("Delivery error: {0}")] + DeliveryError(String), + + #[error("Invalid public key")] + InvalidPublicKey, + + #[error("Invalid signature")] + InvalidSignature, + + #[error("Invalid FROST signature")] + InvalidFrostSignature, + + #[error("Invalid FROST verifying key")] + InvalidFrostVerifyingKey, + + #[error("Invalid FROST verification")] + InvalidFrostVerification, +} + +impl From for gadget_sdk::Error { + fn from(err: SigningError) -> Self { + gadget_sdk::Error::Other(err.to_string()) + } +} diff --git a/src/signing_state_machine.rs b/src/signing_state_machine.rs new file mode 100644 index 0000000..941bff5 --- /dev/null +++ b/src/signing_state_machine.rs @@ -0,0 +1,310 @@ +use rand::{CryptoRng, RngCore}; +use round_based::{ + rounds_router::{simple_store::RoundInput, RoundsRouter}, + Delivery, MessageDestination, Mpc, MpcParty, ProtocolMessage, +}; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::keygen_state_machine::{HasRecipient, WstsState}; +use crate::signing::SigningError; +use frost_taproot::{Ciphersuite, Secp256K1Taproot, VerifyingKey}; +use itertools::Itertools; +use p256k1::point::Point; +use p256k1::scalar::Scalar; +use round_based::SinkExt; +use serde::{Deserialize, Serialize}; +use wsts::common::Signature; +use wsts::v2::Party; +use wsts::{ + common::{PublicNonce, SignatureShare}, + v2::{PartyState, SignatureAggregator}, +}; + +#[derive(Default, Serialize, Deserialize, Clone)] +pub struct WstsSigningState { + pub party_id: u32, + pub party_key_ids: HashMap>, + pub party_nonces: HashMap, + pub signature_shares: HashMap, + pub n_signers: usize, + pub threshold: u32, + pub message: Vec, + pub public_key_frost_format: Vec, + pub party: Arc>>, + pub aggregated_signature: Option>, + pub signature_frost_format: Vec, +} + +#[derive(Serialize, Deserialize, Default, Clone)] +#[allow(non_snake_case)] +pub struct SerializeableSignature { + pub R: Point, + /// The sum of the party signatures + pub z: Scalar, +} + +impl From for SerializeableSignature { + fn from(sig: Signature) -> Self { + SerializeableSignature { R: sig.R, z: sig.z } + } +} + +impl From for Signature { + fn from(sig: SerializeableSignature) -> Self { + Signature { R: sig.R, z: sig.z } + } +} + +impl WstsSigningState { + pub fn new( + party_id: u32, + n_signers: usize, + threshold: u32, + message: Vec, + public_key_frost_format: Vec, + ) -> Self { + WstsSigningState { + party_id, + n_signers, + threshold, + message, + public_key_frost_format, + ..Default::default() + } + } +} + +#[derive(ProtocolMessage, Serialize, Deserialize, Clone)] +#[allow(clippy::large_enum_variant)] +pub enum Msg { + Round1(Round1Msg), + Round2(Round2Msg), +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct Round1Msg { + source: u32, + key_ids: Vec, + nonce: PublicNonce, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct Round2Msg { + source: u32, + signature_share: SignatureShare, +} + +pub async fn wsts_signing_protocol( + network: M, + keygen_state: &WstsState, + message: Vec, + rng: &mut R, +) -> Result +where + M: Mpc, +{ + let (mut signer, threshold) = { + let lock = keygen_state.party.lock(); + let state = lock + .as_ref() + .ok_or_else(|| SigningError::ContextError("Party not found".to_string()))?; + let threshold = state.threshold; + let signer = Party::load(state); + drop(lock); + (signer, threshold) + }; + + let n_signers = keygen_state.n_signers; + let MpcParty { delivery, .. } = network.into_party(); + let (incomings, mut outgoings) = delivery.split(); + let mut state = WstsSigningState::new( + signer.party_id, + n_signers, + threshold, + message.clone(), + keygen_state.public_key_frost_format.clone(), + ); + + let mut rounds = RoundsRouter::builder(); + let round1 = rounds.add_round(RoundInput::::broadcast( + state.party_id as _, + n_signers as _, + )); + let round2 = rounds.add_round(RoundInput::::broadcast( + state.party_id as _, + n_signers as _, + )); + let mut rounds = rounds.listen(incomings); + + // Round 1: Generate and broadcast nonce + let nonce = signer.gen_nonce(rng); + let key_ids = signer.key_ids.clone(); + + let my_round1 = Round1Msg { + source: signer.party_id, + key_ids: key_ids.clone(), + nonce: nonce.clone(), + }; + + let msg = Msg::Round1(my_round1.clone()); + send_message::(msg, &mut outgoings).await?; + + let round1_msgs = rounds + .complete(round1) + .await + .map_err(|err| SigningError::MpcError(err.to_string()))?; + + let round1_msgs: HashMap = round1_msgs + .into_iter_including_me(my_round1) + .map(|r| (r.source, r)) + .collect(); + + // Process round 1 messages + for (party_id, msg) in round1_msgs { + state.party_key_ids.insert(party_id, msg.key_ids); + state.party_nonces.insert(party_id, msg.nonce); + } + + // Sort and prepare for signing + let party_ids = state + .party_key_ids + .keys() + .copied() + .sorted_by(|a, b| a.cmp(b)) + .collect_vec(); + let party_key_ids = state + .party_key_ids + .clone() + .into_iter() + .sorted_by(|a, b| a.0.cmp(&b.0)) + .flat_map(|r| r.1) + .collect_vec(); + let party_nonces = state + .party_nonces + .clone() + .into_iter() + .sorted_by(|a, b| a.0.cmp(&b.0)) + .map(|r| r.1) + .collect_vec(); + + // Round 2: Generate and broadcast signature share + let signature_share = signer.sign(&message, &party_ids, &party_key_ids, &party_nonces); + + let my_round2 = Round2Msg { + source: signer.party_id, + signature_share: signature_share.clone(), + }; + + let msg = Msg::Round2(my_round2.clone()); + send_message::(msg, &mut outgoings).await?; + + let round2_msgs = rounds + .complete(round2) + .await + .map_err(|err| SigningError::MpcError(err.to_string()))?; + + let round2_msgs: HashMap = round2_msgs + .into_iter_including_me(my_round2) + .map(|r| (r.source, r)) + .collect(); + + // Process round 2 messages + for (party_id, msg) in round2_msgs { + state.signature_shares.insert(party_id, msg.signature_share); + } + + // Sort signature shares and aggregate + let signature_shares = state + .signature_shares + .clone() + .into_iter() + .sorted_by(|a, b| a.0.cmp(&b.0)) + .map(|r| r.1) + .collect_vec(); + + let public_key_comm = keygen_state + .poly_commitments + .iter() + .sorted_by(|r1, r2| r1.0.cmp(r2.0)) + .map(|r| r.1.clone()) + .collect_vec(); + + // Create signature aggregator + let mut sig_agg = + SignatureAggregator::new(state.n_signers as u32, state.threshold, public_key_comm) + .map_err(|err| SigningError::MpcError(err.to_string()))?; + + // Generate final signature + let wsts_sig = sig_agg + .sign(&message, &party_nonces, &signature_shares, &party_key_ids) + .map_err(|err| SigningError::MpcError(err.to_string()))?; + + // Verify WSTS signature + let compressed_public_key = + p256k1::point::Compressed::try_from(state.public_key_frost_format.as_slice()) + .map_err(|_| SigningError::InvalidPublicKey)?; + + let wsts_public_key = p256k1::point::Point::try_from(&compressed_public_key) + .map_err(|_| SigningError::InvalidPublicKey)?; + + if !wsts_sig.verify(&wsts_public_key, &message) { + return Err(SigningError::InvalidSignature); + } + + // Convert to FROST format and verify + let mut signature_bytes = [0u8; 33 + 32]; + let r = wsts_sig.R.compress(); + signature_bytes[0..33].copy_from_slice(&r.data); + signature_bytes[33..].copy_from_slice(&wsts_sig.z.to_bytes()); + + state.signature_frost_format = signature_bytes.to_vec(); + + let frost_signature = frost_taproot::Signature::deserialize(signature_bytes) + .map_err(|_| SigningError::InvalidFrostSignature)?; + + let frost_verifying_key = + VerifyingKey::deserialize(state.public_key_frost_format.clone().try_into().unwrap()) + .map_err(|_| SigningError::InvalidFrostVerifyingKey)?; + + if !frost_signature.is_valid() { + return Err(SigningError::InvalidFrostSignature); + } + + frost_verifying_key + .verify(&message, &frost_signature) + .map_err(|_| SigningError::InvalidFrostVerification)?; + + Secp256K1Taproot::verify_signature(&message, &frost_signature, &frost_verifying_key) + .map_err(|_| SigningError::InvalidFrostVerification)?; + + state.party = Arc::new(parking_lot::Mutex::new(Some(signer.save()))); + state.aggregated_signature = Some(Arc::new(wsts_sig.into())); + + Ok(state) +} + +impl HasRecipient for Msg { + fn recipient(&self) -> MessageDestination { + match self { + Msg::Round1(_) | Msg::Round2(_) => MessageDestination::AllParties, + } + } +} + +pub async fn send_message( + msg: Msg, + tx: &mut <::Delivery as Delivery>::Send, +) -> Result<(), SigningError> +where + Msg: HasRecipient, + M: Mpc, +{ + let recipient = msg.recipient(); + let msg = round_based::Outgoing { recipient, msg }; + tx.send(msg) + .await + .map_err(|e| SigningError::DeliveryError(e.to_string()))?; + + Ok(()) +} diff --git a/src/utils.rs b/src/utils.rs index e1b5820..8526042 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -2,7 +2,7 @@ use gadget_sdk::Error; pub fn validate_parameters(n: u32, k: u32, t: u32) -> Result<(), Error> { if k % n != 0 { return Err(Error::Job { - reason: "K % N != 0".to_string(), + reason: format!("K({k} % N({n} != 0"), }); } diff --git a/tests/wsts.rs b/tests/wsts.rs index a408b40..01e6aec 100644 --- a/tests/wsts.rs +++ b/tests/wsts.rs @@ -1,30 +1,145 @@ #[cfg(test)] mod e2e { - use std::sync::atomic::AtomicU64; - + use ::blueprint_test_utils::test_ext::new_test_ext_blueprint_manager; + pub use ::blueprint_test_utils::{ + run_test_blueprint_manager, setup_log, submit_job, wait_for_completion_of_tangle_job, Job, + }; + use blueprint_test_utils::tangle::NodeConfig; use blueprint_test_utils::*; use wsts_blueprint::keygen::KEYGEN_JOB_ID; + use wsts_blueprint::signing::SIGN_JOB_ID; + const N: usize = 3; const T: usize = 2; - // The macro takes this variable as an argument, and will update it so that - // when we pass the signing arguments, we can pass the associated keygen call id - static KEYGEN_CALL_ID: AtomicU64 = AtomicU64::new(0); - - mpc_generate_keygen_and_signing_tests!( - "./", - N, - T, - KEYGEN_JOB_ID, - [InputValue::Uint16(N as _)], - [], - KEYGEN_JOB_ID, - [ - InputValue::Uint16(N as _), - InputValue::Uint64(KEYGEN_CALL_ID.load(std::sync::atomic::Ordering::SeqCst)), - InputValue::Bytes(BoundedVec(vec![1, 2, 3])) - ], - [], - KEYGEN_CALL_ID, - ); + #[tokio::test(flavor = "multi_thread")] + async fn test_blueprint() { + setup_log(); + + let tmp_dir = ::blueprint_test_utils::tempfile::TempDir::new().unwrap(); + let tmp_dir_path = format!("{}", tmp_dir.path().display()); + + new_test_ext_blueprint_manager::( + tmp_dir_path, + run_test_blueprint_manager, + NodeConfig::new(false), + ) + .await + .execute_with_async(|client, handles, blueprint, _| async move { + let keypair = handles[0].sr25519_id().clone(); + let service = &blueprint.services[KEYGEN_JOB_ID as usize]; + + let service_id = service.id; + gadget_sdk::info!( + "Submitting KEYGEN job {KEYGEN_JOB_ID} with service ID {service_id}", + ); + + let job_args = vec![(InputValue::Uint16(T as u16))]; + let call_id = get_next_call_id(client) + .await + .expect("Failed to get next job id") + .saturating_sub(1); + let job = submit_job( + client, + &keypair, + service_id, + Job::from(KEYGEN_JOB_ID), + job_args, + call_id, + ) + .await + .expect("Failed to submit job"); + + let keygen_call_id = job.call_id; + + gadget_sdk::info!( + "Submitted KEYGEN job {} with service ID {service_id} has call id {keygen_call_id}", + KEYGEN_JOB_ID + ); + + let job_results = wait_for_completion_of_tangle_job(client, service_id, keygen_call_id, T) + .await + .expect("Failed to wait for job completion"); + + assert_eq!(job_results.service_id, service_id); + assert_eq!(job_results.call_id, keygen_call_id); + + let expected_outputs = vec![]; + if !expected_outputs.is_empty() { + assert_eq!( + job_results.result.len(), + expected_outputs.len(), + "Number of keygen outputs doesn't match expected" + ); + + for (result, expected) in job_results + .result + .into_iter() + .zip(expected_outputs.into_iter()) + { + assert_eq!(result, expected); + } + } else { + gadget_sdk::info!("No expected outputs specified, skipping keygen verification"); + } + + gadget_sdk::info!("Keygen job completed successfully! Moving on to signing ..."); + + let service = &blueprint.services[0]; + let service_id = service.id; + gadget_sdk::info!( + "Submitting SIGNING job {} with service ID {service_id}", + SIGN_JOB_ID + ); + + let job_args = vec![ + InputValue::Uint64(keygen_call_id), + InputValue::List(BoundedVec(vec![ + InputValue::Uint8(1), + InputValue::Uint8(2), + InputValue::Uint8(3), + ])), + ]; + + let job = submit_job( + client, + &keypair, + service_id, + Job::from(SIGN_JOB_ID), + job_args, + call_id + 1, + ) + .await + .expect("Failed to submit job"); + + let signing_call_id = job.call_id; + gadget_sdk::info!( + "Submitted SIGNING job {SIGN_JOB_ID} with service ID {service_id} has call id {signing_call_id}", + ); + + let job_results = wait_for_completion_of_tangle_job(client, service_id, signing_call_id, T) + .await + .expect("Failed to wait for job completion"); + + let expected_outputs = vec![]; + if !expected_outputs.is_empty() { + assert_eq!( + job_results.result.len(), + expected_outputs.len(), + "Number of signing outputs doesn't match expected" + ); + + for (result, expected) in job_results + .result + .into_iter() + .zip(expected_outputs.into_iter()) + { + assert_eq!(result, expected); + } + } else { + gadget_sdk::info!("No expected outputs specified, skipping signing verification"); + } + }) + .await + } }