From fc36f2f99cbb54b7676aef30a44ab474c955e1ef Mon Sep 17 00:00:00 2001 From: kp Date: Fri, 15 Nov 2024 05:26:03 +0530 Subject: [PATCH 1/6] Title: Implement Distributed Circuit Breaker (#70) * feature/redis-circuit-breaker * feature/redis-circuit-breaker * Refactor * save state * Saving half-open state also * Saving half-open state also * Added test case * Saving state transition * Pass context * Moved redis circuit breaker to v2 * Revert go.mod and go.sum * Acked review comments * Refactor * Refactor --------- Co-authored-by: Kalpit Pant --- v2/distributed_circuit_breaker.go | 247 ++++++++++++++++++ v2/distributed_circuit_breaker_test.go | 338 +++++++++++++++++++++++++ v2/go.mod | 9 + v2/go.sum | 12 + 4 files changed, 606 insertions(+) create mode 100644 v2/distributed_circuit_breaker.go create mode 100644 v2/distributed_circuit_breaker_test.go diff --git a/v2/distributed_circuit_breaker.go b/v2/distributed_circuit_breaker.go new file mode 100644 index 0000000..f3b6472 --- /dev/null +++ b/v2/distributed_circuit_breaker.go @@ -0,0 +1,247 @@ +package gobreaker + +import ( + "context" + "encoding/json" + "fmt" + "time" +) + +type CacheClient interface { + GetState(ctx context.Context, key string) ([]byte, error) + SetState(ctx context.Context, key string, value interface{}, expiration time.Duration) error +} + +// DistributedCircuitBreaker extends CircuitBreaker with distributed state storage +type DistributedCircuitBreaker[T any] struct { + *CircuitBreaker[T] + cacheClient CacheClient +} + +// StorageSettings extends Settings +type StorageSettings struct { + Settings +} + +// NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker configured with the given StorageSettings +func NewDistributedCircuitBreaker[T any](storageClient CacheClient, settings StorageSettings) *DistributedCircuitBreaker[T] { + cb := NewCircuitBreaker[T](settings.Settings) + return &DistributedCircuitBreaker[T]{ + CircuitBreaker: cb, + cacheClient: storageClient, + } +} + +// StoredState represents the CircuitBreaker state stored in Distributed Storage +type StoredState struct { + State State `json:"state"` + Generation uint64 `json:"generation"` + Counts Counts `json:"counts"` + Expiry time.Time `json:"expiry"` +} + +func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { + if rcb.cacheClient == nil { + return rcb.CircuitBreaker.State() + } + + state, err := rcb.getStoredState(ctx) + if err != nil { + // Fallback to in-memory state if Storage fails + return rcb.CircuitBreaker.State() + } + + now := time.Now() + currentState, _ := rcb.currentState(state, now) + + // Update the state in Storage if it has changed + if currentState != state.State { + state.State = currentState + if err := rcb.setStoredState(ctx, state); err != nil { + // Log the error, but continue with the current state + fmt.Printf("Failed to update state in storage: %v\n", err) + } + } + + return state.State +} + +// Execute runs the given request if the DistributedCircuitBreaker accepts it +func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (T, error) { + if rcb.cacheClient == nil { + return rcb.CircuitBreaker.Execute(req) + } + generation, err := rcb.beforeRequest(ctx) + if err != nil { + var zero T + return zero, err + } + + defer func() { + e := recover() + if e != nil { + rcb.afterRequest(ctx, generation, false) + panic(e) + } + }() + + result, err := req() + rcb.afterRequest(ctx, generation, rcb.isSuccessful(err)) + + return result, err +} + +func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) { + state, err := rcb.getStoredState(ctx) + if err != nil { + return 0, err + } + now := time.Now() + currentState, generation := rcb.currentState(state, now) + + if currentState != state.State { + rcb.setState(&state, currentState, now) + err = rcb.setStoredState(ctx, state) + if err != nil { + return 0, err + } + } + + if currentState == StateOpen { + return generation, ErrOpenState + } else if currentState == StateHalfOpen && state.Counts.Requests >= rcb.maxRequests { + return generation, ErrTooManyRequests + } + + state.Counts.onRequest() + err = rcb.setStoredState(ctx, state) + if err != nil { + return 0, err + } + + return generation, nil +} + +func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) { + state, err := rcb.getStoredState(ctx) + if err != nil { + return + } + now := time.Now() + currentState, generation := rcb.currentState(state, now) + if generation != before { + return + } + + if success { + rcb.onSuccess(&state, currentState, now) + } else { + rcb.onFailure(&state, currentState, now) + } + + rcb.setStoredState(ctx, state) +} + +func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *StoredState, currentState State, now time.Time) { + if state.State == StateOpen { + state.State = currentState + } + + switch currentState { + case StateClosed: + state.Counts.onSuccess() + case StateHalfOpen: + state.Counts.onSuccess() + if state.Counts.ConsecutiveSuccesses >= rcb.maxRequests { + rcb.setState(state, StateClosed, now) + } + } +} + +func (rcb *DistributedCircuitBreaker[T]) onFailure(state *StoredState, currentState State, now time.Time) { + switch currentState { + case StateClosed: + state.Counts.onFailure() + if rcb.readyToTrip(state.Counts) { + rcb.setState(state, StateOpen, now) + } + case StateHalfOpen: + rcb.setState(state, StateOpen, now) + } +} + +func (rcb *DistributedCircuitBreaker[T]) currentState(state StoredState, now time.Time) (State, uint64) { + switch state.State { + case StateClosed: + if !state.Expiry.IsZero() && state.Expiry.Before(now) { + rcb.toNewGeneration(&state, now) + } + case StateOpen: + if state.Expiry.Before(now) { + rcb.setState(&state, StateHalfOpen, now) + } + } + return state.State, state.Generation +} + +func (rcb *DistributedCircuitBreaker[T]) setState(state *StoredState, newState State, now time.Time) { + if state.State == newState { + return + } + + prev := state.State + state.State = newState + + rcb.toNewGeneration(state, now) + + if rcb.onStateChange != nil { + rcb.onStateChange(rcb.name, prev, newState) + } +} + +func (rcb *DistributedCircuitBreaker[T]) toNewGeneration(state *StoredState, now time.Time) { + + state.Generation++ + state.Counts.clear() + + var zero time.Time + switch state.State { + case StateClosed: + if rcb.interval == 0 { + state.Expiry = zero + } else { + state.Expiry = now.Add(rcb.interval) + } + case StateOpen: + state.Expiry = now.Add(rcb.timeout) + default: // StateHalfOpen + state.Expiry = zero + } +} + +func (rcb *DistributedCircuitBreaker[T]) getStorageKey() string { + return "cb:" + rcb.name +} + +func (rcb *DistributedCircuitBreaker[T]) getStoredState(ctx context.Context) (StoredState, error) { + var state StoredState + data, err := rcb.cacheClient.GetState(ctx, rcb.getStorageKey()) + if len(data) == 0 { + // Key doesn't exist, return default state + return StoredState{State: StateClosed}, nil + } else if err != nil { + return state, err + } + + err = json.Unmarshal(data, &state) + return state, err +} + +func (rcb *DistributedCircuitBreaker[T]) setStoredState(ctx context.Context, state StoredState) error { + data, err := json.Marshal(state) + if err != nil { + return err + } + + return rcb.cacheClient.SetState(ctx, rcb.getStorageKey(), data, 0) +} diff --git a/v2/distributed_circuit_breaker_test.go b/v2/distributed_circuit_breaker_test.go new file mode 100644 index 0000000..6371d3d --- /dev/null +++ b/v2/distributed_circuit_breaker_test.go @@ -0,0 +1,338 @@ +package gobreaker + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var defaultRCB *DistributedCircuitBreaker[any] +var customRCB *DistributedCircuitBreaker[any] + +type storageAdapter struct { + client *redis.Client +} + +func (r *storageAdapter) GetState(ctx context.Context, key string) ([]byte, error) { + return r.client.Get(ctx, key).Bytes() +} + +func (r *storageAdapter) SetState(ctx context.Context, key string, value interface{}, expiration time.Duration) error { + return r.client.Set(ctx, key, value, expiration).Err() +} + +func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Miniredis, *redis.Client) { + mr, err := miniredis.Run() + if err != nil { + panic(err) + } + + client := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + storageClient := &storageAdapter{client: client} + + return NewDistributedCircuitBreaker[any](storageClient, StorageSettings{ + Settings: Settings{ + Name: "TestBreaker", + MaxRequests: 3, + Interval: time.Second, + Timeout: time.Second * 2, + ReadyToTrip: func(counts Counts) bool { + return counts.ConsecutiveFailures > 5 + }, + }, + }), mr, client +} + +func pseudoSleepStorage(ctx context.Context, rcb *DistributedCircuitBreaker[any], period time.Duration) { + state, _ := rcb.getStoredState(ctx) + + state.Expiry = state.Expiry.Add(-period) + // Reset counts if the interval has passed + if time.Now().After(state.Expiry) { + state.Counts = Counts{} + } + rcb.setStoredState(ctx, state) +} + +func successRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error { + _, err := rcb.Execute(ctx, func() (interface{}, error) { return nil, nil }) + return err +} + +func failRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error { + _, err := rcb.Execute(ctx, func() (interface{}, error) { return nil, errors.New("fail") }) + if err != nil && err.Error() == "fail" { + return nil + } + return err +} + +func TestDistributedCircuitBreakerInitialization(t *testing.T) { + rcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + assert.Equal(t, "TestBreaker", rcb.Name()) + assert.Equal(t, uint32(3), rcb.maxRequests) + assert.Equal(t, time.Second, rcb.interval) + assert.Equal(t, time.Second*2, rcb.timeout) + assert.NotNil(t, rcb.readyToTrip) + + state := rcb.State(ctx) + assert.Equal(t, StateClosed, state) +} + +func TestDistributedCircuitBreakerStateTransitions(t *testing.T) { + rcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + // Check if initial state is closed + assert.Equal(t, StateClosed, rcb.State(ctx)) + + // StateClosed to StateOpen + for i := 0; i < 6; i++ { + assert.NoError(t, failRequest(ctx, rcb)) + } + + assert.Equal(t, StateOpen, rcb.State(ctx)) + + // Ensure requests fail when circuit is open + err := failRequest(ctx, rcb) + assert.Error(t, err) + assert.Equal(t, ErrOpenState, err) + + // Wait for timeout to transition to half-open + pseudoSleepStorage(ctx, rcb, rcb.timeout) + assert.Equal(t, StateHalfOpen, rcb.State(ctx)) + + // StateHalfOpen to StateClosed + for i := 0; i < int(rcb.maxRequests); i++ { + assert.NoError(t, successRequest(ctx, rcb)) + } + assert.Equal(t, StateClosed, rcb.State(ctx)) + + // StateClosed to StateOpen (again) + for i := 0; i < 6; i++ { + assert.NoError(t, failRequest(ctx, rcb)) + } + assert.Equal(t, StateOpen, rcb.State(ctx)) +} + +func TestDistributedCircuitBreakerExecution(t *testing.T) { + rcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + // Test successful execution + result, err := rcb.Execute(ctx, func() (interface{}, error) { + return "success", nil + }) + assert.NoError(t, err) + assert.Equal(t, "success", result) + + // Test failed execution + _, err = rcb.Execute(ctx, func() (interface{}, error) { + return nil, errors.New("test error") + }) + assert.Error(t, err) + assert.Equal(t, "test error", err.Error()) +} + +func TestDistributedCircuitBreakerCounts(t *testing.T) { + rcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + for i := 0; i < 5; i++ { + assert.Nil(t, successRequest(ctx, rcb)) + } + + state, _ := rcb.getStoredState(ctx) + assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts) + + assert.Nil(t, failRequest(ctx, rcb)) + state, _ = rcb.getStoredState(ctx) + assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts) +} + +func TestDistributedCircuitBreakerFallback(t *testing.T) { + rcb, mr, _ := setupTestWithMiniredis() + defer mr.Close() + + ctx := context.Background() + + // Test when Storage is unavailable + mr.Close() // Simulate Storage being unavailable + + rcb.cacheClient = nil + + state := rcb.State(ctx) + assert.Equal(t, StateClosed, state, "Should fallback to in-memory state when Storage is unavailable") + + // Ensure operations still work without Storage + assert.Nil(t, successRequest(ctx, rcb)) + assert.Nil(t, failRequest(ctx, rcb)) +} + +func TestCustomDistributedCircuitBreaker(t *testing.T) { + mr, err := miniredis.Run() + if err != nil { + panic(err) + } + defer mr.Close() + + client := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + storageClient := &storageAdapter{client: client} + + customRCB = NewDistributedCircuitBreaker[any](storageClient, StorageSettings{ + Settings: Settings{ + Name: "CustomBreaker", + MaxRequests: 3, + Interval: time.Second * 30, + Timeout: time.Second * 90, + ReadyToTrip: func(counts Counts) bool { + numReqs := counts.Requests + failureRatio := float64(counts.TotalFailures) / float64(numReqs) + return numReqs >= 3 && failureRatio >= 0.6 + }, + }, + }) + + ctx := context.Background() + + t.Run("Initialization", func(t *testing.T) { + assert.Equal(t, "CustomBreaker", customRCB.Name()) + assert.Equal(t, StateClosed, customRCB.State(ctx)) + }) + + t.Run("Counts and State Transitions", func(t *testing.T) { + // Perform 5 successful and 5 failed requests + for i := 0; i < 5; i++ { + assert.NoError(t, successRequest(ctx, customRCB)) + assert.NoError(t, failRequest(ctx, customRCB)) + } + + state, err := customRCB.getStoredState(ctx) + assert.NoError(t, err) + assert.Equal(t, StateClosed, state.State) + assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts) + + // Perform one more successful request + assert.NoError(t, successRequest(ctx, customRCB)) + state, err = customRCB.getStoredState(ctx) + assert.NoError(t, err) + assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts) + + // Simulate time passing to reset counts + pseudoSleepStorage(ctx, customRCB, time.Second*30) + + // Perform requests to trigger StateOpen + assert.NoError(t, successRequest(ctx, customRCB)) + assert.NoError(t, failRequest(ctx, customRCB)) + assert.NoError(t, failRequest(ctx, customRCB)) + + // Check if the circuit breaker is now open + assert.Equal(t, StateOpen, customRCB.State(ctx)) + + state, err = customRCB.getStoredState(ctx) + assert.NoError(t, err) + assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts) + }) + + t.Run("Timeout and Half-Open State", func(t *testing.T) { + // Simulate timeout to transition to half-open state + pseudoSleepStorage(ctx, customRCB, time.Second*90) + assert.Equal(t, StateHalfOpen, customRCB.State(ctx)) + + // Successful requests in half-open state should close the circuit + for i := 0; i < 3; i++ { + assert.NoError(t, successRequest(ctx, customRCB)) + } + assert.Equal(t, StateClosed, customRCB.State(ctx)) + }) +} + +func TestCustomDistributedCircuitBreakerStateTransitions(t *testing.T) { + // Setup + var stateChange StateChange + customSt := Settings{ + Name: "cb", + MaxRequests: 3, + Interval: 5 * time.Second, + Timeout: 5 * time.Second, + ReadyToTrip: func(counts Counts) bool { + return counts.ConsecutiveFailures >= 2 + }, + OnStateChange: func(name string, from State, to State) { + stateChange = StateChange{name, from, to} + }, + } + + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("Failed to start miniredis: %v", err) + } + defer mr.Close() + + client := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + storageClient := &storageAdapter{client: client} + + cb := NewDistributedCircuitBreaker[any](storageClient, StorageSettings{Settings: customSt}) + + ctx := context.Background() + + // Test case + t.Run("Circuit Breaker State Transitions", func(t *testing.T) { + // Initial state should be Closed + assert.Equal(t, StateClosed, cb.State(ctx)) + + // Cause two consecutive failures to trip the circuit + for i := 0; i < 2; i++ { + err := failRequest(ctx, cb) + assert.NoError(t, err, "Fail request should not return an error") + } + + // Circuit should now be Open + assert.Equal(t, StateOpen, cb.State(ctx)) + assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) + + // Requests should fail immediately when circuit is Open + err := successRequest(ctx, cb) + assert.Error(t, err) + assert.Equal(t, ErrOpenState, err) + + // Simulate timeout to transition to Half-Open + pseudoSleepStorage(ctx, cb, 6*time.Second) + assert.Equal(t, StateHalfOpen, cb.State(ctx)) + assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) + + // Successful requests in Half-Open state should close the circuit + for i := 0; i < int(cb.maxRequests); i++ { + err := successRequest(ctx, cb) + assert.NoError(t, err) + } + + // Circuit should now be Closed + assert.Equal(t, StateClosed, cb.State(ctx)) + assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) + }) +} diff --git a/v2/go.mod b/v2/go.mod index 9e1537a..eb1204f 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -5,7 +5,16 @@ go 1.21 require github.com/stretchr/testify v1.8.4 require ( + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect +) + +require ( + github.com/alicebob/miniredis/v2 v2.33.0 github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/redis/go-redis/v9 v9.7.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/v2/go.sum b/v2/go.sum index fa4b6e6..f36dd60 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -1,9 +1,21 @@ +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA= +github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From 7cc80993b8990af207d72a2fb67b3c33c13820d5 Mon Sep 17 00:00:00 2001 From: Yoshiyuki Mineo Date: Sat, 16 Nov 2024 19:04:13 +0900 Subject: [PATCH 2/6] Rename --- v2/distributed_circuit_breaker.go | 35 ++++++++++-------------- v2/distributed_circuit_breaker_test.go | 38 ++++++++++++-------------- 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/v2/distributed_circuit_breaker.go b/v2/distributed_circuit_breaker.go index f3b6472..34cf051 100644 --- a/v2/distributed_circuit_breaker.go +++ b/v2/distributed_circuit_breaker.go @@ -7,7 +7,7 @@ import ( "time" ) -type CacheClient interface { +type SharedStateStore interface { GetState(ctx context.Context, key string) ([]byte, error) SetState(ctx context.Context, key string, value interface{}, expiration time.Duration) error } @@ -15,25 +15,20 @@ type CacheClient interface { // DistributedCircuitBreaker extends CircuitBreaker with distributed state storage type DistributedCircuitBreaker[T any] struct { *CircuitBreaker[T] - cacheClient CacheClient -} - -// StorageSettings extends Settings -type StorageSettings struct { - Settings + cacheClient SharedStateStore } // NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker configured with the given StorageSettings -func NewDistributedCircuitBreaker[T any](storageClient CacheClient, settings StorageSettings) *DistributedCircuitBreaker[T] { - cb := NewCircuitBreaker[T](settings.Settings) +func NewDistributedCircuitBreaker[T any](storageClient SharedStateStore, settings Settings) *DistributedCircuitBreaker[T] { + cb := NewCircuitBreaker[T](settings) return &DistributedCircuitBreaker[T]{ CircuitBreaker: cb, cacheClient: storageClient, } } -// StoredState represents the CircuitBreaker state stored in Distributed Storage -type StoredState struct { +// SharedState represents the CircuitBreaker state stored in Distributed Storage +type SharedState struct { State State `json:"state"` Generation uint64 `json:"generation"` Counts Counts `json:"counts"` @@ -142,7 +137,7 @@ func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, befor rcb.setStoredState(ctx, state) } -func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *StoredState, currentState State, now time.Time) { +func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) { if state.State == StateOpen { state.State = currentState } @@ -158,7 +153,7 @@ func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *StoredState, currentSt } } -func (rcb *DistributedCircuitBreaker[T]) onFailure(state *StoredState, currentState State, now time.Time) { +func (rcb *DistributedCircuitBreaker[T]) onFailure(state *SharedState, currentState State, now time.Time) { switch currentState { case StateClosed: state.Counts.onFailure() @@ -170,7 +165,7 @@ func (rcb *DistributedCircuitBreaker[T]) onFailure(state *StoredState, currentSt } } -func (rcb *DistributedCircuitBreaker[T]) currentState(state StoredState, now time.Time) (State, uint64) { +func (rcb *DistributedCircuitBreaker[T]) currentState(state SharedState, now time.Time) (State, uint64) { switch state.State { case StateClosed: if !state.Expiry.IsZero() && state.Expiry.Before(now) { @@ -184,7 +179,7 @@ func (rcb *DistributedCircuitBreaker[T]) currentState(state StoredState, now tim return state.State, state.Generation } -func (rcb *DistributedCircuitBreaker[T]) setState(state *StoredState, newState State, now time.Time) { +func (rcb *DistributedCircuitBreaker[T]) setState(state *SharedState, newState State, now time.Time) { if state.State == newState { return } @@ -199,7 +194,7 @@ func (rcb *DistributedCircuitBreaker[T]) setState(state *StoredState, newState S } } -func (rcb *DistributedCircuitBreaker[T]) toNewGeneration(state *StoredState, now time.Time) { +func (rcb *DistributedCircuitBreaker[T]) toNewGeneration(state *SharedState, now time.Time) { state.Generation++ state.Counts.clear() @@ -223,12 +218,12 @@ func (rcb *DistributedCircuitBreaker[T]) getStorageKey() string { return "cb:" + rcb.name } -func (rcb *DistributedCircuitBreaker[T]) getStoredState(ctx context.Context) (StoredState, error) { - var state StoredState +func (rcb *DistributedCircuitBreaker[T]) getStoredState(ctx context.Context) (SharedState, error) { + var state SharedState data, err := rcb.cacheClient.GetState(ctx, rcb.getStorageKey()) if len(data) == 0 { // Key doesn't exist, return default state - return StoredState{State: StateClosed}, nil + return SharedState{State: StateClosed}, nil } else if err != nil { return state, err } @@ -237,7 +232,7 @@ func (rcb *DistributedCircuitBreaker[T]) getStoredState(ctx context.Context) (St return state, err } -func (rcb *DistributedCircuitBreaker[T]) setStoredState(ctx context.Context, state StoredState) error { +func (rcb *DistributedCircuitBreaker[T]) setStoredState(ctx context.Context, state SharedState) error { data, err := json.Marshal(state) if err != nil { return err diff --git a/v2/distributed_circuit_breaker_test.go b/v2/distributed_circuit_breaker_test.go index 6371d3d..5d123ee 100644 --- a/v2/distributed_circuit_breaker_test.go +++ b/v2/distributed_circuit_breaker_test.go @@ -38,15 +38,13 @@ func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Minir storageClient := &storageAdapter{client: client} - return NewDistributedCircuitBreaker[any](storageClient, StorageSettings{ - Settings: Settings{ - Name: "TestBreaker", - MaxRequests: 3, - Interval: time.Second, - Timeout: time.Second * 2, - ReadyToTrip: func(counts Counts) bool { - return counts.ConsecutiveFailures > 5 - }, + return NewDistributedCircuitBreaker[any](storageClient, Settings{ + Name: "TestBreaker", + MaxRequests: 3, + Interval: time.Second, + Timeout: time.Second * 2, + ReadyToTrip: func(counts Counts) bool { + return counts.ConsecutiveFailures > 5 }, }), mr, client } @@ -200,17 +198,15 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { storageClient := &storageAdapter{client: client} - customRCB = NewDistributedCircuitBreaker[any](storageClient, StorageSettings{ - Settings: Settings{ - Name: "CustomBreaker", - MaxRequests: 3, - Interval: time.Second * 30, - Timeout: time.Second * 90, - ReadyToTrip: func(counts Counts) bool { - numReqs := counts.Requests - failureRatio := float64(counts.TotalFailures) / float64(numReqs) - return numReqs >= 3 && failureRatio >= 0.6 - }, + customRCB = NewDistributedCircuitBreaker[any](storageClient, Settings{ + Name: "CustomBreaker", + MaxRequests: 3, + Interval: time.Second * 30, + Timeout: time.Second * 90, + ReadyToTrip: func(counts Counts) bool { + numReqs := counts.Requests + failureRatio := float64(counts.TotalFailures) / float64(numReqs) + return numReqs >= 3 && failureRatio >= 0.6 }, }) @@ -296,7 +292,7 @@ func TestCustomDistributedCircuitBreakerStateTransitions(t *testing.T) { storageClient := &storageAdapter{client: client} - cb := NewDistributedCircuitBreaker[any](storageClient, StorageSettings{Settings: customSt}) + cb := NewDistributedCircuitBreaker[any](storageClient, customSt) ctx := context.Background() From 11d03b2475f5fb598396e87c4193b465de371f64 Mon Sep 17 00:00:00 2001 From: Yoshiyuki Mineo Date: Sat, 16 Nov 2024 19:06:36 +0900 Subject: [PATCH 3/6] Rename --- v2/{distributed_circuit_breaker.go => distributed_gobreaker.go} | 0 ...uted_circuit_breaker_test.go => distributed_gobreaker_test.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename v2/{distributed_circuit_breaker.go => distributed_gobreaker.go} (100%) rename v2/{distributed_circuit_breaker_test.go => distributed_gobreaker_test.go} (100%) diff --git a/v2/distributed_circuit_breaker.go b/v2/distributed_gobreaker.go similarity index 100% rename from v2/distributed_circuit_breaker.go rename to v2/distributed_gobreaker.go diff --git a/v2/distributed_circuit_breaker_test.go b/v2/distributed_gobreaker_test.go similarity index 100% rename from v2/distributed_circuit_breaker_test.go rename to v2/distributed_gobreaker_test.go From d6880cfae02d0f5c2251d8910ac4df12cbde405d Mon Sep 17 00:00:00 2001 From: Yoshiyuki Mineo Date: Sat, 16 Nov 2024 20:54:56 +0900 Subject: [PATCH 4/6] Refactor --- v2/distributed_gobreaker.go | 62 +++++++++----------------------- v2/distributed_gobreaker_test.go | 38 ++++++++++++++------ 2 files changed, 44 insertions(+), 56 deletions(-) diff --git a/v2/distributed_gobreaker.go b/v2/distributed_gobreaker.go index 34cf051..ee8a86f 100644 --- a/v2/distributed_gobreaker.go +++ b/v2/distributed_gobreaker.go @@ -2,14 +2,21 @@ package gobreaker import ( "context" - "encoding/json" "fmt" "time" ) +// SharedState represents the CircuitBreaker state stored in Distributed Storage +type SharedState struct { + State State `json:"state"` + Generation uint64 `json:"generation"` + Counts Counts `json:"counts"` + Expiry time.Time `json:"expiry"` +} + type SharedStateStore interface { - GetState(ctx context.Context, key string) ([]byte, error) - SetState(ctx context.Context, key string, value interface{}, expiration time.Duration) error + GetState(ctx context.Context) (SharedState, error) + SetState(ctx context.Context, state SharedState) error } // DistributedCircuitBreaker extends CircuitBreaker with distributed state storage @@ -27,20 +34,12 @@ func NewDistributedCircuitBreaker[T any](storageClient SharedStateStore, setting } } -// SharedState represents the CircuitBreaker state stored in Distributed Storage -type SharedState struct { - State State `json:"state"` - Generation uint64 `json:"generation"` - Counts Counts `json:"counts"` - Expiry time.Time `json:"expiry"` -} - func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { if rcb.cacheClient == nil { return rcb.CircuitBreaker.State() } - state, err := rcb.getStoredState(ctx) + state, err := rcb.cacheClient.GetState(ctx) if err != nil { // Fallback to in-memory state if Storage fails return rcb.CircuitBreaker.State() @@ -52,7 +51,7 @@ func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { // Update the state in Storage if it has changed if currentState != state.State { state.State = currentState - if err := rcb.setStoredState(ctx, state); err != nil { + if err := rcb.cacheClient.SetState(ctx, state); err != nil { // Log the error, but continue with the current state fmt.Printf("Failed to update state in storage: %v\n", err) } @@ -87,7 +86,7 @@ func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() } func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) { - state, err := rcb.getStoredState(ctx) + state, err := rcb.cacheClient.GetState(ctx) if err != nil { return 0, err } @@ -96,7 +95,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin if currentState != state.State { rcb.setState(&state, currentState, now) - err = rcb.setStoredState(ctx, state) + err = rcb.cacheClient.SetState(ctx, state) if err != nil { return 0, err } @@ -109,7 +108,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin } state.Counts.onRequest() - err = rcb.setStoredState(ctx, state) + err = rcb.cacheClient.SetState(ctx, state) if err != nil { return 0, err } @@ -118,7 +117,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin } func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) { - state, err := rcb.getStoredState(ctx) + state, err := rcb.cacheClient.GetState(ctx) if err != nil { return } @@ -134,7 +133,7 @@ func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, befor rcb.onFailure(&state, currentState, now) } - rcb.setStoredState(ctx, state) + rcb.cacheClient.SetState(ctx, state) } func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) { @@ -213,30 +212,3 @@ func (rcb *DistributedCircuitBreaker[T]) toNewGeneration(state *SharedState, now state.Expiry = zero } } - -func (rcb *DistributedCircuitBreaker[T]) getStorageKey() string { - return "cb:" + rcb.name -} - -func (rcb *DistributedCircuitBreaker[T]) getStoredState(ctx context.Context) (SharedState, error) { - var state SharedState - data, err := rcb.cacheClient.GetState(ctx, rcb.getStorageKey()) - if len(data) == 0 { - // Key doesn't exist, return default state - return SharedState{State: StateClosed}, nil - } else if err != nil { - return state, err - } - - err = json.Unmarshal(data, &state) - return state, err -} - -func (rcb *DistributedCircuitBreaker[T]) setStoredState(ctx context.Context, state SharedState) error { - data, err := json.Marshal(state) - if err != nil { - return err - } - - return rcb.cacheClient.SetState(ctx, rcb.getStorageKey(), data, 0) -} diff --git a/v2/distributed_gobreaker_test.go b/v2/distributed_gobreaker_test.go index 5d123ee..d9a8828 100644 --- a/v2/distributed_gobreaker_test.go +++ b/v2/distributed_gobreaker_test.go @@ -2,6 +2,7 @@ package gobreaker import ( "context" + "encoding/json" "errors" "testing" "time" @@ -18,12 +19,27 @@ type storageAdapter struct { client *redis.Client } -func (r *storageAdapter) GetState(ctx context.Context, key string) ([]byte, error) { - return r.client.Get(ctx, key).Bytes() +func (r *storageAdapter) GetState(ctx context.Context) (SharedState, error) { + var state SharedState + data, err := r.client.Get(ctx, "gobreaker").Bytes() + if len(data) == 0 { + // Key doesn't exist, return default state + return SharedState{State: StateClosed}, nil + } else if err != nil { + return state, err + } + + err = json.Unmarshal(data, &state) + return state, err } -func (r *storageAdapter) SetState(ctx context.Context, key string, value interface{}, expiration time.Duration) error { - return r.client.Set(ctx, key, value, expiration).Err() +func (r *storageAdapter) SetState(ctx context.Context, state SharedState) error { + data, err := json.Marshal(state) + if err != nil { + return err + } + + return r.client.Set(ctx, "gobreaker", data, 0).Err() } func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Miniredis, *redis.Client) { @@ -50,14 +66,14 @@ func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Minir } func pseudoSleepStorage(ctx context.Context, rcb *DistributedCircuitBreaker[any], period time.Duration) { - state, _ := rcb.getStoredState(ctx) + state, _ := rcb.cacheClient.GetState(ctx) state.Expiry = state.Expiry.Add(-period) // Reset counts if the interval has passed if time.Now().After(state.Expiry) { state.Counts = Counts{} } - rcb.setStoredState(ctx, state) + rcb.cacheClient.SetState(ctx, state) } func successRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error { @@ -158,11 +174,11 @@ func TestDistributedCircuitBreakerCounts(t *testing.T) { assert.Nil(t, successRequest(ctx, rcb)) } - state, _ := rcb.getStoredState(ctx) + state, _ := rcb.cacheClient.GetState(ctx) assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts) assert.Nil(t, failRequest(ctx, rcb)) - state, _ = rcb.getStoredState(ctx) + state, _ = rcb.cacheClient.GetState(ctx) assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts) } @@ -224,14 +240,14 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { assert.NoError(t, failRequest(ctx, customRCB)) } - state, err := customRCB.getStoredState(ctx) + state, err := customRCB.cacheClient.GetState(ctx) assert.NoError(t, err) assert.Equal(t, StateClosed, state.State) assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts) // Perform one more successful request assert.NoError(t, successRequest(ctx, customRCB)) - state, err = customRCB.getStoredState(ctx) + state, err = customRCB.cacheClient.GetState(ctx) assert.NoError(t, err) assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts) @@ -246,7 +262,7 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { // Check if the circuit breaker is now open assert.Equal(t, StateOpen, customRCB.State(ctx)) - state, err = customRCB.getStoredState(ctx) + state, err = customRCB.cacheClient.GetState(ctx) assert.NoError(t, err) assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts) }) From c1a3eca83d205b434eeacb96e99f72e03408388e Mon Sep 17 00:00:00 2001 From: Yoshiyuki Mineo Date: Sat, 16 Nov 2024 20:58:56 +0900 Subject: [PATCH 5/6] Rename --- v2/distributed_gobreaker.go | 22 +++++++++++----------- v2/distributed_gobreaker_test.go | 16 ++++++++-------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/v2/distributed_gobreaker.go b/v2/distributed_gobreaker.go index ee8a86f..bc7fbd6 100644 --- a/v2/distributed_gobreaker.go +++ b/v2/distributed_gobreaker.go @@ -22,7 +22,7 @@ type SharedStateStore interface { // DistributedCircuitBreaker extends CircuitBreaker with distributed state storage type DistributedCircuitBreaker[T any] struct { *CircuitBreaker[T] - cacheClient SharedStateStore + store SharedStateStore } // NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker configured with the given StorageSettings @@ -30,16 +30,16 @@ func NewDistributedCircuitBreaker[T any](storageClient SharedStateStore, setting cb := NewCircuitBreaker[T](settings) return &DistributedCircuitBreaker[T]{ CircuitBreaker: cb, - cacheClient: storageClient, + store: storageClient, } } func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { - if rcb.cacheClient == nil { + if rcb.store == nil { return rcb.CircuitBreaker.State() } - state, err := rcb.cacheClient.GetState(ctx) + state, err := rcb.store.GetState(ctx) if err != nil { // Fallback to in-memory state if Storage fails return rcb.CircuitBreaker.State() @@ -51,7 +51,7 @@ func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { // Update the state in Storage if it has changed if currentState != state.State { state.State = currentState - if err := rcb.cacheClient.SetState(ctx, state); err != nil { + if err := rcb.store.SetState(ctx, state); err != nil { // Log the error, but continue with the current state fmt.Printf("Failed to update state in storage: %v\n", err) } @@ -62,7 +62,7 @@ func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { // Execute runs the given request if the DistributedCircuitBreaker accepts it func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (T, error) { - if rcb.cacheClient == nil { + if rcb.store == nil { return rcb.CircuitBreaker.Execute(req) } generation, err := rcb.beforeRequest(ctx) @@ -86,7 +86,7 @@ func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() } func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) { - state, err := rcb.cacheClient.GetState(ctx) + state, err := rcb.store.GetState(ctx) if err != nil { return 0, err } @@ -95,7 +95,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin if currentState != state.State { rcb.setState(&state, currentState, now) - err = rcb.cacheClient.SetState(ctx, state) + err = rcb.store.SetState(ctx, state) if err != nil { return 0, err } @@ -108,7 +108,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin } state.Counts.onRequest() - err = rcb.cacheClient.SetState(ctx, state) + err = rcb.store.SetState(ctx, state) if err != nil { return 0, err } @@ -117,7 +117,7 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin } func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) { - state, err := rcb.cacheClient.GetState(ctx) + state, err := rcb.store.GetState(ctx) if err != nil { return } @@ -133,7 +133,7 @@ func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, befor rcb.onFailure(&state, currentState, now) } - rcb.cacheClient.SetState(ctx, state) + rcb.store.SetState(ctx, state) } func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) { diff --git a/v2/distributed_gobreaker_test.go b/v2/distributed_gobreaker_test.go index d9a8828..7ffbf6e 100644 --- a/v2/distributed_gobreaker_test.go +++ b/v2/distributed_gobreaker_test.go @@ -66,14 +66,14 @@ func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Minir } func pseudoSleepStorage(ctx context.Context, rcb *DistributedCircuitBreaker[any], period time.Duration) { - state, _ := rcb.cacheClient.GetState(ctx) + state, _ := rcb.store.GetState(ctx) state.Expiry = state.Expiry.Add(-period) // Reset counts if the interval has passed if time.Now().After(state.Expiry) { state.Counts = Counts{} } - rcb.cacheClient.SetState(ctx, state) + rcb.store.SetState(ctx, state) } func successRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error { @@ -174,11 +174,11 @@ func TestDistributedCircuitBreakerCounts(t *testing.T) { assert.Nil(t, successRequest(ctx, rcb)) } - state, _ := rcb.cacheClient.GetState(ctx) + state, _ := rcb.store.GetState(ctx) assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts) assert.Nil(t, failRequest(ctx, rcb)) - state, _ = rcb.cacheClient.GetState(ctx) + state, _ = rcb.store.GetState(ctx) assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts) } @@ -191,7 +191,7 @@ func TestDistributedCircuitBreakerFallback(t *testing.T) { // Test when Storage is unavailable mr.Close() // Simulate Storage being unavailable - rcb.cacheClient = nil + rcb.store = nil state := rcb.State(ctx) assert.Equal(t, StateClosed, state, "Should fallback to in-memory state when Storage is unavailable") @@ -240,14 +240,14 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { assert.NoError(t, failRequest(ctx, customRCB)) } - state, err := customRCB.cacheClient.GetState(ctx) + state, err := customRCB.store.GetState(ctx) assert.NoError(t, err) assert.Equal(t, StateClosed, state.State) assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts) // Perform one more successful request assert.NoError(t, successRequest(ctx, customRCB)) - state, err = customRCB.cacheClient.GetState(ctx) + state, err = customRCB.store.GetState(ctx) assert.NoError(t, err) assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts) @@ -262,7 +262,7 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { // Check if the circuit breaker is now open assert.Equal(t, StateOpen, customRCB.State(ctx)) - state, err = customRCB.cacheClient.GetState(ctx) + state, err = customRCB.store.GetState(ctx) assert.NoError(t, err) assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts) }) From c301deab31675a5558c717decdb398bea55e8c63 Mon Sep 17 00:00:00 2001 From: Yoshiyuki Mineo Date: Sat, 16 Nov 2024 21:48:32 +0900 Subject: [PATCH 6/6] Rename --- v2/distributed_gobreaker.go | 95 +++++++++++---------- v2/distributed_gobreaker_test.go | 138 +++++++++++++++---------------- 2 files changed, 116 insertions(+), 117 deletions(-) diff --git a/v2/distributed_gobreaker.go b/v2/distributed_gobreaker.go index bc7fbd6..2565d90 100644 --- a/v2/distributed_gobreaker.go +++ b/v2/distributed_gobreaker.go @@ -6,7 +6,7 @@ import ( "time" ) -// SharedState represents the CircuitBreaker state stored in Distributed Storage +// SharedState represents the shared state of DistributedCircuitBreaker. type SharedState struct { State State `json:"state"` Generation uint64 `json:"generation"` @@ -26,32 +26,32 @@ type DistributedCircuitBreaker[T any] struct { } // NewDistributedCircuitBreaker returns a new DistributedCircuitBreaker configured with the given StorageSettings -func NewDistributedCircuitBreaker[T any](storageClient SharedStateStore, settings Settings) *DistributedCircuitBreaker[T] { +func NewDistributedCircuitBreaker[T any](store SharedStateStore, settings Settings) *DistributedCircuitBreaker[T] { cb := NewCircuitBreaker[T](settings) return &DistributedCircuitBreaker[T]{ CircuitBreaker: cb, - store: storageClient, + store: store, } } -func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { - if rcb.store == nil { - return rcb.CircuitBreaker.State() +func (dcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { + if dcb.store == nil { + return dcb.CircuitBreaker.State() } - state, err := rcb.store.GetState(ctx) + state, err := dcb.store.GetState(ctx) if err != nil { // Fallback to in-memory state if Storage fails - return rcb.CircuitBreaker.State() + return dcb.CircuitBreaker.State() } now := time.Now() - currentState, _ := rcb.currentState(state, now) + currentState, _ := dcb.currentState(state, now) // Update the state in Storage if it has changed if currentState != state.State { state.State = currentState - if err := rcb.store.SetState(ctx, state); err != nil { + if err := dcb.store.SetState(ctx, state); err != nil { // Log the error, but continue with the current state fmt.Printf("Failed to update state in storage: %v\n", err) } @@ -61,11 +61,11 @@ func (rcb *DistributedCircuitBreaker[T]) State(ctx context.Context) State { } // Execute runs the given request if the DistributedCircuitBreaker accepts it -func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (T, error) { - if rcb.store == nil { - return rcb.CircuitBreaker.Execute(req) +func (dcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() (T, error)) (T, error) { + if dcb.store == nil { + return dcb.CircuitBreaker.Execute(req) } - generation, err := rcb.beforeRequest(ctx) + generation, err := dcb.beforeRequest(ctx) if err != nil { var zero T return zero, err @@ -74,28 +74,28 @@ func (rcb *DistributedCircuitBreaker[T]) Execute(ctx context.Context, req func() defer func() { e := recover() if e != nil { - rcb.afterRequest(ctx, generation, false) + dcb.afterRequest(ctx, generation, false) panic(e) } }() result, err := req() - rcb.afterRequest(ctx, generation, rcb.isSuccessful(err)) + dcb.afterRequest(ctx, generation, dcb.isSuccessful(err)) return result, err } -func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) { - state, err := rcb.store.GetState(ctx) +func (dcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uint64, error) { + state, err := dcb.store.GetState(ctx) if err != nil { return 0, err } now := time.Now() - currentState, generation := rcb.currentState(state, now) + currentState, generation := dcb.currentState(state, now) if currentState != state.State { - rcb.setState(&state, currentState, now) - err = rcb.store.SetState(ctx, state) + dcb.setState(&state, currentState, now) + err = dcb.store.SetState(ctx, state) if err != nil { return 0, err } @@ -103,12 +103,12 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin if currentState == StateOpen { return generation, ErrOpenState - } else if currentState == StateHalfOpen && state.Counts.Requests >= rcb.maxRequests { + } else if currentState == StateHalfOpen && state.Counts.Requests >= dcb.maxRequests { return generation, ErrTooManyRequests } state.Counts.onRequest() - err = rcb.store.SetState(ctx, state) + err = dcb.store.SetState(ctx, state) if err != nil { return 0, err } @@ -116,27 +116,27 @@ func (rcb *DistributedCircuitBreaker[T]) beforeRequest(ctx context.Context) (uin return generation, nil } -func (rcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) { - state, err := rcb.store.GetState(ctx) +func (dcb *DistributedCircuitBreaker[T]) afterRequest(ctx context.Context, before uint64, success bool) { + state, err := dcb.store.GetState(ctx) if err != nil { return } now := time.Now() - currentState, generation := rcb.currentState(state, now) + currentState, generation := dcb.currentState(state, now) if generation != before { return } if success { - rcb.onSuccess(&state, currentState, now) + dcb.onSuccess(&state, currentState, now) } else { - rcb.onFailure(&state, currentState, now) + dcb.onFailure(&state, currentState, now) } - rcb.store.SetState(ctx, state) + dcb.store.SetState(ctx, state) } -func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) { +func (dcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentState State, now time.Time) { if state.State == StateOpen { state.State = currentState } @@ -146,39 +146,39 @@ func (rcb *DistributedCircuitBreaker[T]) onSuccess(state *SharedState, currentSt state.Counts.onSuccess() case StateHalfOpen: state.Counts.onSuccess() - if state.Counts.ConsecutiveSuccesses >= rcb.maxRequests { - rcb.setState(state, StateClosed, now) + if state.Counts.ConsecutiveSuccesses >= dcb.maxRequests { + dcb.setState(state, StateClosed, now) } } } -func (rcb *DistributedCircuitBreaker[T]) onFailure(state *SharedState, currentState State, now time.Time) { +func (dcb *DistributedCircuitBreaker[T]) onFailure(state *SharedState, currentState State, now time.Time) { switch currentState { case StateClosed: state.Counts.onFailure() - if rcb.readyToTrip(state.Counts) { - rcb.setState(state, StateOpen, now) + if dcb.readyToTrip(state.Counts) { + dcb.setState(state, StateOpen, now) } case StateHalfOpen: - rcb.setState(state, StateOpen, now) + dcb.setState(state, StateOpen, now) } } -func (rcb *DistributedCircuitBreaker[T]) currentState(state SharedState, now time.Time) (State, uint64) { +func (dcb *DistributedCircuitBreaker[T]) currentState(state SharedState, now time.Time) (State, uint64) { switch state.State { case StateClosed: if !state.Expiry.IsZero() && state.Expiry.Before(now) { - rcb.toNewGeneration(&state, now) + dcb.toNewGeneration(&state, now) } case StateOpen: if state.Expiry.Before(now) { - rcb.setState(&state, StateHalfOpen, now) + dcb.setState(&state, StateHalfOpen, now) } } return state.State, state.Generation } -func (rcb *DistributedCircuitBreaker[T]) setState(state *SharedState, newState State, now time.Time) { +func (dcb *DistributedCircuitBreaker[T]) setState(state *SharedState, newState State, now time.Time) { if state.State == newState { return } @@ -186,28 +186,27 @@ func (rcb *DistributedCircuitBreaker[T]) setState(state *SharedState, newState S prev := state.State state.State = newState - rcb.toNewGeneration(state, now) + dcb.toNewGeneration(state, now) - if rcb.onStateChange != nil { - rcb.onStateChange(rcb.name, prev, newState) + if dcb.onStateChange != nil { + dcb.onStateChange(dcb.name, prev, newState) } } -func (rcb *DistributedCircuitBreaker[T]) toNewGeneration(state *SharedState, now time.Time) { - +func (dcb *DistributedCircuitBreaker[T]) toNewGeneration(state *SharedState, now time.Time) { state.Generation++ state.Counts.clear() var zero time.Time switch state.State { case StateClosed: - if rcb.interval == 0 { + if dcb.interval == 0 { state.Expiry = zero } else { - state.Expiry = now.Add(rcb.interval) + state.Expiry = now.Add(dcb.interval) } case StateOpen: - state.Expiry = now.Add(rcb.timeout) + state.Expiry = now.Add(dcb.timeout) default: // StateHalfOpen state.Expiry = zero } diff --git a/v2/distributed_gobreaker_test.go b/v2/distributed_gobreaker_test.go index 7ffbf6e..c9edf6f 100644 --- a/v2/distributed_gobreaker_test.go +++ b/v2/distributed_gobreaker_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" ) -var defaultRCB *DistributedCircuitBreaker[any] -var customRCB *DistributedCircuitBreaker[any] +var defaultDCB *DistributedCircuitBreaker[any] +var customDCB *DistributedCircuitBreaker[any] type storageAdapter struct { client *redis.Client @@ -65,24 +65,24 @@ func setupTestWithMiniredis() (*DistributedCircuitBreaker[any], *miniredis.Minir }), mr, client } -func pseudoSleepStorage(ctx context.Context, rcb *DistributedCircuitBreaker[any], period time.Duration) { - state, _ := rcb.store.GetState(ctx) +func pseudoSleepStorage(ctx context.Context, dcb *DistributedCircuitBreaker[any], period time.Duration) { + state, _ := dcb.store.GetState(ctx) state.Expiry = state.Expiry.Add(-period) // Reset counts if the interval has passed if time.Now().After(state.Expiry) { state.Counts = Counts{} } - rcb.store.SetState(ctx, state) + dcb.store.SetState(ctx, state) } -func successRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error { - _, err := rcb.Execute(ctx, func() (interface{}, error) { return nil, nil }) +func successRequest(ctx context.Context, dcb *DistributedCircuitBreaker[any]) error { + _, err := dcb.Execute(ctx, func() (interface{}, error) { return nil, nil }) return err } -func failRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error { - _, err := rcb.Execute(ctx, func() (interface{}, error) { return nil, errors.New("fail") }) +func failRequest(ctx context.Context, dcb *DistributedCircuitBreaker[any]) error { + _, err := dcb.Execute(ctx, func() (interface{}, error) { return nil, errors.New("fail") }) if err != nil && err.Error() == "fail" { return nil } @@ -90,74 +90,74 @@ func failRequest(ctx context.Context, rcb *DistributedCircuitBreaker[any]) error } func TestDistributedCircuitBreakerInitialization(t *testing.T) { - rcb, mr, _ := setupTestWithMiniredis() + dcb, mr, _ := setupTestWithMiniredis() defer mr.Close() ctx := context.Background() - assert.Equal(t, "TestBreaker", rcb.Name()) - assert.Equal(t, uint32(3), rcb.maxRequests) - assert.Equal(t, time.Second, rcb.interval) - assert.Equal(t, time.Second*2, rcb.timeout) - assert.NotNil(t, rcb.readyToTrip) + assert.Equal(t, "TestBreaker", dcb.Name()) + assert.Equal(t, uint32(3), dcb.maxRequests) + assert.Equal(t, time.Second, dcb.interval) + assert.Equal(t, time.Second*2, dcb.timeout) + assert.NotNil(t, dcb.readyToTrip) - state := rcb.State(ctx) + state := dcb.State(ctx) assert.Equal(t, StateClosed, state) } func TestDistributedCircuitBreakerStateTransitions(t *testing.T) { - rcb, mr, _ := setupTestWithMiniredis() + dcb, mr, _ := setupTestWithMiniredis() defer mr.Close() ctx := context.Background() // Check if initial state is closed - assert.Equal(t, StateClosed, rcb.State(ctx)) + assert.Equal(t, StateClosed, dcb.State(ctx)) // StateClosed to StateOpen for i := 0; i < 6; i++ { - assert.NoError(t, failRequest(ctx, rcb)) + assert.NoError(t, failRequest(ctx, dcb)) } - assert.Equal(t, StateOpen, rcb.State(ctx)) + assert.Equal(t, StateOpen, dcb.State(ctx)) // Ensure requests fail when circuit is open - err := failRequest(ctx, rcb) + err := failRequest(ctx, dcb) assert.Error(t, err) assert.Equal(t, ErrOpenState, err) // Wait for timeout to transition to half-open - pseudoSleepStorage(ctx, rcb, rcb.timeout) - assert.Equal(t, StateHalfOpen, rcb.State(ctx)) + pseudoSleepStorage(ctx, dcb, dcb.timeout) + assert.Equal(t, StateHalfOpen, dcb.State(ctx)) // StateHalfOpen to StateClosed - for i := 0; i < int(rcb.maxRequests); i++ { - assert.NoError(t, successRequest(ctx, rcb)) + for i := 0; i < int(dcb.maxRequests); i++ { + assert.NoError(t, successRequest(ctx, dcb)) } - assert.Equal(t, StateClosed, rcb.State(ctx)) + assert.Equal(t, StateClosed, dcb.State(ctx)) // StateClosed to StateOpen (again) for i := 0; i < 6; i++ { - assert.NoError(t, failRequest(ctx, rcb)) + assert.NoError(t, failRequest(ctx, dcb)) } - assert.Equal(t, StateOpen, rcb.State(ctx)) + assert.Equal(t, StateOpen, dcb.State(ctx)) } func TestDistributedCircuitBreakerExecution(t *testing.T) { - rcb, mr, _ := setupTestWithMiniredis() + dcb, mr, _ := setupTestWithMiniredis() defer mr.Close() ctx := context.Background() // Test successful execution - result, err := rcb.Execute(ctx, func() (interface{}, error) { + result, err := dcb.Execute(ctx, func() (interface{}, error) { return "success", nil }) assert.NoError(t, err) assert.Equal(t, "success", result) // Test failed execution - _, err = rcb.Execute(ctx, func() (interface{}, error) { + _, err = dcb.Execute(ctx, func() (interface{}, error) { return nil, errors.New("test error") }) assert.Error(t, err) @@ -165,25 +165,25 @@ func TestDistributedCircuitBreakerExecution(t *testing.T) { } func TestDistributedCircuitBreakerCounts(t *testing.T) { - rcb, mr, _ := setupTestWithMiniredis() + dcb, mr, _ := setupTestWithMiniredis() defer mr.Close() ctx := context.Background() for i := 0; i < 5; i++ { - assert.Nil(t, successRequest(ctx, rcb)) + assert.Nil(t, successRequest(ctx, dcb)) } - state, _ := rcb.store.GetState(ctx) + state, _ := dcb.store.GetState(ctx) assert.Equal(t, Counts{5, 5, 0, 5, 0}, state.Counts) - assert.Nil(t, failRequest(ctx, rcb)) - state, _ = rcb.store.GetState(ctx) + assert.Nil(t, failRequest(ctx, dcb)) + state, _ = dcb.store.GetState(ctx) assert.Equal(t, Counts{6, 5, 1, 0, 1}, state.Counts) } func TestDistributedCircuitBreakerFallback(t *testing.T) { - rcb, mr, _ := setupTestWithMiniredis() + dcb, mr, _ := setupTestWithMiniredis() defer mr.Close() ctx := context.Background() @@ -191,14 +191,14 @@ func TestDistributedCircuitBreakerFallback(t *testing.T) { // Test when Storage is unavailable mr.Close() // Simulate Storage being unavailable - rcb.store = nil + dcb.store = nil - state := rcb.State(ctx) + state := dcb.State(ctx) assert.Equal(t, StateClosed, state, "Should fallback to in-memory state when Storage is unavailable") // Ensure operations still work without Storage - assert.Nil(t, successRequest(ctx, rcb)) - assert.Nil(t, failRequest(ctx, rcb)) + assert.Nil(t, successRequest(ctx, dcb)) + assert.Nil(t, failRequest(ctx, dcb)) } func TestCustomDistributedCircuitBreaker(t *testing.T) { @@ -214,7 +214,7 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { storageClient := &storageAdapter{client: client} - customRCB = NewDistributedCircuitBreaker[any](storageClient, Settings{ + customDCB = NewDistributedCircuitBreaker[any](storageClient, Settings{ Name: "CustomBreaker", MaxRequests: 3, Interval: time.Second * 30, @@ -229,54 +229,54 @@ func TestCustomDistributedCircuitBreaker(t *testing.T) { ctx := context.Background() t.Run("Initialization", func(t *testing.T) { - assert.Equal(t, "CustomBreaker", customRCB.Name()) - assert.Equal(t, StateClosed, customRCB.State(ctx)) + assert.Equal(t, "CustomBreaker", customDCB.Name()) + assert.Equal(t, StateClosed, customDCB.State(ctx)) }) t.Run("Counts and State Transitions", func(t *testing.T) { // Perform 5 successful and 5 failed requests for i := 0; i < 5; i++ { - assert.NoError(t, successRequest(ctx, customRCB)) - assert.NoError(t, failRequest(ctx, customRCB)) + assert.NoError(t, successRequest(ctx, customDCB)) + assert.NoError(t, failRequest(ctx, customDCB)) } - state, err := customRCB.store.GetState(ctx) + state, err := customDCB.store.GetState(ctx) assert.NoError(t, err) assert.Equal(t, StateClosed, state.State) assert.Equal(t, Counts{10, 5, 5, 0, 1}, state.Counts) // Perform one more successful request - assert.NoError(t, successRequest(ctx, customRCB)) - state, err = customRCB.store.GetState(ctx) + assert.NoError(t, successRequest(ctx, customDCB)) + state, err = customDCB.store.GetState(ctx) assert.NoError(t, err) assert.Equal(t, Counts{11, 6, 5, 1, 0}, state.Counts) // Simulate time passing to reset counts - pseudoSleepStorage(ctx, customRCB, time.Second*30) + pseudoSleepStorage(ctx, customDCB, time.Second*30) // Perform requests to trigger StateOpen - assert.NoError(t, successRequest(ctx, customRCB)) - assert.NoError(t, failRequest(ctx, customRCB)) - assert.NoError(t, failRequest(ctx, customRCB)) + assert.NoError(t, successRequest(ctx, customDCB)) + assert.NoError(t, failRequest(ctx, customDCB)) + assert.NoError(t, failRequest(ctx, customDCB)) // Check if the circuit breaker is now open - assert.Equal(t, StateOpen, customRCB.State(ctx)) + assert.Equal(t, StateOpen, customDCB.State(ctx)) - state, err = customRCB.store.GetState(ctx) + state, err = customDCB.store.GetState(ctx) assert.NoError(t, err) assert.Equal(t, Counts{0, 0, 0, 0, 0}, state.Counts) }) t.Run("Timeout and Half-Open State", func(t *testing.T) { // Simulate timeout to transition to half-open state - pseudoSleepStorage(ctx, customRCB, time.Second*90) - assert.Equal(t, StateHalfOpen, customRCB.State(ctx)) + pseudoSleepStorage(ctx, customDCB, time.Second*90) + assert.Equal(t, StateHalfOpen, customDCB.State(ctx)) // Successful requests in half-open state should close the circuit for i := 0; i < 3; i++ { - assert.NoError(t, successRequest(ctx, customRCB)) + assert.NoError(t, successRequest(ctx, customDCB)) } - assert.Equal(t, StateClosed, customRCB.State(ctx)) + assert.Equal(t, StateClosed, customDCB.State(ctx)) }) } @@ -308,43 +308,43 @@ func TestCustomDistributedCircuitBreakerStateTransitions(t *testing.T) { storageClient := &storageAdapter{client: client} - cb := NewDistributedCircuitBreaker[any](storageClient, customSt) + dcb := NewDistributedCircuitBreaker[any](storageClient, customSt) ctx := context.Background() // Test case t.Run("Circuit Breaker State Transitions", func(t *testing.T) { // Initial state should be Closed - assert.Equal(t, StateClosed, cb.State(ctx)) + assert.Equal(t, StateClosed, dcb.State(ctx)) // Cause two consecutive failures to trip the circuit for i := 0; i < 2; i++ { - err := failRequest(ctx, cb) + err := failRequest(ctx, dcb) assert.NoError(t, err, "Fail request should not return an error") } // Circuit should now be Open - assert.Equal(t, StateOpen, cb.State(ctx)) + assert.Equal(t, StateOpen, dcb.State(ctx)) assert.Equal(t, StateChange{"cb", StateClosed, StateOpen}, stateChange) // Requests should fail immediately when circuit is Open - err := successRequest(ctx, cb) + err := successRequest(ctx, dcb) assert.Error(t, err) assert.Equal(t, ErrOpenState, err) // Simulate timeout to transition to Half-Open - pseudoSleepStorage(ctx, cb, 6*time.Second) - assert.Equal(t, StateHalfOpen, cb.State(ctx)) + pseudoSleepStorage(ctx, dcb, 6*time.Second) + assert.Equal(t, StateHalfOpen, dcb.State(ctx)) assert.Equal(t, StateChange{"cb", StateOpen, StateHalfOpen}, stateChange) // Successful requests in Half-Open state should close the circuit - for i := 0; i < int(cb.maxRequests); i++ { - err := successRequest(ctx, cb) + for i := 0; i < int(dcb.maxRequests); i++ { + err := successRequest(ctx, dcb) assert.NoError(t, err) } // Circuit should now be Closed - assert.Equal(t, StateClosed, cb.State(ctx)) + assert.Equal(t, StateClosed, dcb.State(ctx)) assert.Equal(t, StateChange{"cb", StateHalfOpen, StateClosed}, stateChange) }) }