diff --git a/v2/distributed_gobreaker.go b/v2/distributed_gobreaker.go new file mode 100644 index 0000000..2565d90 --- /dev/null +++ b/v2/distributed_gobreaker.go @@ -0,0 +1,213 @@ +package gobreaker + +import ( + "context" + "fmt" + "time" +) + +// SharedState represents the shared state of DistributedCircuitBreaker. +type SharedState struct { + State State `json:"state"` + Generation uint64 `json:"generation"` + Counts Counts `json:"counts"` + Expiry time.Time `json:"expiry"` +} + +type SharedStateStore interface { + GetState(ctx context.Context) (SharedState, error) + SetState(ctx context.Context, state SharedState) error +} + +// DistributedCircuitBreaker extends CircuitBreaker with distributed state storage +type DistributedCircuitBreaker[T any] struct { + *CircuitBreaker[T] + store SharedStateStore +} + +// NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker configured with the given StorageSettings +func NewDistributedCircuitBreaker[T any](store SharedStateStore, settings Settings) *DistributedCircuitBreaker[T] { + cb := NewCircuitBreaker[T](settings) + return &DistributedCircuitBreaker[T]{ + CircuitBreaker: cb, + store: store, + } +} + +func (dcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { + if dcb.store == nil { + return dcb.CircuitBreaker.State() + } + + state, err := dcb.store.GetState(ctx) + if err != nil { + // Fallback to in-memory state if Storage fails + return dcb.CircuitBreaker.State() + } + + now := time.Now() + currentState, _ := dcb.currentState(state, now) + + // Update the state in Storage if it has changed + if currentState != state.State { + state.State = currentState + if err := dcb.store.SetState(ctx, state); err != nil { + // Log the error, but continue with the current state + fmt.Printf("Failed to update state in storage: %v\n", err) + } + } + + return state.State +} + +// Execute runs the given request if the DistributedCircuitBreaker accepts it +func (dcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (T, error) { + if dcb.store == nil { + return dcb.CircuitBreaker.Execute(req) + } + generation, err := dcb.beforeRequest(ctx) + if err != nil { + var zero T + return zero, err + } + + defer func() { + e := recover() + if e != nil { + dcb.afterRequest(ctx, generation, false) + panic(e) + } + }() + + result, err := req() + dcb.afterRequest(ctx, generation, dcb.isSuccessful(err)) + + return result, err +} + +func (dcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) { + state, err := dcb.store.GetState(ctx) + if err != nil { + return 0, err + } + now := time.Now() + currentState, generation := dcb.currentState(state, now) + + if currentState != state.State { + dcb.setState(&state, currentState, now) + err = dcb.store.SetState(ctx, state) + if err != nil { + return 0, err + } + } + + if currentState == StateOpen { + return generation, ErrOpenState + } else if currentState == StateHalfOpen && state.Counts.Requests >= dcb.maxRequests { + return generation, ErrTooManyRequests + } + + state.Counts.onRequest() + err = dcb.store.SetState(ctx, state) + if err != nil { + return 0, err + } + + return generation, nil +} + +func (dcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) { + state, err := dcb.store.GetState(ctx) + if err != nil { + return + } + now := time.Now() + currentState, generation := dcb.currentState(state, now) + if generation != before { + return + } + + if success { + dcb.onSuccess(&state, currentState, now) + } else { + dcb.onFailure(&state, currentState, now) + } + + dcb.store.SetState(ctx, state) +} + +func (dcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) { + if state.State == StateOpen { + state.State = currentState + } + + switch currentState { + case StateClosed: + state.Counts.onSuccess() + case StateHalfOpen: + state.Counts.onSuccess() + if state.Counts.ConsecutiveSuccesses >= dcb.maxRequests { + dcb.setState(state, StateClosed, now) + } + } +} + +func (dcb *DistributedCircuitBreaker[T]) onFailure(state *SharedState, currentState State, now time.Time) { + switch currentState { + case StateClosed: + state.Counts.onFailure() + if dcb.readyToTrip(state.Counts) { + dcb.setState(state, StateOpen, now) + } + case StateHalfOpen: + dcb.setState(state, StateOpen, now) + } +} + +func (dcb *DistributedCircuitBreaker[T]) currentState(state SharedState, now time.Time) (State, uint64) { + switch state.State { + case StateClosed: + if !state.Expiry.IsZero() && state.Expiry.Before(now) { + dcb.toNewGeneration(&state, now) + } + case StateOpen: + if state.Expiry.Before(now) { + dcb.setState(&state, StateHalfOpen, now) + } + } + return state.State, state.Generation +} + +func (dcb *DistributedCircuitBreaker[T]) setState(state *SharedState, newState State, now time.Time) { + if state.State == newState { + return + } + + prev := state.State + state.State = newState + + dcb.toNewGeneration(state, now) + + if dcb.onStateChange != nil { + dcb.onStateChange(dcb.name, prev, newState) + } +} + +func (dcb *DistributedCircuitBreaker[T]) toNewGeneration(state *SharedState, now time.Time) { + state.Generation++ + state.Counts.clear() + + var zero time.Time + switch state.State { + case StateClosed: + if dcb.interval == 0 { + state.Expiry = zero + } else { + state.Expiry = now.Add(dcb.interval) + } + case StateOpen: + state.Expiry = now.Add(dcb.timeout) + default: // StateHalfOpen + state.Expiry = zero + } +} diff --git a/v2/distributed_gobreaker_test.go b/v2/distributed_gobreaker_test.go new file mode 100644 index 0000000..c9edf6f --- /dev/null +++ b/v2/distributed_gobreaker_test.go @@ -0,0 +1,350 @@ +package gobreaker + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var defaultDCB *DistributedCircuitBreaker[any] +var customDCB *DistributedCircuitBreaker[any] + +type storageAdapter struct { + client *redis.Client +} + +func (r *storageAdapter) GetState(ctx context.Context) (SharedState, error) { + var state SharedState + data, err := r.client.Get(ctx, "gobreaker").Bytes() + if len(data) == 0 { + // Key doesn't exist, return default state + return SharedState{State: StateClosed}, nil + } else if err != nil { + return state, err + } + + err = json.Unmarshal(data, &state) + return state, err +} + +func (r *storageAdapter) SetState(ctx context.Context, state SharedState) error { + data, err := json.Marshal(state) + if err != nil { + return err + } + + return r.client.Set(ctx, "gobreaker", data, 0).Err() +} + +func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Miniredis, *redis.Client) { + mr, err := miniredis.Run() + if err != nil { + panic(err) + } + + client := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + storageClient := &storageAdapter{client: client} + + return NewDistributedCircuitBreaker[any](storageClient, Settings{ + Name: "TestBreaker", + MaxRequests: 3, + Interval: time.Second, + Timeout: time.Second * 2, + ReadyToTrip: func(counts Counts) bool { + return counts.ConsecutiveFailures > 5 + }, + }), mr, client +} + +func pseudoSleepStorage(ctx context.Context, dcb *DistributedCircuitBreaker[any], period time.Duration) { + state, _ := dcb.store.GetState(ctx) + + state.Expiry = state.Expiry.Add(-period) + // Reset counts if the interval has passed + if time.Now().After(state.Expiry) { + state.Counts = Counts{} + } + dcb.store.SetState(ctx, state) +} + +func successRequest(ctx context.Context, dcb *DistributedCircuitBreaker[any]) error { + _, err := dcb.Execute(ctx, func() (interface{}, error) { return nil, nil }) + return err +} + +func failRequest(ctx context.Context, dcb *DistributedCircuitBreaker[any]) error { + _, err := dcb.Execute(ctx, func() (interface{}, error) { return nil, errors.New("fail") }) + if err != nil && err.Error() == "fail" { + return nil + } + return err +} + +func TestDistributedCircuitBreakerInitialization(t *testing.T) { + dcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + assert.Equal(t, "TestBreaker", dcb.Name()) + assert.Equal(t, uint32(3), dcb.maxRequests) + assert.Equal(t, time.Second, dcb.interval) + assert.Equal(t, time.Second*2, dcb.timeout) + assert.NotNil(t, dcb.readyToTrip) + + state := dcb.State(ctx) + assert.Equal(t, StateClosed, state) +} + +func TestDistributedCircuitBreakerStateTransitions(t *testing.T) { + dcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + // Check if initial state is closed + assert.Equal(t, StateClosed, dcb.State(ctx)) + + // StateClosed to StateOpen + for i := 0; i < 6; i++ { + assert.NoError(t, failRequest(ctx, dcb)) + } + + assert.Equal(t, StateOpen, dcb.State(ctx)) + + // Ensure requests fail when circuit is open + err := failRequest(ctx, dcb) + assert.Error(t, err) + assert.Equal(t, ErrOpenState, err) + + // Wait for timeout to transition to half-open + pseudoSleepStorage(ctx, dcb, dcb.timeout) + assert.Equal(t, StateHalfOpen, dcb.State(ctx)) + + // StateHalfOpen to StateClosed + for i := 0; i < int(dcb.maxRequests); i++ { + assert.NoError(t, successRequest(ctx, dcb)) + } + assert.Equal(t, StateClosed, dcb.State(ctx)) + + // StateClosed to StateOpen (again) + for i := 0; i < 6; i++ { + assert.NoError(t, failRequest(ctx, dcb)) + } + assert.Equal(t, StateOpen, dcb.State(ctx)) +} + +func TestDistributedCircuitBreakerExecution(t *testing.T) { + dcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + // Test successful execution + result, err := dcb.Execute(ctx, func() (interface{}, error) { + return "success", nil + }) + assert.NoError(t, err) + assert.Equal(t, "success", result) + + // Test failed execution + _, err = dcb.Execute(ctx, func() (interface{}, error) { + return nil, errors.New("test error") + }) + assert.Error(t, err) + assert.Equal(t, "test error", err.Error()) +} + +func TestDistributedCircuitBreakerCounts(t *testing.T) { + dcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + for i := 0; i < 5; i++ { + assert.Nil(t, successRequest(ctx, dcb)) + } + + state, _ := dcb.store.GetState(ctx) + assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts) + + assert.Nil(t, failRequest(ctx, dcb)) + state, _ = dcb.store.GetState(ctx) + assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts) +} + +func TestDistributedCircuitBreakerFallback(t *testing.T) { + dcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + // Test when Storage is unavailable + mr.Close() // Simulate Storage being unavailable + + dcb.store = nil + + state := dcb.State(ctx) + assert.Equal(t, StateClosed, state, "Should fallback to in-memory state when Storage is unavailable") + + // Ensure operations still work without Storage + assert.Nil(t, successRequest(ctx, dcb)) + assert.Nil(t, failRequest(ctx, dcb)) +} + +func TestCustomDistributedCircuitBreaker(t *testing.T) { + mr, err := miniredis.Run() + if err != nil { + panic(err) + } + defer mr.Close() + + client := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + storageClient := &storageAdapter{client: client} + + customDCB = NewDistributedCircuitBreaker[any](storageClient, Settings{ + Name: "CustomBreaker", + MaxRequests: 3, + Interval: time.Second * 30, + Timeout: time.Second * 90, + ReadyToTrip: func(counts Counts) bool { + numReqs := counts.Requests + failureRatio := float64(counts.TotalFailures) / float64(numReqs) + return numReqs >= 3 && failureRatio >= 0.6 + }, + }) + + ctx := context.Background() + + t.Run("Initialization", func(t *testing.T) { + assert.Equal(t, "CustomBreaker", customDCB.Name()) + assert.Equal(t, StateClosed, customDCB.State(ctx)) + }) + + t.Run("Counts and State Transitions", func(t *testing.T) { + // Perform 5 successful and 5 failed requests + for i := 0; i < 5; i++ { + assert.NoError(t, successRequest(ctx, customDCB)) + assert.NoError(t, failRequest(ctx, customDCB)) + } + + state, err := customDCB.store.GetState(ctx) + assert.NoError(t, err) + assert.Equal(t, StateClosed, state.State) + assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts) + + // Perform one more successful request + assert.NoError(t, successRequest(ctx, customDCB)) + state, err = customDCB.store.GetState(ctx) + assert.NoError(t, err) + assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts) + + // Simulate time passing to reset counts + pseudoSleepStorage(ctx, customDCB, time.Second*30) + + // Perform requests to trigger StateOpen + assert.NoError(t, successRequest(ctx, customDCB)) + assert.NoError(t, failRequest(ctx, customDCB)) + assert.NoError(t, failRequest(ctx, customDCB)) + + // Check if the circuit breaker is now open + assert.Equal(t, StateOpen, customDCB.State(ctx)) + + state, err = customDCB.store.GetState(ctx) + assert.NoError(t, err) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts) + }) + + t.Run("Timeout and Half-Open State", func(t *testing.T) { + // Simulate timeout to transition to half-open state + pseudoSleepStorage(ctx, customDCB, time.Second*90) + assert.Equal(t, StateHalfOpen, customDCB.State(ctx)) + + // Successful requests in half-open state should close the circuit + for i := 0; i < 3; i++ { + assert.NoError(t, successRequest(ctx, customDCB)) + } + assert.Equal(t, StateClosed, customDCB.State(ctx)) + }) +} + +func TestCustomDistributedCircuitBreakerStateTransitions(t *testing.T) { + // Setup + var stateChange StateChange + customSt := Settings{ + Name: "cb", + MaxRequests: 3, + Interval: 5 * time.Second, + Timeout: 5 * time.Second, + ReadyToTrip: func(counts Counts) bool { + return counts.ConsecutiveFailures >= 2 + }, + OnStateChange: func(name string, from State, to State) { + stateChange = StateChange{name, from, to} + }, + } + + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("Failed to start miniredis: %v", err) + } + defer mr.Close() + + client := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + storageClient := &storageAdapter{client: client} + + dcb := NewDistributedCircuitBreaker[any](storageClient, customSt) + + ctx := context.Background() + + // Test case + t.Run("Circuit Breaker State Transitions", func(t *testing.T) { + // Initial state should be Closed + assert.Equal(t, StateClosed, dcb.State(ctx)) + + // Cause two consecutive failures to trip the circuit + for i := 0; i < 2; i++ { + err := failRequest(ctx, dcb) + assert.NoError(t, err, "Fail request should not return an error") + } + + // Circuit should now be Open + assert.Equal(t, StateOpen, dcb.State(ctx)) + assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) + + // Requests should fail immediately when circuit is Open + err := successRequest(ctx, dcb) + assert.Error(t, err) + assert.Equal(t, ErrOpenState, err) + + // Simulate timeout to transition to Half-Open + pseudoSleepStorage(ctx, dcb, 6*time.Second) + assert.Equal(t, StateHalfOpen, dcb.State(ctx)) + assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) + + // Successful requests in Half-Open state should close the circuit + for i := 0; i < int(dcb.maxRequests); i++ { + err := successRequest(ctx, dcb) + assert.NoError(t, err) + } + + // Circuit should now be Closed + assert.Equal(t, StateClosed, dcb.State(ctx)) + assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) + }) +} diff --git a/v2/go.mod b/v2/go.mod index 9e1537a..eb1204f 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -5,7 +5,16 @@ go 1.21 require github.com/stretchr/testify v1.8.4 require ( + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect +) + +require ( + github.com/alicebob/miniredis/v2 v2.33.0 github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/redis/go-redis/v9 v9.7.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/v2/go.sum b/v2/go.sum index fa4b6e6..f36dd60 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -1,9 +1,21 @@ +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA= +github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=