diff --git a/README.md b/README.md index 590bf95..1bb064f 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ You can configure `CircuitBreaker` by the struct `Settings`: ```go type Settings struct { Name string - MaxRequests uint32 + ReadyToClose func(counts Counts) (bool, bool) Interval time.Duration Timeout time.Duration ReadyToTrip func(counts Counts) bool @@ -38,9 +38,11 @@ type Settings struct { - `Name` is the name of the `CircuitBreaker`. -- `MaxRequests` is the maximum number of requests allowed to pass through - when the `CircuitBreaker` is half-open. - If `MaxRequests` is 0, `CircuitBreaker` allows only 1 request. +- `ReadyToClose` is called with a copy of `Counts` for each request in the half-open state. + If `ReadyToClose` returns true, the `CircuitBreaker` will be placed into the close state. + If `ReadyToClose` returns false, the `CircuitBreaker` will be placed into the open state if second returned value is true. + If `ReadyToClose` is nil, default `ReadyToClose` is used. + Default `ReadyToClose` returns true when the number of consecutive successes is more than 1. - `Interval` is the cyclic period of the closed state for `CircuitBreaker` to clear the internal `Counts`, described later in this section. diff --git a/gobreaker.go b/gobreaker.go index 7503a27..0ea10b2 100644 --- a/gobreaker.go +++ b/gobreaker.go @@ -20,8 +20,6 @@ const ( ) var ( - // ErrTooManyRequests is returned when the CB state is half open and the requests count is over the cb maxRequests - ErrTooManyRequests = errors.New("too many requests") // ErrOpenState is returned when the CB state is open ErrOpenState = errors.New("circuit breaker is open") ) @@ -80,9 +78,11 @@ func (c *Counts) clear() { // // Name is the name of the CircuitBreaker. // -// MaxRequests is the maximum number of requests allowed to pass through -// when the CircuitBreaker is half-open. -// If MaxRequests is 0, the CircuitBreaker allows only 1 request. +// ReadyToClose is called with a copy of Counts for each request in the half-open state. +// If ReadyToClose returns true, the CircuitBreaker will be placed into the close state. +// If ReadyToClose returns false, the CircuitBreaker will be placed into the open state if second returned value is true. +// If ReadyToClose is nil, default ReadyToClose is used. +// Default ReadyToClose returns true when the number of consecutive successes is more than 1. // // Interval is the cyclic period of the closed state // for the CircuitBreaker to clear the internal Counts. @@ -105,7 +105,7 @@ func (c *Counts) clear() { // If IsSuccessful is nil, default IsSuccessful is used, which returns false for all non-nil errors. type Settings struct { Name string - MaxRequests uint32 + ReadyToClose func(counts Counts) (bool, bool) Interval time.Duration Timeout time.Duration ReadyToTrip func(counts Counts) bool @@ -116,7 +116,7 @@ type Settings struct { // CircuitBreaker is a state machine to prevent sending requests that are likely to fail. type CircuitBreaker struct { name string - maxRequests uint32 + readyToClose func(counts Counts) (bool, bool) interval time.Duration timeout time.Duration readyToTrip func(counts Counts) bool @@ -144,10 +144,10 @@ func NewCircuitBreaker(st Settings) *CircuitBreaker { cb.name = st.Name cb.onStateChange = st.OnStateChange - if st.MaxRequests == 0 { - cb.maxRequests = 1 + if st.ReadyToClose == nil { + cb.readyToClose = defaultReadyToClose } else { - cb.maxRequests = st.MaxRequests + cb.readyToClose = st.ReadyToClose } if st.Interval <= 0 { @@ -193,6 +193,10 @@ func defaultReadyToTrip(counts Counts) bool { return counts.ConsecutiveFailures > 5 } +func defaultReadyToClose(counts Counts) (bool, bool) { + return counts.ConsecutiveSuccesses >= 1, true +} + func defaultIsSuccessful(err error) bool { return err == nil } @@ -282,8 +286,6 @@ func (cb *CircuitBreaker) beforeRequest() (uint64, error) { if state == StateOpen { return generation, ErrOpenState - } else if state == StateHalfOpen && cb.counts.Requests >= cb.maxRequests { - return generation, ErrTooManyRequests } cb.counts.onRequest() @@ -313,7 +315,7 @@ func (cb *CircuitBreaker) onSuccess(state State, now time.Time) { cb.counts.onSuccess() case StateHalfOpen: cb.counts.onSuccess() - if cb.counts.ConsecutiveSuccesses >= cb.maxRequests { + if ok, _ := cb.readyToClose(cb.counts); ok { cb.setState(StateClosed, now) } } @@ -326,8 +328,12 @@ func (cb *CircuitBreaker) onFailure(state State, now time.Time) { if cb.readyToTrip(cb.counts) { cb.setState(StateOpen, now) } + case StateHalfOpen: - cb.setState(StateOpen, now) + cb.counts.onFailure() + if ok, nowOpen := cb.readyToClose(cb.counts); !ok && nowOpen { + cb.setState(StateOpen, now) + } } } @@ -337,11 +343,13 @@ func (cb *CircuitBreaker) currentState(now time.Time) (State, uint64) { if !cb.expiry.IsZero() && cb.expiry.Before(now) { cb.toNewGeneration(now) } + case StateOpen: if cb.expiry.Before(now) { cb.setState(StateHalfOpen, now) } } + return cb.state, cb.generation } @@ -372,8 +380,10 @@ func (cb *CircuitBreaker) toNewGeneration(now time.Time) { } else { cb.expiry = now.Add(cb.interval) } + case StateOpen: cb.expiry = now.Add(cb.timeout) + default: // StateHalfOpen cb.expiry = zero } diff --git a/gobreaker_test.go b/gobreaker_test.go index c20169d..6ba792e 100644 --- a/gobreaker_test.go +++ b/gobreaker_test.go @@ -80,7 +80,21 @@ func causePanic(cb *CircuitBreaker) error { func newCustom() *CircuitBreaker { var customSt Settings customSt.Name = "cb" - customSt.MaxRequests = 3 + customSt.ReadyToClose = func(counts Counts) (bool, bool) { + if counts.ConsecutiveSuccesses >= 3 { + return true, true + } + + numReqs := counts.Requests + failureRatio := float64(counts.TotalFailures) / float64(numReqs) + + var nowOpen bool + if numReqs >= 3 && failureRatio >= 0.6 { + nowOpen = true + } + + return false, nowOpen + } customSt.Interval = time.Duration(30) * time.Second customSt.Timeout = time.Duration(90) * time.Second customSt.ReadyToTrip = func(counts Counts) bool { @@ -126,7 +140,7 @@ func TestStateConstants(t *testing.T) { func TestNewCircuitBreaker(t *testing.T) { defaultCB := NewCircuitBreaker(Settings{}) assert.Equal(t, "", defaultCB.name) - assert.Equal(t, uint32(1), defaultCB.maxRequests) + assert.NotNil(t, defaultCB.readyToClose) assert.Equal(t, time.Duration(0), defaultCB.interval) assert.Equal(t, time.Duration(60)*time.Second, defaultCB.timeout) assert.NotNil(t, defaultCB.readyToTrip) @@ -137,7 +151,7 @@ func TestNewCircuitBreaker(t *testing.T) { customCB := newCustom() assert.Equal(t, "cb", customCB.name) - assert.Equal(t, uint32(3), customCB.maxRequests) + assert.NotNil(t, customCB.readyToClose) assert.Equal(t, time.Duration(30)*time.Second, customCB.interval) assert.Equal(t, time.Duration(90)*time.Second, customCB.timeout) assert.NotNil(t, customCB.readyToTrip) @@ -148,7 +162,7 @@ func TestNewCircuitBreaker(t *testing.T) { negativeDurationCB := newNegativeDurationCB() assert.Equal(t, "ncb", negativeDurationCB.name) - assert.Equal(t, uint32(1), negativeDurationCB.maxRequests) + assert.NotNil(t, negativeDurationCB.readyToClose) assert.Equal(t, time.Duration(0)*time.Second, negativeDurationCB.interval) assert.Equal(t, time.Duration(60)*time.Second, negativeDurationCB.timeout) assert.NotNil(t, negativeDurationCB.readyToTrip) @@ -256,7 +270,7 @@ func TestCustomCircuitBreaker(t *testing.T) { ch := succeedLater(customCB, time.Duration(100)*time.Millisecond) // 3 consecutive successes time.Sleep(time.Duration(50) * time.Millisecond) assert.Equal(t, Counts{3, 2, 0, 2, 0}, customCB.counts) - assert.Error(t, succeed(customCB)) // over MaxRequests + assert.Nil(t, succeed(customCB)) assert.Nil(t, <-ch) assert.Equal(t, StateClosed, customCB.State()) assert.Equal(t, Counts{0, 0, 0, 0, 0}, customCB.counts)