Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ? and return in the decorated function #4

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion retry-if-macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ exclude = [".idea", ".gitignore", ".github"]
[dependencies]
proc-macro2 = "1.0.81"
quote = "1.0.36"
syn = { version = "2.0.60", features = ["full"] }
syn = { version = "2.0.60", features = ["full", "visit-mut"] }
tracing = { version = "0.1.40", optional = true }

[features]
Expand Down
59 changes: 43 additions & 16 deletions retry-if-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,33 @@ use proc_macro2::Ident;
use quote::quote;
use syn::parse::Parser;
use syn::punctuated::Punctuated;
use syn::ItemFn;
use syn::{ItemFn, parse_quote, Expr};
use syn::visit_mut;
use syn::visit_mut::VisitMut;

struct BlockModifier;

impl VisitMut for BlockModifier {
fn visit_expr_mut(&mut self, i: &mut Expr) {
if let Expr::Try(expr_try) = i {
let expr = &expr_try.expr;
*i = Expr::Match(parse_quote! {
match #expr {
Ok(val) => val,
Err(err) => break 'block Err(err),
}
});
} else if let Expr::Return(expr_return) = i {
let return_value = &expr_return.expr;
*i = Expr::Break(parse_quote! {
break 'block #return_value
});
}

// Important: continue visiting to find nested expressions
visit_mut::visit_expr_mut(self, i);
}
}

/// Decorate a function with a given retry configuration.
///
Expand Down Expand Up @@ -67,25 +93,31 @@ pub fn retry(
///
/// This takes the underlying function as [ItemFn], the backoff configuration (defined in parent
/// crate) as an `&Ident`, and the `&Ident` for the retry function
fn decorate_fn(impl_fn: ItemFn, config: &Ident, retry_if: &Ident) -> proc_macro::TokenStream {
fn decorate_fn(mut impl_fn: ItemFn, config: &Ident, retry_if: &Ident) -> proc_macro::TokenStream {
let attrs = &impl_fn.attrs;
let vis = &impl_fn.vis;
let sig = &impl_fn.sig;

(BlockModifier {}).visit_block_mut(&mut impl_fn.block);
let block = &impl_fn.block;

(quote! {
#(#attrs)*
#vis #sig {
let start = tokio::time::Instant::now();
let backoff_max = #config.backoff_max.unwrap_or(std::time::Duration::MAX);
let max_tries = #config.max_retries;
let mut attempt = 0;

let mut result = #block;
loop {
let result = 'block: {
#block
};

for attempt in 0..max_tries {
if !#retry_if(&result) {
break;
// Return result if retry isn't required, or if we ran out of attempts
if !#retry_if(&result) || attempt >= #config.max_retries {
return result;
}
attempt += 1;

let retry_wait = #config.t_wait
.mul_f64(#config.backoff.powi(attempt))
Expand All @@ -94,24 +126,19 @@ fn decorate_fn(impl_fn: ItemFn, config: &Ident, retry_if: &Ident) -> proc_macro:
if let Some(max_wait) = #config.t_wait_max {
let now = tokio::time::Instant::now();
let since_start = now - start;
let will_exceed_time = since_start + retry_wait > max_wait;

if will_exceed_time {
break;
// Return if our overall duration is going to exceed `max_wait`
if since_start + retry_wait > max_wait {
return result;
}
}

if cfg!(feature = "tracing") {
tracing::info!("Sleeping {retry_wait:?} on attempt {attempt}");
}

tokio::time::sleep(retry_wait).await;

result = #block;
}

result
}
})
.into()
.into()
}
45 changes: 45 additions & 0 deletions tests/retry_method_with_try.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//! This example tests a backoff configuration using a function that returns false for if it should
//! retry, thus resulting in no retries at all.
use std::num::ParseIntError;
use std::str::FromStr;
use retry_if::{retry, ExponentialBackoffConfig};
use std::time::Duration;
use tokio::time::Instant;

const BACKOFF_CONFIG: ExponentialBackoffConfig = ExponentialBackoffConfig {
max_retries: 5,
t_wait: Duration::from_secs(1),
backoff: 2.0,
t_wait_max: None,
backoff_max: None,
};

#[tokio::test]
async fn test_retry_with_try_operator_on_result() {
fn retry_if(_: &Result<i32, ParseIntError>) -> bool {
false
}

#[retry(BACKOFF_CONFIG, retry_if)]
async fn method(int: &str) -> Result<i32, ParseIntError> {
return Ok(i32::from_str(int)?);
}

let result = method("3").await;
assert_eq!(Ok(3), result);
}

#[tokio::test]
async fn test_retry_with_try_operator_on_option() {
fn retry_if(_: &Result<i32, ParseIntError>) -> bool {
false
}

#[retry(BACKOFF_CONFIG, retry_if)]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails because the usage of ? on an Option gets modified by the block modifier, looking like below:

let result = 'block: {
    {
        break 'block Some(
            match i32::from_str(int).ok() {  // produces Option
                Ok(val) => val,                         // creates matching code for Result
                Err(err) => break 'block Err(err),
            },
        );
    }
};

async fn method(int: &str) -> Option<i32> {
return Some(i32::from_str(int).ok()?);
}

let result = method("3").await;
assert_eq!(Some(3), result);
}
Loading