From 60652e593d212d97df91e6089dc820062830960b Mon Sep 17 00:00:00 2001 From: Cody Jones Date: Thu, 11 Mar 2021 15:33:37 -0800 Subject: [PATCH] Add RetryWithContext() and respect cancellation while sleeping This is a breaking change for developers using custom strategies. However, there shouldn't be any impact on code using the strategies included in this package. Because the time.Sleep() call is now abstracted, strategies are tested without actually sleeping, and the strategies don't need to be aware of contexts. Context is passed through to the action in case the action is defined separately from the retry.RetryWithContext() call, is reused at multiple points, etc. --- README.md | 14 +++-- retry.go | 50 ++++++++++++++-- retry_test.go | 121 ++++++++++++++++++++++++++++++++++---- strategy/strategy.go | 18 +++--- strategy/strategy_test.go | 107 ++++++++++++++------------------- 5 files changed, 216 insertions(+), 94 deletions(-) diff --git a/README.md b/README.md index bccf4de..fd4ddc5 100644 --- a/README.md +++ b/README.md @@ -55,12 +55,13 @@ logFile.Chdir() // Do something with the file ### HTTP request with strategies and backoff ```go -var response *http.Response +action := func(ctx context.Context, attempt uint) error { + var response *http.Response -action := func(attempt uint) error { - var err error - - response, err = http.Get("https://api.github.com/repos/Rican7/retry") + req, err := NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/Rican7/retry", nil) + if err == nil { + response, err = c.Do(req) + } if nil == err && nil != response && response.StatusCode > 200 { err = fmt.Errorf("failed to fetch (attempt #%d) with status code: %d", attempt, response.StatusCode) @@ -69,7 +70,8 @@ action := func(attempt uint) error { return err } -err := retry.Retry( +err := retry.RetryWithContext( + context.TODO(), action, strategy.Limit(5), strategy.Backoff(backoff.Fibonacci(10*time.Millisecond)), diff --git a/retry.go b/retry.go index 15015db..b7fcd44 100644 --- a/retry.go +++ b/retry.go @@ -4,20 +4,48 @@ // Copyright © 2016 Trevor N. Suarez (Rican7) package retry -import "github.com/Rican7/retry/strategy" +import ( + "context" + "time" + + "github.com/Rican7/retry/strategy" +) // Action defines a callable function that package retry can handle. type Action func(attempt uint) error +// ActionWithContext defines a callable function that package retry can handle. +type ActionWithContext func(ctx context.Context, attempt uint) error + // Retry takes an action and performs it, repetitively, until successful. // // Optionally, strategies may be passed that assess whether or not an attempt // should be made. func Retry(action Action, strategies ...strategy.Strategy) error { + return RetryWithContext(context.Background(), func(ctx context.Context, attempt uint) error { return action(attempt) }, strategies...) +} + +// RetryWithContext takes an action and performs it, repetitively, until successful +// or the context is done. +// +// Optionally, strategies may be passed that assess whether or not an attempt +// should be made. +// +// Context errors take precedence over action errors so this commonplace test: +// +// err := retry.RetryWithContext(...) +// if err != nil { return err } +// +// will pass cancellation errors up the call chain. +func RetryWithContext(ctx context.Context, action ActionWithContext, strategies ...strategy.Strategy) error { var err error - for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, strategies...); attempt++ { - err = action(attempt) + for attempt := uint(0); (0 == attempt || nil != err) && shouldAttempt(attempt, sleepFunc(ctx), strategies...) && nil == ctx.Err(); attempt++ { + err = action(ctx, attempt) + } + + if ctx.Err() != nil { + return ctx.Err() } return err @@ -25,12 +53,24 @@ func Retry(action Action, strategies ...strategy.Strategy) error { // shouldAttempt evaluates the provided strategies with the given attempt to // determine if the Retry loop should make another attempt. -func shouldAttempt(attempt uint, strategies ...strategy.Strategy) bool { +func shouldAttempt(attempt uint, sleep func(time.Duration), strategies ...strategy.Strategy) bool { shouldAttempt := true for i := 0; shouldAttempt && i < len(strategies); i++ { - shouldAttempt = shouldAttempt && strategies[i](attempt) + shouldAttempt = shouldAttempt && strategies[i](attempt, sleep) } return shouldAttempt } + +// sleepFunc returns a function with the same signature as time.Sleep() +// that blocks for the given duration, but will return sooner if the context is +// cancelled or its deadline passes. +func sleepFunc(ctx context.Context) func(time.Duration) { + return func(d time.Duration) { + select { + case <-ctx.Done(): + case <-time.After(d): + } + } +} diff --git a/retry_test.go b/retry_test.go index 8340a15..97b96b7 100644 --- a/retry_test.go +++ b/retry_test.go @@ -1,10 +1,16 @@ package retry import ( + "context" "errors" "testing" + "time" ) +// timeMarginOfError represents the acceptable amount of time that may pass for +// a time-based (sleep) unit before considering invalid. +const timeMarginOfError = time.Millisecond + func TestRetry(t *testing.T) { action := func(attempt uint) error { return nil @@ -47,8 +53,99 @@ func TestRetryRetriesUntilNoErrorReturned(t *testing.T) { } } +func TestRetryWithContextChecksContextAfterLastAttempt(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + strategy := func(attempt uint, sleep func(time.Duration)) bool { + if attempt == 0 { + return true + } + + cancel() + return false + } + + action := func(ctx context.Context, attempt uint) error { + return errors.New("erroring") + } + + err := RetryWithContext(ctx, action, strategy) + + if context.Canceled != err { + t.Error("expected a context error") + } +} + +func TestRetryWithContextCancelStopsAttempts(t *testing.T) { + var numCalls int + + ctx, cancel := context.WithCancel(context.Background()) + + action := func(ctx context.Context, attempt uint) error { + numCalls++ + + if numCalls == 1 { + cancel() + return ctx.Err() + } + + return nil + } + + err := RetryWithContext(ctx, action) + + if 1 != numCalls { + t.Errorf("expected the action to be tried once, not %d times", numCalls) + } + + if context.Canceled != err { + t.Error("expected a context error") + } +} + +func TestRetryWithContextSleepIsInterrupted(t *testing.T) { + const sleepDuration = 100 * timeMarginOfError + fullySleptBy := time.Now().Add(sleepDuration) + + strategy := func(attempt uint, sleep func(time.Duration)) bool { + if attempt > 0 { + sleep(sleepDuration) + } + return attempt <= 1 + } + + var numCalls int + + action := func(ctx context.Context, attempt uint) error { + numCalls++ + return errors.New("erroring") + } + + stopAfter := 10 * timeMarginOfError + deadline := time.Now().Add(stopAfter) + ctx, _ := context.WithDeadline(context.Background(), deadline) + + err := RetryWithContext(ctx, action, strategy) + + if time.Now().Before(deadline) { + t.Errorf("expected to stop after %s", stopAfter) + } + + if time.Now().After(fullySleptBy) { + t.Errorf("expected to stop before %s", sleepDuration) + } + + if 1 != numCalls { + t.Errorf("expected the action to be tried once, not %d times", numCalls) + } + + if context.DeadlineExceeded != err { + t.Error("expected a context error") + } +} + func TestShouldAttempt(t *testing.T) { - shouldAttempt := shouldAttempt(1) + shouldAttempt := shouldAttempt(1, time.Sleep) if !shouldAttempt { t.Error("expected to return true") @@ -58,23 +155,23 @@ func TestShouldAttempt(t *testing.T) { func TestShouldAttemptWithStrategy(t *testing.T) { const attemptNumberShouldReturnFalse = 7 - strategy := func(attempt uint) bool { + strategy := func(attempt uint, sleep func(time.Duration)) bool { return (attemptNumberShouldReturnFalse != attempt) } - should := shouldAttempt(1, strategy) + should := shouldAttempt(1, time.Sleep, strategy) if !should { t.Error("expected to return true") } - should = shouldAttempt(1+attemptNumberShouldReturnFalse, strategy) + should = shouldAttempt(1+attemptNumberShouldReturnFalse, time.Sleep, strategy) if !should { t.Error("expected to return true") } - should = shouldAttempt(attemptNumberShouldReturnFalse, strategy) + should = shouldAttempt(attemptNumberShouldReturnFalse, time.Sleep, strategy) if should { t.Error("expected to return false") @@ -82,39 +179,39 @@ func TestShouldAttemptWithStrategy(t *testing.T) { } func TestShouldAttemptWithMultipleStrategies(t *testing.T) { - trueStrategy := func(attempt uint) bool { + trueStrategy := func(attempt uint, sleep func(time.Duration)) bool { return true } - falseStrategy := func(attempt uint) bool { + falseStrategy := func(attempt uint, sleep func(time.Duration)) bool { return false } - should := shouldAttempt(1, trueStrategy) + should := shouldAttempt(1, time.Sleep, trueStrategy) if !should { t.Error("expected to return true") } - should = shouldAttempt(1, falseStrategy) + should = shouldAttempt(1, time.Sleep, falseStrategy) if should { t.Error("expected to return false") } - should = shouldAttempt(1, trueStrategy, trueStrategy, trueStrategy) + should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, trueStrategy) if !should { t.Error("expected to return true") } - should = shouldAttempt(1, falseStrategy, falseStrategy, falseStrategy) + should = shouldAttempt(1, time.Sleep, falseStrategy, falseStrategy, falseStrategy) if should { t.Error("expected to return false") } - should = shouldAttempt(1, trueStrategy, trueStrategy, falseStrategy) + should = shouldAttempt(1, time.Sleep, trueStrategy, trueStrategy, falseStrategy) if should { t.Error("expected to return false") diff --git a/strategy/strategy.go b/strategy/strategy.go index a315fa0..5d4b2b6 100644 --- a/strategy/strategy.go +++ b/strategy/strategy.go @@ -18,22 +18,22 @@ import ( // The strategy will be passed an "attempt" number on each successive retry // iteration, starting with a `0` value before the first attempt is actually // made. This allows for a pre-action delay, etc. -type Strategy func(attempt uint) bool +type Strategy func(attempt uint, sleep func(time.Duration)) bool // Limit creates a Strategy that limits the number of attempts that Retry will // make. func Limit(attemptLimit uint) Strategy { - return func(attempt uint) bool { - return (attempt <= attemptLimit) + return func(attempt uint, sleep func(time.Duration)) bool { + return attempt <= attemptLimit } } // Delay creates a Strategy that waits the given duration before the first // attempt is made. func Delay(duration time.Duration) Strategy { - return func(attempt uint) bool { + return func(attempt uint, sleep func(time.Duration)) bool { if 0 == attempt { - time.Sleep(duration) + sleep(duration) } return true @@ -44,7 +44,7 @@ func Delay(duration time.Duration) Strategy { // the first. If the number of attempts is greater than the number of durations // provided, then the strategy uses the last duration provided. func Wait(durations ...time.Duration) Strategy { - return func(attempt uint) bool { + return func(attempt uint, sleep func(time.Duration)) bool { if 0 < attempt && 0 < len(durations) { durationIndex := int(attempt - 1) @@ -52,7 +52,7 @@ func Wait(durations ...time.Duration) Strategy { durationIndex = len(durations) - 1 } - time.Sleep(durations[durationIndex]) + sleep(durations[durationIndex]) } return true @@ -68,9 +68,9 @@ func Backoff(algorithm backoff.Algorithm) Strategy { // BackoffWithJitter creates a Strategy that waits before each attempt, with a // duration as defined by the given backoff.Algorithm and jitter.Transformation. func BackoffWithJitter(algorithm backoff.Algorithm, transformation jitter.Transformation) Strategy { - return func(attempt uint) bool { + return func(attempt uint, sleep func(time.Duration)) bool { if 0 < attempt { - time.Sleep(transformation(algorithm(attempt))) + sleep(transformation(algorithm(attempt))) } return true diff --git a/strategy/strategy_test.go b/strategy/strategy_test.go index 17488f5..6b9c644 100644 --- a/strategy/strategy_test.go +++ b/strategy/strategy_test.go @@ -5,45 +5,38 @@ import ( "time" ) -// timeMarginOfError represents the acceptable amount of time that may pass for -// a time-based (sleep) unit before considering invalid. -const timeMarginOfError = time.Millisecond - func TestLimit(t *testing.T) { const attemptLimit = 3 strategy := Limit(attemptLimit) - if !strategy(1) { + if !strategy(1, time.Sleep) { t.Error("strategy expected to return true") } - if !strategy(2) { + if !strategy(2, time.Sleep) { t.Error("strategy expected to return true") } - if !strategy(3) { + if !strategy(3, time.Sleep) { t.Error("strategy expected to return true") } - if strategy(4) { + if strategy(4, time.Sleep) { t.Error("strategy expected to return false") } } func TestDelay(t *testing.T) { - const delayDuration = time.Duration(10 * timeMarginOfError) + const delayDuration = time.Duration(10) strategy := Delay(delayDuration) - if now := time.Now(); !strategy(0) || delayDuration > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - time.Duration(delayDuration), - ) + if spy, actual := sleepSpy(); !strategy(0, spy) || delayDuration != *actual { + t.Errorf("strategy expected to return true in %s", delayDuration) } - if now := time.Now(); !strategy(5) || (delayDuration/10) < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(5, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } } @@ -51,71 +44,59 @@ func TestDelay(t *testing.T) { func TestWait(t *testing.T) { strategy := Wait() - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } - if now := time.Now(); !strategy(999) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(999, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } } func TestWaitWithDuration(t *testing.T) { - const waitDuration = time.Duration(10 * timeMarginOfError) + const waitDuration = time.Duration(10) strategy := Wait(waitDuration) - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } - if now := time.Now(); !strategy(1) || waitDuration > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - time.Duration(waitDuration), - ) + if spy, actual := sleepSpy(); !strategy(1, spy) || waitDuration != *actual { + t.Errorf("strategy expected to return true in %s", waitDuration) } } func TestWaitWithMultipleDurations(t *testing.T) { waitDurations := []time.Duration{ - time.Duration(10 * timeMarginOfError), - time.Duration(20 * timeMarginOfError), - time.Duration(30 * timeMarginOfError), - time.Duration(40 * timeMarginOfError), + time.Duration(10), + time.Duration(20), + time.Duration(30), + time.Duration(40), } strategy := Wait(waitDurations...) - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } - if now := time.Now(); !strategy(1) || waitDurations[0] > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - time.Duration(waitDurations[0]), - ) + if spy, actual := sleepSpy(); !strategy(1, spy) || waitDurations[0] != *actual { + t.Errorf("strategy expected to return true in %s", waitDurations[0]) } - if now := time.Now(); !strategy(3) || waitDurations[2] > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - waitDurations[2], - ) + if spy, actual := sleepSpy(); !strategy(3, spy) || waitDurations[2] != *actual { + t.Errorf("strategy expected to return true in %s", waitDurations[2]) } - if now := time.Now(); !strategy(999) || waitDurations[len(waitDurations)-1] > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - waitDurations[len(waitDurations)-1], - ) + if spy, actual := sleepSpy(); !strategy(999, spy) || waitDurations[len(waitDurations)-1] != *actual { + t.Errorf("strategy expected to return true in %s", waitDurations[len(waitDurations)-1]) } } func TestBackoff(t *testing.T) { - const backoffDuration = time.Duration(10 * timeMarginOfError) - const algorithmDurationBase = timeMarginOfError + const backoffDuration = time.Duration(10) + const algorithmDurationBase = time.Duration(1) algorithm := func(attempt uint) time.Duration { return backoffDuration - (algorithmDurationBase * time.Duration(attempt)) @@ -123,48 +104,42 @@ func TestBackoff(t *testing.T) { strategy := Backoff(algorithm) - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } for i := uint(1); i < 10; i++ { expectedResult := algorithm(i) - if now := time.Now(); !strategy(i) || expectedResult > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - expectedResult, - ) + if spy, actual := sleepSpy(); !strategy(i, spy) || expectedResult != *actual { + t.Errorf("strategy expected to return true in %s", expectedResult) } } } func TestBackoffWithJitter(t *testing.T) { - const backoffDuration = time.Duration(10 * timeMarginOfError) - const algorithmDurationBase = timeMarginOfError + const backoffDuration = time.Duration(20) + const algorithmDurationBase = time.Duration(1) algorithm := func(attempt uint) time.Duration { return backoffDuration - (algorithmDurationBase * time.Duration(attempt)) } transformation := func(duration time.Duration) time.Duration { - return duration - time.Duration(10*timeMarginOfError) + return duration - time.Duration(10) } strategy := BackoffWithJitter(algorithm, transformation) - if now := time.Now(); !strategy(0) || timeMarginOfError < time.Since(now) { + if spy, actual := sleepSpy(); !strategy(0, spy) || 0 != *actual { t.Error("strategy expected to return true in ~0 time") } for i := uint(1); i < 10; i++ { expectedResult := transformation(algorithm(i)) - if now := time.Now(); !strategy(i) || expectedResult > time.Since(now) { - t.Errorf( - "strategy expected to return true in %s", - expectedResult, - ) + if spy, actual := sleepSpy(); !strategy(i, spy) || expectedResult != *actual { + t.Errorf("strategy expected to return true in %s", expectedResult) } } } @@ -173,7 +148,7 @@ func TestNoJitter(t *testing.T) { transformation := noJitter() for i := uint(0); i < 10; i++ { - duration := time.Duration(i) * timeMarginOfError + duration := time.Duration(i) result := transformation(duration) expected := duration @@ -182,3 +157,11 @@ func TestNoJitter(t *testing.T) { } } } + +// sleepSpy returns a spy for the time.Sleep function that sums the +// durations passed to it. +func sleepSpy() (func(time.Duration), *time.Duration) { + var actual time.Duration + + return func(d time.Duration) { actual += d }, &actual +}