From d3807b27ad6c6259a848088bb9789535e2278814 Mon Sep 17 00:00:00 2001 From: Szymon Wojtulewicz Date: Mon, 11 Mar 2024 12:04:07 +0000 Subject: [PATCH] rewrite transaction waiter --- Cargo.toml | 1 - crates/account_sdk/Cargo.toml | 7 +- .../src/deploy_contract/pending.rs | 5 +- .../account_sdk/src/tests/webauthn/utils.rs | 1 + crates/account_sdk/src/transaction_waiter.rs | 180 +++++++----------- 5 files changed, 75 insertions(+), 119 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7afb42fc..aa6516f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,6 @@ sha2 = "0.10" starknet = "0.8" starknet-crypto = "0.6" thiserror = "1" -tokio = { version = "1", features = ["macros"] } toml = "0.8" u256-literal = "1" url = "2" diff --git a/crates/account_sdk/Cargo.toml b/crates/account_sdk/Cargo.toml index 8f315ddd..01753ed1 100644 --- a/crates/account_sdk/Cargo.toml +++ b/crates/account_sdk/Cargo.toml @@ -25,9 +25,14 @@ sha2.workspace = true starknet.workspace = true starknet-crypto.workspace = true thiserror.workspace = true -tokio.workspace = true +tokio = { version = "1", features = ["macros", "time"]} +async-std = { version = "1.12.0"} toml.workspace = true u256-literal.workspace = true url.workspace = true wasm-bindgen.workspace = true webauthn-rs-proto.workspace = true + +[features] +tokio-runtime = [] +async-std-runtime = [] diff --git a/crates/account_sdk/src/deploy_contract/pending.rs b/crates/account_sdk/src/deploy_contract/pending.rs index f332247d..c66d2c19 100644 --- a/crates/account_sdk/src/deploy_contract/pending.rs +++ b/crates/account_sdk/src/deploy_contract/pending.rs @@ -9,7 +9,7 @@ use super::deployment::DeployResult; pub struct PendingTransaction<'a, P, T> where - &'a P: Provider + Send, + &'a P: Provider + Send + Sync, { transaction_result: T, transaction_hash: FieldElement, @@ -18,7 +18,7 @@ where impl<'a, P, T> PendingTransaction<'a, P, T> where - &'a P: Provider + Send, + &'a P: Provider + Send + Sync, { pub fn new(transaction_result: T, transaction_hash: FieldElement, client: &'a P) -> Self { PendingTransaction { @@ -29,6 +29,7 @@ where } pub async fn wait_for_completion(self) -> T { TransactionWaiter::new(self.transaction_hash, &self.client) + .wait() .await .unwrap(); self.transaction_result diff --git a/crates/account_sdk/src/tests/webauthn/utils.rs b/crates/account_sdk/src/tests/webauthn/utils.rs index e2b4cae4..de6a2180 100644 --- a/crates/account_sdk/src/tests/webauthn/utils.rs +++ b/crates/account_sdk/src/tests/webauthn/utils.rs @@ -98,6 +98,7 @@ where set_execution.send().await.unwrap(); TransactionWaiter::new(set_tx, self.runner.client()) + .wait() .await .unwrap(); } diff --git a/crates/account_sdk/src/transaction_waiter.rs b/crates/account_sdk/src/transaction_waiter.rs index fce853ef..6b0dc9ed 100644 --- a/crates/account_sdk/src/transaction_waiter.rs +++ b/crates/account_sdk/src/transaction_waiter.rs @@ -1,18 +1,12 @@ -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::{Duration, Instant}; +use std::time::Duration; -use futures::FutureExt; +use futures::{select, FutureExt}; use starknet::core::types::{ ExecutionResult, FieldElement, MaybePendingTransactionReceipt, PendingTransactionReceipt, StarknetError, TransactionFinalityStatus, TransactionReceipt, }; use starknet::providers::{Provider, ProviderError}; -type GetReceiptResult = Result; -type GetReceiptFuture<'a> = Pin + Send + 'a>>; - #[derive(Debug, thiserror::Error)] pub enum TransactionWaitingError { #[error("request timed out")] @@ -63,68 +57,56 @@ pub struct TransactionWaiter<'a, P: Provider> { must_succeed: bool, /// Poll the transaction every `interval` miliseconds. Miliseconds are used so that /// we can be more precise with the polling interval. Defaults to 250ms. - interval: Interval, + interval: Duration, /// The maximum amount of time to wait for the transaction to achieve the desired status. An /// error will be returned if it is unable to finish within the `timeout` duration. Defaults to /// 60 seconds. timeout: Duration, /// The provider to use for polling the transaction. provider: &'a P, - /// The future that gets the transaction receipt. - receipt_request_fut: Option>, - /// The time when the transaction waiter was first polled. - started_at: Option, } -struct Interval { - last: Instant, - interval: Duration, -} +enum Sleeper {} -impl Interval { - fn new(interval: Duration) -> Self { - Self { - last: Instant::now(), - interval, +impl Sleeper { + pub async fn sleep(interval: Duration) { + #[cfg(feature = "tokio-runtime")] + { + tokio::time::sleep(interval).await; } - } - - fn poll_tick(&mut self, _cx: &mut Context<'_>) -> Poll<()> { - if self.last.elapsed() > self.interval { - self.last = Instant::now(); - Poll::Ready(()) - } else { - std::thread::sleep(self.interval); - Poll::Pending + #[cfg(feature = "async-std-runtime")] + { + async_std::task::sleep(interval).await; + } + #[cfg(not(any(feature = "tokio-runtime", feature = "async-std-runtime")))] + { + compile_error!("At least one of the features 'tokio-runtime' or 'async-std-runtime' must be enabled"); } } - } #[allow(dead_code)] impl<'a, P> TransactionWaiter<'a, P> where - P: Provider + Send, + P: Provider + Send + Sync, { const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300); const DEFAULT_INTERVAL: Duration = Duration::from_millis(2500); pub fn new(tx: FieldElement, provider: &'a P) -> Self { Self { - provider, + provider: provider.into(), tx_hash: tx, - started_at: None, must_succeed: true, finality_status: None, - receipt_request_fut: None, timeout: Self::DEFAULT_TIMEOUT, - interval: Interval::new(Self::DEFAULT_INTERVAL), + interval: Self::DEFAULT_INTERVAL, } } pub fn with_interval(self, milisecond: u64) -> Self { Self { - interval: Interval::new(Duration::from_millis(milisecond)), + interval: Duration::from_millis(milisecond), ..self } } @@ -139,36 +121,44 @@ where pub fn with_timeout(self, timeout: Duration) -> Self { Self { timeout, ..self } } -} -impl<'a, P> Future for TransactionWaiter<'a, P> -where - P: Provider + Send, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - if this.started_at.is_none() { - this.started_at = Some(Instant::now()); + pub async fn wait(self) -> Result { + let timeout = self.timeout.clone(); + select! { + result = self.wait_without_timeout().fuse() => result, + _ = Sleeper::sleep(timeout).fuse() => Err(TransactionWaitingError::Timeout), } - + } + async fn wait_without_timeout( + self, + ) -> Result { loop { - if let Some(started_at) = this.started_at { - if started_at.elapsed() > this.timeout { - return Poll::Ready(Err(TransactionWaitingError::Timeout)); - } - } + let now = std::time::Instant::now(); + let transaction = self.provider.get_transaction_receipt(self.tx_hash).await; + match transaction { + Ok(receipt) => match &receipt { + MaybePendingTransactionReceipt::PendingReceipt(r) => { + if self.finality_status.is_none() { + if self.must_succeed { + return match execution_status_from_pending_receipt(r) { + ExecutionResult::Succeeded => Ok(receipt), + ExecutionResult::Reverted { reason } => { + Err(TransactionWaitingError::TransactionReverted( + reason.clone(), + )) + } + }; + } + return Ok(receipt); + } + } - if let Some(mut flush) = this.receipt_request_fut.take() { - match flush.poll_unpin(cx) { - Poll::Ready(res) => match res { - Ok(receipt) => match &receipt { - MaybePendingTransactionReceipt::PendingReceipt(r) => { - if this.finality_status.is_none() { - if this.must_succeed { - let res = match execution_status_from_pending_receipt(r) { + MaybePendingTransactionReceipt::Receipt(r) => { + if let Some(finality_status) = self.finality_status { + match finality_status_from_receipt(r) { + status if status == finality_status => { + if self.must_succeed { + return match execution_status_from_receipt(r) { ExecutionResult::Succeeded => Ok(receipt), ExecutionResult::Reverted { reason } => { Err(TransactionWaitingError::TransactionReverted( @@ -176,66 +166,26 @@ where )) } }; - return Poll::Ready(res); } - - return Poll::Ready(Ok(receipt)); + return Ok(receipt); } - } - MaybePendingTransactionReceipt::Receipt(r) => { - if let Some(finality_status) = this.finality_status { - match finality_status_from_receipt(r) { - status if status == finality_status => { - if this.must_succeed { - let res = match execution_status_from_receipt(r) { - ExecutionResult::Succeeded => Ok(receipt), - ExecutionResult::Reverted { reason } => { - Err(TransactionWaitingError::TransactionReverted( - reason.clone(), - )) - } - }; - return Poll::Ready(res); - } - - return Poll::Ready(Ok(receipt)); - } - - _ => {} - } - } else { - return Poll::Ready(Ok(receipt)); - } + _ => {} } - }, - - Err(ProviderError::StarknetError( - StarknetError::TransactionHashNotFound, - )) => {} - - Err(e) => { - return Poll::Ready(Err(TransactionWaitingError::Provider(e))); + } else { + return Ok(receipt); } - }, - - Poll::Pending => { - this.receipt_request_fut = Some(flush); - return Poll::Pending; } - } - } + }, + + Err(ProviderError::StarknetError(StarknetError::TransactionHashNotFound)) => {} - if this.interval.poll_tick(cx).is_ready() { - this.receipt_request_fut = Some(Box::pin( - this.provider.get_transaction_receipt(this.tx_hash), - )); - } else { - break; + Err(e) => { + return Err(TransactionWaitingError::Provider(e)); + } } + Sleeper::sleep(self.interval.checked_sub(now.elapsed()).unwrap_or_default()).await; } - - Poll::Pending } }