Skip to content

Commit

Permalink
rewrite transaction waiter
Browse files Browse the repository at this point in the history
  • Loading branch information
piniom committed Mar 11, 2024
1 parent a9f7adf commit d3807b2
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 119 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ sha2 = "0.10"
starknet = "0.8"
starknet-crypto = "0.6"
thiserror = "1"
tokio = { version = "1", features = ["macros"] }
toml = "0.8"
u256-literal = "1"
url = "2"
Expand Down
7 changes: 6 additions & 1 deletion crates/account_sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ sha2.workspace = true
starknet.workspace = true
starknet-crypto.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio = { version = "1", features = ["macros", "time"]}
async-std = { version = "1.12.0"}
toml.workspace = true
u256-literal.workspace = true
url.workspace = true
wasm-bindgen.workspace = true
webauthn-rs-proto.workspace = true

[features]
tokio-runtime = []
async-std-runtime = []
5 changes: 3 additions & 2 deletions crates/account_sdk/src/deploy_contract/pending.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::deployment::DeployResult;

pub struct PendingTransaction<'a, P, T>
where
&'a P: Provider + Send,
&'a P: Provider + Send + Sync,
{
transaction_result: T,
transaction_hash: FieldElement,
Expand All @@ -18,7 +18,7 @@ where

impl<'a, P, T> PendingTransaction<'a, P, T>
where
&'a P: Provider + Send,
&'a P: Provider + Send + Sync,
{
pub fn new(transaction_result: T, transaction_hash: FieldElement, client: &'a P) -> Self {
PendingTransaction {
Expand All @@ -29,6 +29,7 @@ where
}
pub async fn wait_for_completion(self) -> T {
TransactionWaiter::new(self.transaction_hash, &self.client)
.wait()
.await
.unwrap();
self.transaction_result
Expand Down
1 change: 1 addition & 0 deletions crates/account_sdk/src/tests/webauthn/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ where
set_execution.send().await.unwrap();

TransactionWaiter::new(set_tx, self.runner.client())
.wait()
.await
.unwrap();
}
Expand Down
180 changes: 65 additions & 115 deletions crates/account_sdk/src/transaction_waiter.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use std::time::Duration;

use futures::FutureExt;
use futures::{select, FutureExt};
use starknet::core::types::{
ExecutionResult, FieldElement, MaybePendingTransactionReceipt, PendingTransactionReceipt,
StarknetError, TransactionFinalityStatus, TransactionReceipt,
};
use starknet::providers::{Provider, ProviderError};

type GetReceiptResult = Result<MaybePendingTransactionReceipt, ProviderError>;
type GetReceiptFuture<'a> = Pin<Box<dyn Future<Output = GetReceiptResult> + Send + 'a>>;

#[derive(Debug, thiserror::Error)]
pub enum TransactionWaitingError {
#[error("request timed out")]
Expand Down Expand Up @@ -63,68 +57,56 @@ pub struct TransactionWaiter<'a, P: Provider> {
must_succeed: bool,
/// Poll the transaction every `interval` miliseconds. Miliseconds are used so that
/// we can be more precise with the polling interval. Defaults to 250ms.
interval: Interval,
interval: Duration,
/// The maximum amount of time to wait for the transaction to achieve the desired status. An
/// error will be returned if it is unable to finish within the `timeout` duration. Defaults to
/// 60 seconds.
timeout: Duration,
/// The provider to use for polling the transaction.
provider: &'a P,
/// The future that gets the transaction receipt.
receipt_request_fut: Option<GetReceiptFuture<'a>>,
/// The time when the transaction waiter was first polled.
started_at: Option<Instant>,
}

struct Interval {
last: Instant,
interval: Duration,
}
enum Sleeper {}

impl Interval {
fn new(interval: Duration) -> Self {
Self {
last: Instant::now(),
interval,
impl Sleeper {
pub async fn sleep(interval: Duration) {
#[cfg(feature = "tokio-runtime")]
{
tokio::time::sleep(interval).await;
}
}

fn poll_tick(&mut self, _cx: &mut Context<'_>) -> Poll<()> {
if self.last.elapsed() > self.interval {
self.last = Instant::now();
Poll::Ready(())
} else {
std::thread::sleep(self.interval);
Poll::Pending
#[cfg(feature = "async-std-runtime")]
{
async_std::task::sleep(interval).await;
}
#[cfg(not(any(feature = "tokio-runtime", feature = "async-std-runtime")))]
{
compile_error!("At least one of the features 'tokio-runtime' or 'async-std-runtime' must be enabled");
}
}

}

#[allow(dead_code)]
impl<'a, P> TransactionWaiter<'a, P>
where
P: Provider + Send,
P: Provider + Send + Sync,
{
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
const DEFAULT_INTERVAL: Duration = Duration::from_millis(2500);

pub fn new(tx: FieldElement, provider: &'a P) -> Self {
Self {
provider,
provider: provider.into(),
tx_hash: tx,
started_at: None,
must_succeed: true,
finality_status: None,
receipt_request_fut: None,
timeout: Self::DEFAULT_TIMEOUT,
interval: Interval::new(Self::DEFAULT_INTERVAL),
interval: Self::DEFAULT_INTERVAL,
}
}

pub fn with_interval(self, milisecond: u64) -> Self {
Self {
interval: Interval::new(Duration::from_millis(milisecond)),
interval: Duration::from_millis(milisecond),
..self
}
}
Expand All @@ -139,103 +121,71 @@ where
pub fn with_timeout(self, timeout: Duration) -> Self {
Self { timeout, ..self }
}
}

impl<'a, P> Future for TransactionWaiter<'a, P>
where
P: Provider + Send,
{
type Output = Result<MaybePendingTransactionReceipt, TransactionWaitingError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();

if this.started_at.is_none() {
this.started_at = Some(Instant::now());
pub async fn wait(self) -> Result<MaybePendingTransactionReceipt, TransactionWaitingError> {
let timeout = self.timeout.clone();
select! {
result = self.wait_without_timeout().fuse() => result,
_ = Sleeper::sleep(timeout).fuse() => Err(TransactionWaitingError::Timeout),
}

}
async fn wait_without_timeout(
self,
) -> Result<MaybePendingTransactionReceipt, TransactionWaitingError> {
loop {
if let Some(started_at) = this.started_at {
if started_at.elapsed() > this.timeout {
return Poll::Ready(Err(TransactionWaitingError::Timeout));
}
}
let now = std::time::Instant::now();
let transaction = self.provider.get_transaction_receipt(self.tx_hash).await;
match transaction {
Ok(receipt) => match &receipt {
MaybePendingTransactionReceipt::PendingReceipt(r) => {
if self.finality_status.is_none() {
if self.must_succeed {
return match execution_status_from_pending_receipt(r) {
ExecutionResult::Succeeded => Ok(receipt),
ExecutionResult::Reverted { reason } => {
Err(TransactionWaitingError::TransactionReverted(
reason.clone(),
))
}
};
}
return Ok(receipt);
}
}

if let Some(mut flush) = this.receipt_request_fut.take() {
match flush.poll_unpin(cx) {
Poll::Ready(res) => match res {
Ok(receipt) => match &receipt {
MaybePendingTransactionReceipt::PendingReceipt(r) => {
if this.finality_status.is_none() {
if this.must_succeed {
let res = match execution_status_from_pending_receipt(r) {
MaybePendingTransactionReceipt::Receipt(r) => {
if let Some(finality_status) = self.finality_status {
match finality_status_from_receipt(r) {
status if status == finality_status => {
if self.must_succeed {
return match execution_status_from_receipt(r) {
ExecutionResult::Succeeded => Ok(receipt),
ExecutionResult::Reverted { reason } => {
Err(TransactionWaitingError::TransactionReverted(
reason.clone(),
))
}
};
return Poll::Ready(res);
}

return Poll::Ready(Ok(receipt));
return Ok(receipt);
}
}

MaybePendingTransactionReceipt::Receipt(r) => {
if let Some(finality_status) = this.finality_status {
match finality_status_from_receipt(r) {
status if status == finality_status => {
if this.must_succeed {
let res = match execution_status_from_receipt(r) {
ExecutionResult::Succeeded => Ok(receipt),
ExecutionResult::Reverted { reason } => {
Err(TransactionWaitingError::TransactionReverted(
reason.clone(),
))
}
};
return Poll::Ready(res);
}

return Poll::Ready(Ok(receipt));
}

_ => {}
}
} else {
return Poll::Ready(Ok(receipt));
}
_ => {}
}
},

Err(ProviderError::StarknetError(
StarknetError::TransactionHashNotFound,
)) => {}

Err(e) => {
return Poll::Ready(Err(TransactionWaitingError::Provider(e)));
} else {
return Ok(receipt);
}
},

Poll::Pending => {
this.receipt_request_fut = Some(flush);
return Poll::Pending;
}
}
}
},

Err(ProviderError::StarknetError(StarknetError::TransactionHashNotFound)) => {}

if this.interval.poll_tick(cx).is_ready() {
this.receipt_request_fut = Some(Box::pin(
this.provider.get_transaction_receipt(this.tx_hash),
));
} else {
break;
Err(e) => {
return Err(TransactionWaitingError::Provider(e));
}
}
Sleeper::sleep(self.interval.checked_sub(now.elapsed()).unwrap_or_default()).await;
}

Poll::Pending
}
}

Expand Down

0 comments on commit d3807b2

Please sign in to comment.