Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Commit

Permalink
Fix for overlimit metric doublecounting on non-owner and owner.
Browse files Browse the repository at this point in the history
  • Loading branch information
Baliedge committed Mar 11, 2024
1 parent 2229596 commit c0608d5
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 42 deletions.
40 changes: 26 additions & 14 deletions algorithms.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
// with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT`

// Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket
func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) {
func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) {
tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket"))
defer tokenBucketTimer.ObserveDuration()

Expand Down Expand Up @@ -99,7 +99,7 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
s.Remove(ctx, hashKey)
}

return tokenBucketNewItem(ctx, s, c, r)
return tokenBucketNewItem(ctx, s, c, r, rs)
}

// Update the limit if it changed.
Expand Down Expand Up @@ -161,7 +161,9 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
// If we are already at the limit.
if rl.Remaining == 0 && r.Hits > 0 {
trace.SpanFromContext(ctx).AddEvent("Already over the limit")
metricOverLimitCounter.Add(1)
if rs.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
t.Status = rl.Status
return rl, nil
Expand All @@ -179,7 +181,9 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
// without updating the cache.
if r.Hits > t.Remaining {
trace.SpanFromContext(ctx).AddEvent("Over the limit")
metricOverLimitCounter.Add(1)
if rs.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) {
// DRAIN_OVER_LIMIT behavior drains the remaining counter.
Expand All @@ -195,11 +199,11 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
}

// Item is not found in cache or store, create new.
return tokenBucketNewItem(ctx, s, c, r)
return tokenBucketNewItem(ctx, s, c, r, rs)
}

// Called by tokenBucket() when adding a new item in the store.
func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) {
func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) {
requestTime := *r.RequestTime
expire := requestTime + r.Duration

Expand Down Expand Up @@ -235,7 +239,9 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq)
// Client could be requesting that we always return OVER_LIMIT.
if r.Hits > r.Limit {
trace.SpanFromContext(ctx).AddEvent("Over the limit")
metricOverLimitCounter.Add(1)
if rs.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
rl.Remaining = r.Limit
t.Remaining = r.Limit
Expand All @@ -251,7 +257,7 @@ func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq)
}

// Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket
func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) {
func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) {
leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket"))
defer leakyBucketTimer.ObserveDuration()

Expand Down Expand Up @@ -308,7 +314,7 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
s.Remove(ctx, hashKey)
}

return leakyBucketNewItem(ctx, s, c, r)
return leakyBucketNewItem(ctx, s, c, r, rs)
}

if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) {
Expand Down Expand Up @@ -381,7 +387,9 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *

// If we are already at the limit
if int64(b.Remaining) == 0 && r.Hits > 0 {
metricOverLimitCounter.Add(1)
if rs.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
return rl, nil
}
Expand All @@ -397,7 +405,9 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
// If requested is more than available, then return over the limit
// without updating the bucket, unless `DRAIN_OVER_LIMIT` is set.
if r.Hits > int64(b.Remaining) {
metricOverLimitCounter.Add(1)
if rs.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT

// DRAIN_OVER_LIMIT behavior drains the remaining counter.
Expand All @@ -420,11 +430,11 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *
return rl, nil
}

return leakyBucketNewItem(ctx, s, c, r)
return leakyBucketNewItem(ctx, s, c, r, rs)
}

// Called by leakyBucket() when adding a new item in the store.
func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) {
func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, rs RateLimitReqState) (resp *RateLimitResp, err error) {
requestTime := *r.RequestTime
duration := r.Duration
rate := float64(duration) / float64(r.Limit)
Expand Down Expand Up @@ -457,7 +467,9 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq)

// Client could be requesting that we start with the bucket OVER_LIMIT
if r.Hits > r.Burst {
metricOverLimitCounter.Add(1)
if rs.IsOwner {
metricOverLimitCounter.Add(1)
}
rl.Status = Status_OVER_LIMIT
rl.Remaining = 0
rl.ResetTime = requestTime + (rl.Limit-rl.Remaining)*int64(rate)
Expand Down
12 changes: 6 additions & 6 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func BenchmarkServer(b *testing.B) {
require.NoError(b, err, "Error in conf.SetDefaults")
requestTime := epochMillis(clock.Now())

b.Run("GetPeerRateLimit() with no batching", func(b *testing.B) {
b.Run("GetPeerRateLimit", func(b *testing.B) {
client, err := guber.NewPeerClient(guber.PeerConfig{
Info: cluster.GetRandomPeer(cluster.DataCenterNone),
Behavior: conf.Behaviors,
Expand All @@ -46,9 +46,9 @@ func BenchmarkServer(b *testing.B) {

for n := 0; n < b.N; n++ {
_, err := client.GetPeerRateLimit(ctx, &guber.RateLimitReq{
Name: b.Name(),
UniqueKey: guber.RandomString(10),
Behavior: guber.Behavior_NO_BATCHING,
Name: b.Name(),
UniqueKey: guber.RandomString(10),
// Behavior: guber.Behavior_NO_BATCHING,
Limit: 10,
Duration: 5,
Hits: 1,
Expand All @@ -60,7 +60,7 @@ func BenchmarkServer(b *testing.B) {
}
})

b.Run("GetRateLimit()", func(b *testing.B) {
b.Run("GetRateLimits batching", func(b *testing.B) {
client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil)
require.NoError(b, err, "Error in guber.DialV1Server")
b.ResetTimer()
Expand All @@ -83,7 +83,7 @@ func BenchmarkServer(b *testing.B) {
}
})

b.Run("GetRateLimitGlobal()", func(b *testing.B) {
b.Run("GetRateLimits global", func(b *testing.B) {
client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil)
require.NoError(b, err, "Error in guber.DialV1Server")
b.ResetTimer()
Expand Down
3 changes: 2 additions & 1 deletion global.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,15 @@ func (gm *globalManager) runBroadcasts() {
func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) {
defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration()
var req UpdatePeerGlobalsReq
reqState := RateLimitReqState{IsOwner: false}

gm.metricGlobalQueueLength.Set(float64(len(updates)))

for _, update := range updates {
// Get current rate limit state.
grlReq := proto.Clone(update).(*RateLimitReq)
grlReq.Hits = 0
status, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq)
status, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq, reqState)
if err != nil {
gm.log.WithError(err).Error("while retrieving rate limit status")
continue
Expand Down
28 changes: 18 additions & 10 deletions gubernator.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type V1Instance struct {
workerPool *WorkerPool
}

type RateLimitReqState struct {
IsOwner bool
}

var (
metricGetRateLimitCounter = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "gubernator_getratelimit_counter",
Expand Down Expand Up @@ -240,9 +244,10 @@ func (s *V1Instance) GetRateLimits(ctx context.Context, r *GetRateLimitsReq) (*G
}

// If our server instance is the owner of this rate limit
if peer.Info().IsOwner {
reqState := RateLimitReqState{IsOwner: peer.Info().IsOwner}
if reqState.IsOwner {
// Apply our rate limit algorithm to the request
resp.Responses[i], err = s.getLocalRateLimit(ctx, req)
resp.Responses[i], err = s.getLocalRateLimit(ctx, req, reqState)
if err != nil {
err = errors.Wrapf(err, "Error while apply rate limit for '%s'", key)
span := trace.SpanFromContext(ctx)
Expand Down Expand Up @@ -313,6 +318,7 @@ func (s *V1Instance) asyncRequest(ctx context.Context, req *AsyncReq) {
funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.asyncRequest"))
defer funcTimer.ObserveDuration()

reqState := RateLimitReqState{IsOwner: false}
resp := AsyncResp{
Idx: req.Idx,
}
Expand All @@ -332,7 +338,7 @@ func (s *V1Instance) asyncRequest(ctx context.Context, req *AsyncReq) {
// If we are attempting again, the owner of this rate limit might have changed to us!
if attempts != 0 {
if req.Peer.Info().IsOwner {
resp.Resp, err = s.getLocalRateLimit(ctx, req.Req)
resp.Resp, err = s.getLocalRateLimit(ctx, req.Req, reqState)
if err != nil {
s.log.WithContext(ctx).
WithError(err).
Expand Down Expand Up @@ -399,12 +405,13 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq)
tracing.EndScope(ctx, err)
}()

cpy := proto.Clone(req).(*RateLimitReq)
SetBehavior(&cpy.Behavior, Behavior_NO_BATCHING, true)
SetBehavior(&cpy.Behavior, Behavior_GLOBAL, false)
req2 := proto.Clone(req).(*RateLimitReq)
SetBehavior(&req2.Behavior, Behavior_NO_BATCHING, true)
SetBehavior(&req2.Behavior, Behavior_GLOBAL, false)
reqState := RateLimitReqState{IsOwner: false}

// Process the rate limit like we own it
resp, err = s.getLocalRateLimit(ctx, cpy)
resp, err = s.getLocalRateLimit(ctx, req2, reqState)
if err != nil {
return nil, errors.Wrap(err, "during in getLocalRateLimit")
}
Expand Down Expand Up @@ -476,6 +483,7 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits
respChan := make(chan respOut)
var respWg sync.WaitGroup
respWg.Add(1)
reqState := RateLimitReqState{IsOwner: true}

go func() {
// Capture each response and return in the same order
Expand Down Expand Up @@ -509,7 +517,7 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits
rin.req.RequestTime = &requestTime
}

rl, err := s.getLocalRateLimit(ctx, rin.req)
rl, err := s.getLocalRateLimit(ctx, rin.req, reqState)
if err != nil {
// Return the error for this request
err = errors.Wrap(err, "Error in getLocalRateLimit")
Expand Down Expand Up @@ -577,7 +585,7 @@ func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health
return health, nil
}

func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ *RateLimitResp, err error) {
func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, rs RateLimitReqState) (_ *RateLimitResp, err error) {
ctx = tracing.StartNamedScope(ctx, "V1Instance.getLocalRateLimit", trace.WithAttributes(
attribute.String("ratelimit.key", r.UniqueKey),
attribute.String("ratelimit.name", r.Name),
Expand All @@ -587,7 +595,7 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_
defer func() { tracing.EndScope(ctx, err) }()
defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getLocalRateLimit")).ObserveDuration()

resp, err := s.workerPool.GetRateLimit(ctx, r)
resp, err := s.workerPool.GetRateLimit(ctx, r, rs)
if err != nil {
return nil, errors.Wrap(err, "during workerPool.GetRateLimit")
}
Expand Down
7 changes: 4 additions & 3 deletions peer_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ type response struct {
}

type request struct {
request *RateLimitReq
resp chan *response
ctx context.Context
request *RateLimitReq
reqState RateLimitReqState
resp chan *response
ctx context.Context
}

type PeerConfig struct {
Expand Down
17 changes: 9 additions & 8 deletions workers.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (p *WorkerPool) dispatch(worker *Worker) {
}

resp := new(response)
resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, worker.cache)
resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, req.reqState, worker.cache)
select {
case req.resp <- resp:
// Success.
Expand Down Expand Up @@ -258,16 +258,17 @@ func (p *WorkerPool) dispatch(worker *Worker) {
}

// GetRateLimit sends a GetRateLimit request to worker pool.
func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq) (*RateLimitResp, error) {
func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, rs RateLimitReqState) (*RateLimitResp, error) {
// Delegate request to assigned channel based on request key.
worker := p.getWorker(rlRequest.HashKey())
queueGauge := metricWorkerQueue.WithLabelValues("GetRateLimit", worker.name)
queueGauge.Inc()
defer queueGauge.Dec()
handlerRequest := request{
ctx: ctx,
resp: make(chan *response, 1),
request: rlRequest,
ctx: ctx,
resp: make(chan *response, 1),
request: rlRequest,
reqState: rs,
}

// Send request.
Expand All @@ -289,14 +290,14 @@ func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq)
}

// Handle request received by worker.
func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, cache Cache) (*RateLimitResp, error) {
func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, rs RateLimitReqState, cache Cache) (*RateLimitResp, error) {
defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Worker.handleGetRateLimit")).ObserveDuration()
var rlResponse *RateLimitResp
var err error

switch req.Algorithm {
case Algorithm_TOKEN_BUCKET:
rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req)
rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, rs)
if err != nil {
msg := "Error in tokenBucket"
countError(err, msg)
Expand All @@ -305,7 +306,7 @@ func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq,
}

case Algorithm_LEAKY_BUCKET:
rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req)
rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, rs)
if err != nil {
msg := "Error in leakyBucket"
countError(err, msg)
Expand Down

0 comments on commit c0608d5

Please sign in to comment.