From 6f1e32a5ad2d45dbc976f8187994e1ffe94fc308 Mon Sep 17 00:00:00 2001 From: Maria Ines Parnisari Date: Mon, 19 Feb 2024 18:58:08 -0300 Subject: [PATCH 1/2] docs: add docs in `global.go` (#213) * chore: rename for more clarity * undo function renames * another undo of a rename --- docs/architecture.md | 16 ++++++++-------- global.go | 41 +++++++++++++++++++++++------------------ 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/docs/architecture.md b/docs/architecture.md index f716bf2c..01481320 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -45,7 +45,7 @@ apply the new config immediately. ## Global Behavior Since Gubernator rate limits are hashed and handled by a single peer in the -cluster. Rate limits that apply to every request in a data center would result +cluster, rate limits that apply to every request in a data center could result in the rate limit request being handled by a single peer for the entirety of the data center. For example, consider a rate limit with `name=requests_per_datacenter` and a `unique_id=us-east-1`. Now imagine that a @@ -68,7 +68,7 @@ limit status from the owner. #### Side effects of global behavior Since Hits are batched and forwarded to the owning peer asynchronously, the immediate response to the client will not include the most accurate remaining -counts. As that count will only get updated after the async call to the owner +counts, as that count will only get updated after the async call to the owner peer is complete and the owning peer has had time to update all the peers in the cluster. As a result the use of GLOBAL allows for greater scale but at the cost of consistency. @@ -83,18 +83,18 @@ updates before all nodes have the `Hit` updated in their cache. To calculate the WORST case scenario, we total the number of network updates that must occur for each global rate limit. -Count 1 incoming request to the node -Count 1 request when forwarding to the owning node -Count 1 + (number of nodes in cluster) to update all the nodes with the current Hit count. +- Count 1 incoming request to the node +- Count 1 request when forwarding to the owning node +- Count 1 + (number of nodes in cluster) to update all the nodes with the current Hit count. -Remember this is the WORST case, as the node that recieved the request might be -the owning node thus no need to forward to the owner. Additionally we improve +Remember this is the WORST case, as the node that received the request might be +the owning node thus no need to forward to the owner. Additionally, we improve the worst case by having the owning node batch Hits when forwarding to all the nodes in the cluster. Such that 1,000 individual requests of Hit = 1 each for a unique key will result in batched request from the owner to each node with a single Hit = 1,000 update. -Additionally thousands of hits to different unique keys will also be batched +Additionally, thousands of hits to different unique keys will also be batched such that network usage doesn't increase until the number of requests in an update batch exceeds the `BehaviorConfig.GlobalBatchLimit` or when the number of nodes in the cluster increases. (thus more batch updates) When that occurs you diff --git a/global.go b/global.go index 78431960..cd113108 100644 --- a/global.go +++ b/global.go @@ -27,12 +27,12 @@ import ( // globalManager manages async hit queue and updates peers in // the cluster periodically when a global rate limit we own updates. type globalManager struct { - asyncQueue chan *RateLimitReq - broadcastQueue chan *RateLimitReq + hitsQueue chan *RateLimitReq + updatesQueue chan *RateLimitReq wg syncutil.WaitGroup conf BehaviorConfig log FieldLogger - instance *V1Instance + instance *V1Instance // todo circular import? V1Instance also holds a reference to globalManager metricGlobalSendDuration prometheus.Summary metricBroadcastDuration prometheus.Summary metricBroadcastCounter *prometheus.CounterVec @@ -41,11 +41,11 @@ type globalManager struct { func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager { gm := globalManager{ - log: instance.log, - asyncQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - broadcastQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - instance: instance, - conf: conf, + log: instance.log, + hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), + updatesQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), + instance: instance, + conf: conf, metricGlobalSendDuration: prometheus.NewSummary(prometheus.SummaryOpts{ Name: "gubernator_global_send_duration", Help: "The duration of GLOBAL async sends in seconds.", @@ -71,15 +71,18 @@ func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager } func (gm *globalManager) QueueHit(r *RateLimitReq) { - gm.asyncQueue <- r + gm.hitsQueue <- r } func (gm *globalManager) QueueUpdate(r *RateLimitReq) { - gm.broadcastQueue <- r + gm.updatesQueue <- r } -// runAsyncHits collects async hit requests and queues them to -// be sent to their owning peers. +// runAsyncHits collects async hit requests in a forever loop, +// aggregates them in one request, and sends them to +// the owning peers. +// The updates are sent both when the batch limit is hit +// and in a periodic frequency determined by GlobalSyncWait. func (gm *globalManager) runAsyncHits() { var interval = NewInterval(gm.conf.GlobalSyncWait) hits := make(map[string]*RateLimitReq) @@ -87,7 +90,7 @@ func (gm *globalManager) runAsyncHits() { gm.wg.Until(func(done chan struct{}) bool { select { - case r := <-gm.asyncQueue: + case r := <-gm.hitsQueue: // Aggregate the hits into a single request key := r.HashKey() _, ok := hits[key] @@ -162,7 +165,7 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { if err != nil { gm.log.WithError(err). - Errorf("error sending global hits to '%s'", p.client.Info().GRPCAddress) + Errorf("while sending global hits to '%s'", p.client.Info().GRPCAddress) } return nil }, p) @@ -170,14 +173,17 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { fan.Wait() } -// runBroadcasts collects status changes for global rate limits and broadcasts the changes to each peer in the cluster. +// runBroadcasts collects status changes for global rate limits in a forever loop, +// and broadcasts the changes to each peer in the cluster. +// The updates are sent both when the batch limit is hit +// and in a periodic frequency determined by GlobalSyncWait. func (gm *globalManager) runBroadcasts() { var interval = NewInterval(gm.conf.GlobalSyncWait) updates := make(map[string]*RateLimitReq) gm.wg.Until(func(done chan struct{}) bool { select { - case r := <-gm.broadcastQueue: + case r := <-gm.updatesQueue: updates[r.HashKey()] = r // Send the hits if we reached our batch limit @@ -226,10 +232,9 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] status, err := gm.instance.getLocalRateLimit(ctx, rl) if err != nil { - gm.log.WithError(err).Errorf("while broadcasting update to peers for: '%s'", rl.HashKey()) + gm.log.WithError(err).Errorf("while getting local rate limit for: '%s'", rl.HashKey()) continue } - // Build an UpdatePeerGlobalsReq req.Globals = append(req.Globals, &UpdatePeerGlobal{ Algorithm: rl.Algorithm, Key: rl.HashKey(), From a312ed73014107e1f9d58a04f1c3a156d61bf3c5 Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Wed, 21 Feb 2024 07:49:29 -0700 Subject: [PATCH 2/2] Change global behavior (#219) * test: Add test for global rate limiting with load balancing * fix global update behavior * Added findNonOwningPeer() and getClientToNonOwningPeer() * Fix global mode * remove logs and add comment Co-authored-by: Yamil Asusta * fixed global over-consume issue and improved global testing --------- Co-authored-by: Philip Gough Co-authored-by: Yamil Asusta Co-authored-by: Maria Ines Parnisari --- .github/workflows/on-pull-request.yml | 2 +- Makefile | 2 +- algorithms.go | 23 +- cluster/cluster.go | 51 ++- config.go | 2 + daemon.go | 77 +++- functional_test.go | 533 +++++++++++++++++++++++--- global.go | 57 ++- gubernator.go | 51 ++- interval_test.go | 9 +- 10 files changed, 673 insertions(+), 134 deletions(-) diff --git a/.github/workflows/on-pull-request.yml b/.github/workflows/on-pull-request.yml index d89825f7..23854e74 100644 --- a/.github/workflows/on-pull-request.yml +++ b/.github/workflows/on-pull-request.yml @@ -50,7 +50,7 @@ jobs: skip-cache: true - name: Test - run: go test -v -race -p=1 -count=1 + run: go test -v -race -p=1 -count=1 -tags holster_test_mode go-bench: runs-on: ubuntu-latest timeout-minutes: 30 diff --git a/Makefile b/Makefile index 7c77cca8..75240d97 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ lint: $(GOLANGCI_LINT) .PHONY: test test: - (go test -v -race -p=1 -count=1 -coverprofile coverage.out ./...; ret=$$?; \ + (go test -v -race -p=1 -count=1 -tags holster_test_mode -coverprofile coverage.out ./...; ret=$$?; \ go tool cover -func coverage.out; \ go tool cover -html coverage.out -o coverage.html; \ exit $$ret) diff --git a/algorithms.go b/algorithms.go index 1fb8f9dd..f2ed4a82 100644 --- a/algorithms.go +++ b/algorithms.go @@ -26,6 +26,13 @@ import ( "go.opentelemetry.io/otel/trace" ) +// ### NOTE ### +// The both token and leaky follow the same semantic which allows for requests of more than the limit +// to be rejected, but subsequent requests within the same window that are under the limit to succeed. +// IE: client attempts to send 1000 emails but 100 is their limit. The request is rejected as over the +// limit, but we do not set the remainder to 0 in the cache. The client can retry within the same window +// with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` + // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { @@ -82,12 +89,6 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * ResetTime: 0, }, nil } - - // The following semantic allows for requests of more than the limit to be rejected, but subsequent - // requests within the same duration that are under the limit to succeed. IE: client attempts to - // send 1000 emails but 100 is their limit. The request is rejected as over the limit, but since we - // don't store OVER_LIMIT in the cache the client can retry within the same rate limit duration with - // 100 emails and the request will succeed. t, ok := item.Value.(*TokenBucketItem) if !ok { // Client switched algorithms; perhaps due to a migration? @@ -388,22 +389,24 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * // If requested hits takes the remainder if int64(b.Remaining) == r.Hits { - b.Remaining -= float64(r.Hits) - rl.Remaining = 0 + b.Remaining = 0 + rl.Remaining = int64(b.Remaining) rl.ResetTime = now + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } // If requested is more than available, then return over the limit - // without updating the bucket. + // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. if r.Hits > int64(b.Remaining) { metricOverLimitCounter.Add(1) rl.Status = Status_OVER_LIMIT + + // DRAIN_OVER_LIMIT behavior drains the remaining counter. if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - // DRAIN_OVER_LIMIT behavior drains the remaining counter. b.Remaining = 0 rl.Remaining = 0 } + return rl, nil } diff --git a/cluster/cluster.go b/cluster/cluster.go index 493aa71c..4c18efd6 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -68,6 +68,47 @@ func PeerAt(idx int) gubernator.PeerInfo { return peers[idx] } +// FindOwningPeer finds the peer which owns the rate limit with the provided name and unique key +func FindOwningPeer(name, key string) (gubernator.PeerInfo, error) { + p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + if err != nil { + return gubernator.PeerInfo{}, err + } + return p.Info(), nil +} + +// FindOwningDaemon finds the daemon which owns the rate limit with the provided name and unique key +func FindOwningDaemon(name, key string) (*gubernator.Daemon, error) { + p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + if err != nil { + return &gubernator.Daemon{}, err + } + + for i, d := range daemons { + if d.PeerInfo.GRPCAddress == p.Info().GRPCAddress { + return daemons[i], nil + } + } + return &gubernator.Daemon{}, errors.New("unable to find owning daemon") +} + +// ListNonOwningDaemons returns a list of daemons in the cluster that do not own the rate limit +// for the name and key provided. +func ListNonOwningDaemons(name, key string) ([]*gubernator.Daemon, error) { + owner, err := FindOwningDaemon(name, key) + if err != nil { + return []*gubernator.Daemon{}, err + } + + var daemons []*gubernator.Daemon + for _, d := range GetDaemons() { + if d.PeerInfo.GRPCAddress != owner.PeerInfo.GRPCAddress { + daemons = append(daemons, d) + } + } + return daemons, nil +} + // DaemonAt returns a specific daemon func DaemonAt(idx int) *gubernator.Daemon { return daemons[idx] @@ -112,6 +153,7 @@ func StartWith(localPeers []gubernator.PeerInfo) error { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) d, err := gubernator.SpawnDaemon(ctx, gubernator.DaemonConfig{ Logger: logrus.WithField("instance", peer.GRPCAddress), + InstanceID: peer.GRPCAddress, GRPCListenAddress: peer.GRPCAddress, HTTPListenAddress: peer.HTTPAddress, DataCenter: peer.DataCenter, @@ -127,12 +169,15 @@ func StartWith(localPeers []gubernator.PeerInfo) error { return errors.Wrapf(err, "while starting server for addr '%s'", peer.GRPCAddress) } - // Add the peers and daemons to the package level variables - peers = append(peers, gubernator.PeerInfo{ + p := gubernator.PeerInfo{ GRPCAddress: d.GRPCListeners[0].Addr().String(), HTTPAddress: d.HTTPListener.Addr().String(), DataCenter: peer.DataCenter, - }) + } + d.PeerInfo = p + + // Add the peers and daemons to the package level variables + peers = append(peers, p) daemons = append(daemons, d) } diff --git a/config.go b/config.go index 122ffa22..19f9f06f 100644 --- a/config.go +++ b/config.go @@ -71,6 +71,8 @@ type BehaviorConfig struct { // Config for a gubernator instance type Config struct { + InstanceID string + // (Required) A list of GRPC servers to register our instance with GRPCServers []*grpc.Server diff --git a/daemon.go b/daemon.go index a220136b..97602075 100644 --- a/daemon.go +++ b/daemon.go @@ -19,6 +19,7 @@ package gubernator import ( "context" "crypto/tls" + "fmt" "log" "net" "net/http" @@ -40,6 +41,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/resolver" "google.golang.org/protobuf/encoding/protojson" ) @@ -47,6 +49,8 @@ type Daemon struct { GRPCListeners []net.Listener HTTPListener net.Listener V1Server *V1Instance + InstanceID string + PeerInfo PeerInfo log FieldLogger pool PoolInterface @@ -59,6 +63,7 @@ type Daemon struct { promRegister *prometheus.Registry gwCancel context.CancelFunc instanceConf Config + client V1Client } // SpawnDaemon starts a new gubernator daemon according to the provided DaemonConfig. @@ -67,8 +72,9 @@ type Daemon struct { func SpawnDaemon(ctx context.Context, conf DaemonConfig) (*Daemon, error) { s := &Daemon{ - log: conf.Logger, - conf: conf, + InstanceID: conf.InstanceID, + log: conf.Logger, + conf: conf, } return s, s.Start(ctx) } @@ -77,8 +83,8 @@ func (s *Daemon) Start(ctx context.Context) error { var err error setter.SetDefault(&s.log, logrus.WithFields(logrus.Fields{ - "instance-id": s.conf.InstanceID, - "category": "gubernator", + "instance": s.conf.InstanceID, + "category": "gubernator", })) s.promRegister = prometheus.NewRegistry() @@ -148,6 +154,7 @@ func (s *Daemon) Start(ctx context.Context) error { Behaviors: s.conf.Behaviors, CacheSize: s.conf.CacheSize, Workers: s.conf.Workers, + InstanceID: s.conf.InstanceID, } s.V1Server, err = NewV1Instance(s.instanceConf) @@ -411,6 +418,30 @@ func (s *Daemon) Peers() []PeerInfo { return peers } +func (s *Daemon) MustClient() V1Client { + c, err := s.Client() + if err != nil { + panic(fmt.Sprintf("[%s] failed to init daemon client - '%s'", s.InstanceID, err)) + } + return c +} + +func (s *Daemon) Client() (V1Client, error) { + if s.client != nil { + return s.client, nil + } + + conn, err := grpc.DialContext(context.Background(), + fmt.Sprintf("static:///%s", s.PeerInfo.GRPCAddress), + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + s.client = NewV1Client(conn) + return s.client, nil +} + // WaitForConnect returns nil if the list of addresses is listening // for connections; will block until context is cancelled. func WaitForConnect(ctx context.Context, addresses []string) error { @@ -451,3 +482,41 @@ func WaitForConnect(ctx context.Context, addresses []string) error { } return nil } + +type staticBuilder struct{} + +var _ resolver.Builder = (*staticBuilder)(nil) + +func (sb *staticBuilder) Scheme() string { + return "static" +} + +func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + var resolverAddrs []resolver.Address + for _, address := range strings.Split(target.Endpoint(), ",") { + resolverAddrs = append(resolverAddrs, resolver.Address{ + Addr: address, + ServerName: address, + }) + } + if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { + return nil, err + } + return &staticResolver{cc: cc}, nil +} + +// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC +// to connect a specific peer in the cluster. +func newStaticBuilder() resolver.Builder { + return &staticBuilder{} +} + +type staticResolver struct { + cc resolver.ClientConn +} + +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {} + +var _ resolver.Resolver = (*staticResolver)(nil) diff --git a/functional_test.go b/functional_test.go index 4abd2e25..b377e86a 100644 --- a/functional_test.go +++ b/functional_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "testing" + "time" guber "github.com/mailgun/gubernator/v2" "github.com/mailgun/gubernator/v2/cluster" @@ -37,14 +38,6 @@ import ( json "google.golang.org/protobuf/encoding/protojson" ) -var algos = []struct { - Name string - Algorithm guber.Algorithm -}{ - {Name: "Token bucket", Algorithm: guber.Algorithm_TOKEN_BUCKET}, - {Name: "Leaky bucket", Algorithm: guber.Algorithm_LEAKY_BUCKET}, -} - // Setup and shutdown the mock gubernator cluster for the entire test suite func TestMain(m *testing.M) { if err := cluster.StartWith([]guber.PeerInfo{ @@ -405,8 +398,8 @@ func TestDrainOverLimit(t *testing.T) { }, } - for idx, algoCase := range algos { - t.Run(algoCase.Name, func(t *testing.T) { + for idx, algoCase := range []guber.Algorithm{guber.Algorithm_TOKEN_BUCKET, guber.Algorithm_LEAKY_BUCKET} { + t.Run(guber.Algorithm_name[int32(algoCase)], func(t *testing.T) { for _, test := range tests { ctx := context.Background() t.Run(test.Name, func(t *testing.T) { @@ -415,7 +408,7 @@ func TestDrainOverLimit(t *testing.T) { { Name: "test_drain_over_limit", UniqueKey: fmt.Sprintf("account:1234:%d", idx), - Algorithm: algoCase.Algorithm, + Algorithm: algoCase, Behavior: guber.Behavior_DRAIN_OVER_LIMIT, Duration: guber.Second * 30, Hits: test.Hits, @@ -437,6 +430,49 @@ func TestDrainOverLimit(t *testing.T) { } } +func TestTokenBucketRequestMoreThanAvailable(t *testing.T) { + defer clock.Freeze(clock.Now()).Unfreeze() + + client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + require.NoError(t, err) + + sendHit := func(status guber.Status, remain int64, hit int64) *guber.RateLimitResp { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_token_more_than_available", + UniqueKey: "account:123456", + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Duration: guber.Millisecond * 1000, + Hits: hit, + Limit: 2000, + }, + }, + }) + require.NoError(t, err, hit) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, int64(2000), resp.Responses[0].Limit) + return resp.Responses[0] + } + + // Use half of the bucket + sendHit(guber.Status_UNDER_LIMIT, 1000, 1000) + + // Ask for more than the bucket has and the remainder is still 1000. + // See NOTE in algorithms.go + sendHit(guber.Status_OVER_LIMIT, 1000, 1500) + + // Now other clients can ask for some of the remaining until we hit our limit + sendHit(guber.Status_UNDER_LIMIT, 500, 500) + sendHit(guber.Status_UNDER_LIMIT, 100, 400) + sendHit(guber.Status_UNDER_LIMIT, 0, 100) + sendHit(guber.Status_OVER_LIMIT, 0, 1) +} + func TestLeakyBucket(t *testing.T) { defer clock.Freeze(clock.Now()).Unfreeze() @@ -696,7 +732,7 @@ func TestLeakyBucketGregorian(t *testing.T) { Hits: 1, Remaining: 58, Status: guber.Status_UNDER_LIMIT, - Sleep: clock.Second, + Sleep: clock.Millisecond * 1200, }, { Name: "third hit; leak one hit", @@ -706,7 +742,12 @@ func TestLeakyBucketGregorian(t *testing.T) { }, } + // Truncate to the nearest minute now := clock.Now() + now = now.Truncate(1 * time.Minute) + // So we don't start on the minute boundary + now = now.Add(time.Millisecond * 100) + for _, test := range tests { t.Run(test.Name, func(t *testing.T) { resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ @@ -807,6 +848,50 @@ func TestLeakyBucketNegativeHits(t *testing.T) { } } +func TestLeakyBucketRequestMoreThanAvailable(t *testing.T) { + // Freeze time so we don't leak during the test + defer clock.Freeze(clock.Now()).Unfreeze() + + client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + require.NoError(t, err) + + sendHit := func(status guber.Status, remain int64, hits int64) *guber.RateLimitResp { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_leaky_more_than_available", + UniqueKey: "account:123456", + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Duration: guber.Millisecond * 1000, + Hits: hits, + Limit: 2000, + }, + }, + }) + require.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, int64(2000), resp.Responses[0].Limit) + return resp.Responses[0] + } + + // Use half of the bucket + sendHit(guber.Status_UNDER_LIMIT, 1000, 1000) + + // Ask for more than the rate limit has and the remainder is still 1000. + // See NOTE in algorithms.go + sendHit(guber.Status_OVER_LIMIT, 1000, 1500) + + // Now other clients can ask for some of the remaining until we hit our limit + sendHit(guber.Status_UNDER_LIMIT, 500, 500) + sendHit(guber.Status_UNDER_LIMIT, 100, 400) + sendHit(guber.Status_UNDER_LIMIT, 0, 100) + sendHit(guber.Status_OVER_LIMIT, 0, 1) +} + func TestMissingFields(t *testing.T) { client, errs := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) require.Nil(t, errs) @@ -871,12 +956,16 @@ func TestMissingFields(t *testing.T) { } func TestGlobalRateLimits(t *testing.T) { - peer := cluster.PeerAt(0).GRPCAddress - client, errs := guber.DialV1Server(peer, nil) - require.NoError(t, errs) + const ( + name = "test_global" + key = "account:12345" + ) - sendHit := func(status guber.Status, remain int64, i int) string { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*5) + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, status guber.Status, hits int64, remain int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ @@ -885,59 +974,346 @@ func TestGlobalRateLimits(t *testing.T) { UniqueKey: "account:12345", Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, - Duration: guber.Second * 3, - Hits: 1, + Duration: guber.Minute * 3, + Hits: hits, Limit: 5, }, }, }) - require.NoError(t, err, i) - assert.Equal(t, "", resp.Responses[0].Error, i) - assert.Equal(t, status, resp.Responses[0].Status, i) - assert.Equal(t, remain, resp.Responses[0].Remaining, i) - assert.Equal(t, int64(5), resp.Responses[0].Limit, i) - - // ensure that we have a canonical host - assert.NotEmpty(t, resp.Responses[0].Metadata["owner"]) - - // name/key should ensure our connected peer is NOT the owner, - // the peer we are connected to should forward requests asynchronously to the owner. - assert.NotEqual(t, peer, resp.Responses[0].Metadata["owner"]) - - return resp.Responses[0].Metadata["owner"] + require.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, int64(5), resp.Responses[0].Limit) } - // Our first hit should create the request on the peer and queue for async forward - sendHit(guber.Status_UNDER_LIMIT, 4, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1, 4) // Our second should be processed as if we own it since the async forward hasn't occurred yet - sendHit(guber.Status_UNDER_LIMIT, 3, 2) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 2, 2) testutil.UntilPass(t, 20, clock.Millisecond*200, func(t testutil.TestingT) { - // Inspect our metrics, ensure they collected the counts we expected during this test - d := cluster.DaemonAt(0) - metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress) - m := getMetricRequest(t, metricsURL, "gubernator_global_send_duration_count") + // Inspect peers metrics, ensure the peer sent the global rate limit to the owner + metricsURL := fmt.Sprintf("http://%s/metrics", peers[0].Config().HTTPListenAddress) + m, err := getMetricRequest(metricsURL, "gubernator_global_send_duration_count") + assert.NoError(t, err) assert.Equal(t, 1, int(m.Value)) + }) + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) - // Expect one peer (the owning peer) to indicate a broadcast. - var broadcastCount int - for i := 0; i < cluster.NumOfDaemons(); i++ { - d := cluster.DaemonAt(i) - metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress) - m := getMetricRequest(t, metricsURL, "gubernator_broadcast_duration_count") - broadcastCount += int(m.Value) - } + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + + // Check different peers, they should have gotten the broadcast from the owner + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + + // Non owning peer should calculate the rate limit remaining before forwarding + // to the owner. + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 2, 0) + + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) + + sendHit(peers[4].MustClient(), guber.Status_OVER_LIMIT, 1, 0) +} + +func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { + const ( + name = "test_global_token_limit" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner + // to respond with OVER_LIMIT. + sendHit(guber.Status_OVER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) + // The status should still be OVER_LIMIT + sendHit(guber.Status_OVER_LIMIT, 0) +} + +func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { + const ( + name = "test_global_token_limit_leaky" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + // Ask a different peer if the status is over the limit + sendHit(peers[1].MustClient(), guber.Status_OVER_LIMIT, 1) +} + +func TestGlobalRequestMoreThanAvailable(t *testing.T) { + const ( + name = "test_global_more_than_available" + key = "account:123456" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: hits, + Limit: 100, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + // Ensure GRPC has connections to each peer before we start, as we want + // the actual test requests to happen quite fast. + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 0, 100) + } + + // Send a request for 50 hits from each non owning peer in the cluster. These requests + // will be queued and sent to the owner as accumulated hits. As a result of the async nature + // of `Behavior_GLOBAL` rate limit requests spread across peers like this will be allowed to + // over-consume their resource within the rate limit window until the owner is updated and + // a broadcast to all peers is received. + // + // The maximum number of resources that can be over-consumed can be calculated by multiplying + // the remainder by the number of peers in the cluster. For example: If you have a remainder of 100 + // and a cluster of 10 instances, then the maximum over-consumed resource is 1,000. If you need + // a more accurate remaining calculation, and wish to avoid over consuming a resource, then do + // not use `Behavior_GLOBAL`. + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + } + + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // We should be over the limit + sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) +} + +func TestGlobalNegativeHits(t *testing.T) { + const ( + name = "test_global_negative_hits" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, status guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 100, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, status, resp.Responses[0].GetStatus()) + assert.Equal(t, remaining, resp.Responses[0].Remaining) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + // Send a negative hit on a rate limit with no hits + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, -1, 3) + + // Wait for the negative remaining to propagate + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // Send another negative hit to a different peer + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, -1, 4) + + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + + // Should have 4 in the remainder + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 4, 0) + + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+3)) + + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 0, 0) +} + +func TestGlobalResetRemaining(t *testing.T) { + const ( + name = "test_global_reset" + key = "account:123456" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: hits, + Limit: 100, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + assert.Equal(t, remaining, resp.Responses[0].Remaining) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + } + + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // We should be over the limit and remaining should be zero + sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) - assert.Equal(t, 1, broadcastCount) + // Now reset the remaining + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL | guber.Behavior_RESET_REMAINING, + Duration: guber.Minute * 1_000, + Hits: 0, + Limit: 100, + }, + }, + }) + require.NoError(t, err) + assert.NotEqual(t, 100, resp.Responses[0].Remaining) + + // Wait for the reset to propagate. + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + + // Check a different peer to ensure remaining has been reset + resp, err = peers[1].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: 0, + Limit: 100, + }, + }, }) + require.NoError(t, err) + assert.NotEqual(t, 100, resp.Responses[0].Remaining) + } -func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { +func getMetricRequest(url string, name string) (*model.Sample, error) { resp, err := http.Get(url) - require.NoError(t, err) + if err != nil { + return nil, err + } defer resp.Body.Close() - return getMetric(t, resp.Body, name) + return getMetric(resp.Body, name) } func TestChangeLimit(t *testing.T) { @@ -1174,6 +1550,7 @@ func TestHealthCheck(t *testing.T) { } func TestLeakyBucketDivBug(t *testing.T) { + // Freeze time so we don't leak during the test defer clock.Freeze(clock.Now()).Unfreeze() client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) @@ -1321,7 +1698,7 @@ func TestGetPeerRateLimits(t *testing.T) { // TODO: Add a test for sending no rate limits RateLimitReqList.RateLimits = nil -func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { +func getMetric(in io.Reader, name string) (*model.Sample, error) { dec := expfmt.SampleDecoder{ Dec: expfmt.NewDecoder(in, expfmt.FmtText), Opts: &expfmt.DecodeOptions{ @@ -1336,14 +1713,58 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { if err == io.EOF { break } - assert.NoError(t, err) + if err != nil { + return nil, err + } all = append(all, smpls...) } for _, s := range all { if strings.Contains(s.Metric.String(), name) { - return s + return s, nil + } + } + return nil, nil +} + +// getBroadcastCount returns the current broadcast count for use with waitForBroadcast() +// TODO: Replace this with something else, we can call and reset via HTTP/GRPC calls in gubernator v3 +func getBroadcastCount(d *guber.Daemon) (int, error) { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return 0, err + } + + return int(m.Value), nil +} + +// waitForBroadcast waits until the broadcast count for the daemon passed +// changes to the expected value. Returns an error if the expected value is +// not found before the context is cancelled. +func waitForBroadcast(timeout clock.Duration, d *guber.Daemon, expect int) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return err + } + + // It's possible a broadcast occurred twice if waiting for multiple peer to + // forward updates to the owner. + if int(m.Value) >= expect { + // Give the nodes some time to process the broadcasts + clock.Sleep(clock.Millisecond * 500) + return nil + } + + select { + case <-clock.After(time.Millisecond * 800): + case <-ctx.Done(): + return ctx.Err() } } - return nil } diff --git a/global.go b/global.go index cd113108..adbd8e44 100644 --- a/global.go +++ b/global.go @@ -21,14 +21,13 @@ import ( "github.com/mailgun/holster/v4/syncutil" "github.com/prometheus/client_golang/prometheus" - "google.golang.org/protobuf/proto" ) // globalManager manages async hit queue and updates peers in // the cluster periodically when a global rate limit we own updates. type globalManager struct { hitsQueue chan *RateLimitReq - updatesQueue chan *RateLimitReq + broadcastQueue chan *UpdatePeerGlobal wg syncutil.WaitGroup conf BehaviorConfig log FieldLogger @@ -41,11 +40,11 @@ type globalManager struct { func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager { gm := globalManager{ - log: instance.log, - hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - updatesQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - instance: instance, - conf: conf, + log: instance.log, + hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), + broadcastQueue: make(chan *UpdatePeerGlobal, conf.GlobalBatchLimit), + instance: instance, + conf: conf, metricGlobalSendDuration: prometheus.NewSummary(prometheus.SummaryOpts{ Name: "gubernator_global_send_duration", Help: "The duration of GLOBAL async sends in seconds.", @@ -74,8 +73,12 @@ func (gm *globalManager) QueueHit(r *RateLimitReq) { gm.hitsQueue <- r } -func (gm *globalManager) QueueUpdate(r *RateLimitReq) { - gm.updatesQueue <- r +func (gm *globalManager) QueueUpdate(req *RateLimitReq, resp *RateLimitResp) { + gm.broadcastQueue <- &UpdatePeerGlobal{ + Key: req.HashKey(), + Algorithm: req.Algorithm, + Status: resp, + } } // runAsyncHits collects async hit requests in a forever loop, @@ -95,6 +98,11 @@ func (gm *globalManager) runAsyncHits() { key := r.HashKey() _, ok := hits[key] if ok { + // If any of our hits includes a request to RESET_REMAINING + // ensure the owning peer gets this behavior + if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { + SetBehavior(&hits[key].Behavior, Behavior_RESET_REMAINING, true) + } hits[key].Hits += r.Hits } else { hits[key] = r @@ -142,7 +150,6 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { gm.log.WithError(err).Errorf("while getting peer for hash key '%s'", r.HashKey()) continue } - p, ok := peerRequests[peer.Info().GRPCAddress] if ok { p.req.Requests = append(p.req.Requests, r) @@ -179,18 +186,18 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { // and in a periodic frequency determined by GlobalSyncWait. func (gm *globalManager) runBroadcasts() { var interval = NewInterval(gm.conf.GlobalSyncWait) - updates := make(map[string]*RateLimitReq) + updates := make(map[string]*UpdatePeerGlobal) gm.wg.Until(func(done chan struct{}) bool { select { - case r := <-gm.updatesQueue: - updates[r.HashKey()] = r + case updateReq := <-gm.broadcastQueue: + updates[updateReq.Key] = updateReq // Send the hits if we reached our batch limit if len(updates) >= gm.conf.GlobalBatchLimit { gm.metricBroadcastCounter.WithLabelValues("queue_full").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) return true } @@ -204,7 +211,7 @@ func (gm *globalManager) runBroadcasts() { if len(updates) != 0 { gm.metricBroadcastCounter.WithLabelValues("timer").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) } else { gm.metricGlobalQueueLength.Set(0) } @@ -216,30 +223,14 @@ func (gm *globalManager) runBroadcasts() { } // broadcastPeers broadcasts global rate limit statuses to all other peers -func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) { +func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*UpdatePeerGlobal) { defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration() var req UpdatePeerGlobalsReq gm.metricGlobalQueueLength.Set(float64(len(updates))) for _, r := range updates { - // Copy the original since we are removing the GLOBAL behavior - rl := proto.Clone(r).(*RateLimitReq) - // We are only sending the status of the rate limit so, we - // clear the behavior flag, so we don't get queued for update again. - SetBehavior(&rl.Behavior, Behavior_GLOBAL, false) - rl.Hits = 0 - - status, err := gm.instance.getLocalRateLimit(ctx, rl) - if err != nil { - gm.log.WithError(err).Errorf("while getting local rate limit for: '%s'", rl.HashKey()) - continue - } - req.Globals = append(req.Globals, &UpdatePeerGlobal{ - Algorithm: rl.Algorithm, - Key: rl.HashKey(), - Status: status, - }) + req.Globals = append(req.Globals, r) } fan := syncutil.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) diff --git a/gubernator.go b/gubernator.go index 59c26eca..58f3f616 100644 --- a/gubernator.go +++ b/gubernator.go @@ -396,23 +396,9 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) tracing.EndScope(ctx, err) }() - item, ok, err := s.workerPool.GetCacheItem(ctx, req.HashKey()) - if err != nil { - countError(err, "Error in workerPool.GetCacheItem") - return nil, errors.Wrap(err, "during in workerPool.GetCacheItem") - } - if ok { - // Global rate limits are always stored as RateLimitResp regardless of algorithm - rl, ok := item.Value.(*RateLimitResp) - if ok { - return rl, nil - } - // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and - // our cache still holds the rate limit we created on the first hit. - } - cpy := proto.Clone(req).(*RateLimitReq) - cpy.Behavior = Behavior_NO_BATCHING + SetBehavior(&cpy.Behavior, Behavior_NO_BATCHING, true) + SetBehavior(&cpy.Behavior, Behavior_GLOBAL, false) // Process the rate limit like we own it resp, err = s.getLocalRateLimit(ctx, cpy) @@ -427,13 +413,29 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) // UpdatePeerGlobals updates the local cache with a list of global rate limits. This method should only // be called by a peer who is the owner of a global rate limit. func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) { + now := MillisecondNow() for _, g := range r.Globals { item := &CacheItem{ ExpireAt: g.Status.ResetTime, Algorithm: g.Algorithm, - Value: g.Status, Key: g.Key, } + switch g.Algorithm { + case Algorithm_LEAKY_BUCKET: + item.Value = &LeakyBucketItem{ + Remaining: float64(g.Status.Remaining), + Limit: g.Status.Limit, + Burst: g.Status.Limit, + UpdatedAt: now, + } + case Algorithm_TOKEN_BUCKET: + item.Value = &TokenBucketItem{ + Status: g.Status.Status, + Limit: g.Status.Limit, + Remaining: g.Status.Remaining, + CreatedAt: now, + } + } err := s.workerPool.AddCacheItem(ctx, g.Key, item) if err != nil { return nil, errors.Wrap(err, "Error in workerPool.AddCacheItem") @@ -485,6 +487,15 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits // Extract the propagated context from the metadata in the request prop := propagation.TraceContext{} ctx := prop.Extract(ctx, &MetadataCarrier{Map: rin.req.Metadata}) + + // Forwarded global requests must have DRAIN_OVER_LIMIT set so token and leaky algorithms + // drain the remaining in the event a peer asks for more than is remaining. + // This is needed because with GLOBAL behavior peers will accumulate hits, which could + // result in requesting more hits than is remaining. + if HasBehavior(rin.req.Behavior, Behavior_GLOBAL) { + SetBehavior(&rin.req.Behavior, Behavior_DRAIN_OVER_LIMIT, true) + } + rl, err := s.getLocalRateLimit(ctx, rin.req) if err != nil { // Return the error for this request @@ -569,11 +580,9 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ } metricGetRateLimitCounter.WithLabelValues("local").Inc() - - // If global behavior and owning peer, broadcast update to all peers. - // Assuming that this peer does not own the ratelimit. + // If global behavior, then broadcast update to all peers. if HasBehavior(r.Behavior, Behavior_GLOBAL) { - s.global.QueueUpdate(r) + s.global.QueueUpdate(r, resp) } return resp, nil diff --git a/interval_test.go b/interval_test.go index 89642c3e..68c8b40d 100644 --- a/interval_test.go +++ b/interval_test.go @@ -18,9 +18,8 @@ package gubernator_test import ( "testing" - "time" - gubernator "github.com/mailgun/gubernator/v2" + "github.com/mailgun/gubernator/v2" "github.com/mailgun/holster/v4/clock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,18 +27,18 @@ import ( func TestInterval(t *testing.T) { t.Run("Happy path", func(t *testing.T) { - interval := gubernator.NewInterval(10 * time.Millisecond) + interval := gubernator.NewInterval(10 * clock.Millisecond) defer interval.Stop() interval.Next() assert.Empty(t, interval.C) - time.Sleep(10 * time.Millisecond) + clock.Sleep(10 * clock.Millisecond) // Wait for tick. select { case <-interval.C: - case <-time.After(100 * time.Millisecond): + case <-clock.After(100 * clock.Millisecond): require.Fail(t, "timeout") } })