From c0608d5a90e825d82fa8780edbc6d53b01edb205 Mon Sep 17 00:00:00 2001 From: Shawn Poulson Date: Mon, 11 Mar 2024 19:16:25 -0400 Subject: [PATCH] Fix for overlimit metric doublecounting on non-owner and owner. --- algorithms.go | 40 ++++++++++++++++++++++++++-------------- benchmark_test.go | 12 ++++++------ global.go | 3 ++- gubernator.go | 28 ++++++++++++++++++---------- peer_client.go | 7 ++++--- workers.go | 17 +++++++++-------- 6 files changed, 65 insertions(+), 42 deletions(-) diff --git a/algorithms.go b/algorithms.go index 7d452fc3..4032fa4f 100644 --- a/algorithms.go +++ b/algorithms.go @@ -34,7 +34,7 @@ import ( // with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket -func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { +func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) { tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) defer tokenBucketTimer.ObserveDuration() @@ -99,7 +99,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * s.Remove(ctx, hashKey) } - return tokenBucketNewItem(ctx, s, c, r) + return tokenBucketNewItem(ctx, s, c, r, rs) } // Update the limit if it changed. @@ -161,7 +161,9 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * // If we are already at the limit. if rl.Remaining == 0 && r.Hits > 0 { trace.SpanFromContext(ctx).AddEvent("Already over the limit") - metricOverLimitCounter.Add(1) + if rs.IsOwner { + metricOverLimitCounter.Add(1) + } rl.Status = Status_OVER_LIMIT t.Status = rl.Status return rl, nil @@ -179,7 +181,9 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * // without updating the cache. if r.Hits > t.Remaining { trace.SpanFromContext(ctx).AddEvent("Over the limit") - metricOverLimitCounter.Add(1) + if rs.IsOwner { + metricOverLimitCounter.Add(1) + } rl.Status = Status_OVER_LIMIT if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { // DRAIN_OVER_LIMIT behavior drains the remaining counter. @@ -195,11 +199,11 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * } // Item is not found in cache or store, create new. - return tokenBucketNewItem(ctx, s, c, r) + return tokenBucketNewItem(ctx, s, c, r, rs) } // Called by tokenBucket() when adding a new item in the store. -func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { +func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) { requestTime := *r.RequestTime expire := requestTime + r.Duration @@ -235,7 +239,9 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) // Client could be requesting that we always return OVER_LIMIT. if r.Hits > r.Limit { trace.SpanFromContext(ctx).AddEvent("Over the limit") - metricOverLimitCounter.Add(1) + if rs.IsOwner { + metricOverLimitCounter.Add(1) + } rl.Status = Status_OVER_LIMIT rl.Remaining = r.Limit t.Remaining = r.Limit @@ -251,7 +257,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) } // Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket -func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { +func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) { leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) defer leakyBucketTimer.ObserveDuration() @@ -308,7 +314,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * s.Remove(ctx, hashKey) } - return leakyBucketNewItem(ctx, s, c, r) + return leakyBucketNewItem(ctx, s, c, r, rs) } if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { @@ -381,7 +387,9 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * // If we are already at the limit if int64(b.Remaining) == 0 && r.Hits > 0 { - metricOverLimitCounter.Add(1) + if rs.IsOwner { + metricOverLimitCounter.Add(1) + } rl.Status = Status_OVER_LIMIT return rl, nil } @@ -397,7 +405,9 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * // If requested is more than available, then return over the limit // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. if r.Hits > int64(b.Remaining) { - metricOverLimitCounter.Add(1) + if rs.IsOwner { + metricOverLimitCounter.Add(1) + } rl.Status = Status_OVER_LIMIT // DRAIN_OVER_LIMIT behavior drains the remaining counter. @@ -420,11 +430,11 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * return rl, nil } - return leakyBucketNewItem(ctx, s, c, r) + return leakyBucketNewItem(ctx, s, c, r, rs) } // Called by leakyBucket() when adding a new item in the store. -func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { +func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) { requestTime := *r.RequestTime duration := r.Duration rate := float64(duration) / float64(r.Limit) @@ -457,7 +467,9 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) // Client could be requesting that we start with the bucket OVER_LIMIT if r.Hits > r.Burst { - metricOverLimitCounter.Add(1) + if rs.IsOwner { + metricOverLimitCounter.Add(1) + } rl.Status = Status_OVER_LIMIT rl.Remaining = 0 rl.ResetTime = requestTime + (rl.Limit-rl.Remaining)*int64(rate) diff --git a/benchmark_test.go b/benchmark_test.go index 20323dcd..5ceacf42 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -34,7 +34,7 @@ func BenchmarkServer(b *testing.B) { require.NoError(b, err, "Error in conf.SetDefaults") requestTime := epochMillis(clock.Now()) - b.Run("GetPeerRateLimit() with no batching", func(b *testing.B) { + b.Run("GetPeerRateLimit", func(b *testing.B) { client, err := guber.NewPeerClient(guber.PeerConfig{ Info: cluster.GetRandomPeer(cluster.DataCenterNone), Behavior: conf.Behaviors, @@ -46,9 +46,9 @@ func BenchmarkServer(b *testing.B) { for n := 0; n < b.N; n++ { _, err := client.GetPeerRateLimit(ctx, &guber.RateLimitReq{ - Name: b.Name(), - UniqueKey: guber.RandomString(10), - Behavior: guber.Behavior_NO_BATCHING, + Name: b.Name(), + UniqueKey: guber.RandomString(10), + // Behavior: guber.Behavior_NO_BATCHING, Limit: 10, Duration: 5, Hits: 1, @@ -60,7 +60,7 @@ func BenchmarkServer(b *testing.B) { } }) - b.Run("GetRateLimit()", func(b *testing.B) { + b.Run("GetRateLimits batching", func(b *testing.B) { client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) require.NoError(b, err, "Error in guber.DialV1Server") b.ResetTimer() @@ -83,7 +83,7 @@ func BenchmarkServer(b *testing.B) { } }) - b.Run("GetRateLimitGlobal()", func(b *testing.B) { + b.Run("GetRateLimits global", func(b *testing.B) { client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) require.NoError(b, err, "Error in guber.DialV1Server") b.ResetTimer() diff --git a/global.go b/global.go index 47703f6e..2300b971 100644 --- a/global.go +++ b/global.go @@ -234,6 +234,7 @@ func (gm *globalManager) runBroadcasts() { func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) { defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration() var req UpdatePeerGlobalsReq + reqState := RateLimitReqState{IsOwner: false} gm.metricGlobalQueueLength.Set(float64(len(updates))) @@ -241,7 +242,7 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] // Get current rate limit state. grlReq := proto.Clone(update).(*RateLimitReq) grlReq.Hits = 0 - status, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq) + status, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq, reqState) if err != nil { gm.log.WithError(err).Error("while retrieving rate limit status") continue diff --git a/gubernator.go b/gubernator.go index d0869cd3..280821f3 100644 --- a/gubernator.go +++ b/gubernator.go @@ -53,6 +53,10 @@ type V1Instance struct { workerPool *WorkerPool } +type RateLimitReqState struct { + IsOwner bool +} + var ( metricGetRateLimitCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "gubernator_getratelimit_counter", @@ -240,9 +244,10 @@ func (s *V1Instance) GetRateLimits(ctx context.Context, r *GetRateLimitsReq) (*G } // If our server instance is the owner of this rate limit - if peer.Info().IsOwner { + reqState := RateLimitReqState{IsOwner: peer.Info().IsOwner} + if reqState.IsOwner { // Apply our rate limit algorithm to the request - resp.Responses[i], err = s.getLocalRateLimit(ctx, req) + resp.Responses[i], err = s.getLocalRateLimit(ctx, req, reqState) if err != nil { err = errors.Wrapf(err, "Error while apply rate limit for '%s'", key) span := trace.SpanFromContext(ctx) @@ -313,6 +318,7 @@ func (s *V1Instance) asyncRequest(ctx context.Context, req *AsyncReq) { funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.asyncRequest")) defer funcTimer.ObserveDuration() + reqState := RateLimitReqState{IsOwner: false} resp := AsyncResp{ Idx: req.Idx, } @@ -332,7 +338,7 @@ func (s *V1Instance) asyncRequest(ctx context.Context, req *AsyncReq) { // If we are attempting again, the owner of this rate limit might have changed to us! if attempts != 0 { if req.Peer.Info().IsOwner { - resp.Resp, err = s.getLocalRateLimit(ctx, req.Req) + resp.Resp, err = s.getLocalRateLimit(ctx, req.Req, reqState) if err != nil { s.log.WithContext(ctx). WithError(err). @@ -399,12 +405,13 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) tracing.EndScope(ctx, err) }() - cpy := proto.Clone(req).(*RateLimitReq) - SetBehavior(&cpy.Behavior, Behavior_NO_BATCHING, true) - SetBehavior(&cpy.Behavior, Behavior_GLOBAL, false) + req2 := proto.Clone(req).(*RateLimitReq) + SetBehavior(&req2.Behavior, Behavior_NO_BATCHING, true) + SetBehavior(&req2.Behavior, Behavior_GLOBAL, false) + reqState := RateLimitReqState{IsOwner: false} // Process the rate limit like we own it - resp, err = s.getLocalRateLimit(ctx, cpy) + resp, err = s.getLocalRateLimit(ctx, req2, reqState) if err != nil { return nil, errors.Wrap(err, "during in getLocalRateLimit") } @@ -476,6 +483,7 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits respChan := make(chan respOut) var respWg sync.WaitGroup respWg.Add(1) + reqState := RateLimitReqState{IsOwner: true} go func() { // Capture each response and return in the same order @@ -509,7 +517,7 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits rin.req.RequestTime = &requestTime } - rl, err := s.getLocalRateLimit(ctx, rin.req) + rl, err := s.getLocalRateLimit(ctx, rin.req, reqState) if err != nil { // Return the error for this request err = errors.Wrap(err, "Error in getLocalRateLimit") @@ -577,7 +585,7 @@ func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health return health, nil } -func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ *RateLimitResp, err error) { +func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, rs RateLimitReqState) (_ *RateLimitResp, err error) { ctx = tracing.StartNamedScope(ctx, "V1Instance.getLocalRateLimit", trace.WithAttributes( attribute.String("ratelimit.key", r.UniqueKey), attribute.String("ratelimit.name", r.Name), @@ -587,7 +595,7 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ defer func() { tracing.EndScope(ctx, err) }() defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getLocalRateLimit")).ObserveDuration() - resp, err := s.workerPool.GetRateLimit(ctx, r) + resp, err := s.workerPool.GetRateLimit(ctx, r, rs) if err != nil { return nil, errors.Wrap(err, "during workerPool.GetRateLimit") } diff --git a/peer_client.go b/peer_client.go index 39c13c14..5e2fef15 100644 --- a/peer_client.go +++ b/peer_client.go @@ -66,9 +66,10 @@ type response struct { } type request struct { - request *RateLimitReq - resp chan *response - ctx context.Context + request *RateLimitReq + reqState RateLimitReqState + resp chan *response + ctx context.Context } type PeerConfig struct { diff --git a/workers.go b/workers.go index f6ed60a9..d62071be 100644 --- a/workers.go +++ b/workers.go @@ -199,7 +199,7 @@ func (p *WorkerPool) dispatch(worker *Worker) { } resp := new(response) - resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, worker.cache) + resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, req.reqState, worker.cache) select { case req.resp <- resp: // Success. @@ -258,16 +258,17 @@ func (p *WorkerPool) dispatch(worker *Worker) { } // GetRateLimit sends a GetRateLimit request to worker pool. -func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq) (*RateLimitResp, error) { +func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, rs RateLimitReqState) (*RateLimitResp, error) { // Delegate request to assigned channel based on request key. worker := p.getWorker(rlRequest.HashKey()) queueGauge := metricWorkerQueue.WithLabelValues("GetRateLimit", worker.name) queueGauge.Inc() defer queueGauge.Dec() handlerRequest := request{ - ctx: ctx, - resp: make(chan *response, 1), - request: rlRequest, + ctx: ctx, + resp: make(chan *response, 1), + request: rlRequest, + reqState: rs, } // Send request. @@ -289,14 +290,14 @@ func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq) } // Handle request received by worker. -func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, cache Cache) (*RateLimitResp, error) { +func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, rs RateLimitReqState, cache Cache) (*RateLimitResp, error) { defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Worker.handleGetRateLimit")).ObserveDuration() var rlResponse *RateLimitResp var err error switch req.Algorithm { case Algorithm_TOKEN_BUCKET: - rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req) + rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, rs) if err != nil { msg := "Error in tokenBucket" countError(err, msg) @@ -305,7 +306,7 @@ func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, } case Algorithm_LEAKY_BUCKET: - rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req) + rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, rs) if err != nil { msg := "Error in leakyBucket" countError(err, msg)