From de3526954f409549153f46171123872f6dc41b16 Mon Sep 17 00:00:00 2001 From: Philip Gough Date: Tue, 12 Dec 2023 14:33:40 +0000 Subject: [PATCH 1/6] test: Add test for global rate limiting with load balancing --- functional_test.go | 104 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/functional_test.go b/functional_test.go index 4abd2e25..35235b70 100644 --- a/functional_test.go +++ b/functional_test.go @@ -25,6 +25,7 @@ import ( "os" "strings" "testing" + "time" guber "github.com/mailgun/gubernator/v2" "github.com/mailgun/gubernator/v2/cluster" @@ -34,6 +35,10 @@ import ( "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/resolver" json "google.golang.org/protobuf/encoding/protojson" ) @@ -933,6 +938,62 @@ func TestGlobalRateLimits(t *testing.T) { }) } +func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { + owner := cluster.PeerAt(2).GRPCAddress + peer := cluster.PeerAt(0).GRPCAddress + assert.NotEqual(t, owner, peer) + + dialOpts := []grpc.DialOption{ + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), + } + + address := fmt.Sprintf("static:///%s,%s", owner, peer) + conn, err := grpc.DialContext(context.Background(), address, dialOpts...) + require.NoError(t, err) + + client := guber.NewV1Client(conn) + + sendHit := func(status guber.Status, assertion func(resp *guber.RateLimitResp), i int) string { + ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_global", + UniqueKey: "account:12345", + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: 1, + Limit: 2, + }, + }, + }) + require.NoError(t, err, i) + gotResp := resp.Responses[0] + assert.Equal(t, "", gotResp.GetError(), i) + assert.Equal(t, status, gotResp.GetStatus(), i) + + if assertion != nil { + assertion(gotResp) + } + + return gotResp.GetMetadata()["owner"] + } + + // Send two hits that should be processed by the owner and the peer and deplete the limit + sendHit(guber.Status_UNDER_LIMIT, nil, 1) + sendHit(guber.Status_UNDER_LIMIT, nil, 2) + // sleep to ensure the async forward has occurred and state should be shared + time.Sleep(time.Second * 5) + + for i := 0; i < 10; i++ { + sendHit(guber.Status_OVER_LIMIT, nil, i+2) + } +} + func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { resp, err := http.Get(url) require.NoError(t, err) @@ -1347,3 +1408,46 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { } return nil } + +// staticBuilder implements the `resolver.Builder` interface. +type staticBuilder struct{} + +func newStaticBuilder() resolver.Builder { + return &staticBuilder{} +} + +func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + var resolverAddrs []resolver.Address + for _, address := range strings.Split(target.Endpoint(), ",") { + resolverAddrs = append(resolverAddrs, resolver.Address{ + Addr: address, + ServerName: address, + }) + + } + r, err := newStaticResolver(cc, resolverAddrs) + if err != nil { + return nil, err + } + return r, nil +} + +func (sb *staticBuilder) Scheme() string { + return "static" +} + +type staticResolver struct { + cc resolver.ClientConn +} + +func newStaticResolver(cc resolver.ClientConn, addresses []resolver.Address) (resolver.Resolver, error) { + err := cc.UpdateState(resolver.State{Addresses: addresses}) + if err != nil { + return nil, err + } + return &staticResolver{cc: cc}, nil +} + +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {} From 8715cb79ffc304d3871f4e110f13313e7f5f72cf Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Tue, 23 Jan 2024 15:17:58 -0600 Subject: [PATCH 2/6] fix global update behavior --- functional_test.go | 85 ++++++++++++++++++++++++++++++++++++---------- global.go | 51 +++++++++++----------------- gubernator.go | 22 +++++++++--- 3 files changed, 104 insertions(+), 54 deletions(-) diff --git a/functional_test.go b/functional_test.go index 35235b70..aad2aab1 100644 --- a/functional_test.go +++ b/functional_test.go @@ -938,7 +938,7 @@ func TestGlobalRateLimits(t *testing.T) { }) } -func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { +func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { owner := cluster.PeerAt(2).GRPCAddress peer := cluster.PeerAt(0).GRPCAddress assert.NotEqual(t, owner, peer) @@ -946,24 +946,23 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { dialOpts := []grpc.DialOption{ grpc.WithResolvers(newStaticBuilder()), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), } - address := fmt.Sprintf("static:///%s,%s", owner, peer) + address := fmt.Sprintf("static:///%s", peer) conn, err := grpc.DialContext(context.Background(), address, dialOpts...) require.NoError(t, err) client := guber.NewV1Client(conn) - sendHit := func(status guber.Status, assertion func(resp *guber.RateLimitResp), i int) string { + sendHit := func(status guber.Status, i int) string { ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { - Name: "test_global", + Name: "test_global_token_limit", UniqueKey: "account:12345", - Algorithm: guber.Algorithm_LEAKY_BUCKET, + Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 5, Hits: 1, @@ -976,22 +975,74 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { assert.Equal(t, "", gotResp.GetError(), i) assert.Equal(t, status, gotResp.GetStatus(), i) - if assertion != nil { - assertion(gotResp) - } - return gotResp.GetMetadata()["owner"] } - // Send two hits that should be processed by the owner and the peer and deplete the limit - sendHit(guber.Status_UNDER_LIMIT, nil, 1) - sendHit(guber.Status_UNDER_LIMIT, nil, 2) - // sleep to ensure the async forward has occurred and state should be shared - time.Sleep(time.Second * 5) + // Send two hits that should be processed by the owner and the peer and deplete the remaining + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner + // to respond with OVER_LIMIT. + sendHit(guber.Status_OVER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // The status should still be OVER_LIMIT + sendHit(guber.Status_OVER_LIMIT, 0) +} - for i := 0; i < 10; i++ { - sendHit(guber.Status_OVER_LIMIT, nil, i+2) +func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { + owner := cluster.PeerAt(2).GRPCAddress + peer := cluster.PeerAt(0).GRPCAddress + assert.NotEqual(t, owner, peer) + + dialOpts := []grpc.DialOption{ + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials()), } + + address := fmt.Sprintf("static:///%s", peer) + conn, err := grpc.DialContext(context.Background(), address, dialOpts...) + require.NoError(t, err) + + client := guber.NewV1Client(conn) + + sendHit := func(status guber.Status, i int) string { + ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_global_leaky_limit", + UniqueKey: "account:12345", + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 5, + Hits: 1, + Limit: 2, + }, + }, + }) + require.NoError(t, err, i) + gotResp := resp.Responses[0] + assert.Equal(t, "", gotResp.GetError(), i) + assert.Equal(t, status, gotResp.GetStatus(), i) + + return gotResp.GetMetadata()["owner"] + } + + // Send two hits that should be processed by the owner and the peer and deplete the remaining + sendHit(guber.Status_UNDER_LIMIT, 1) + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // Since the peer must wait for the owner to say it's over the limit, this will return under the limit. + sendHit(guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peer + time.Sleep(time.Second * 3) + // The status should now be OVER_LIMIT + sendHit(guber.Status_OVER_LIMIT, 0) } func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { diff --git a/global.go b/global.go index cd113108..fc6c7983 100644 --- a/global.go +++ b/global.go @@ -21,14 +21,13 @@ import ( "github.com/mailgun/holster/v4/syncutil" "github.com/prometheus/client_golang/prometheus" - "google.golang.org/protobuf/proto" ) // globalManager manages async hit queue and updates peers in // the cluster periodically when a global rate limit we own updates. type globalManager struct { hitsQueue chan *RateLimitReq - updatesQueue chan *RateLimitReq + broadcastQueue chan *UpdatePeerGlobal wg syncutil.WaitGroup conf BehaviorConfig log FieldLogger @@ -41,11 +40,11 @@ type globalManager struct { func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager { gm := globalManager{ - log: instance.log, - hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - updatesQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - instance: instance, - conf: conf, + log: instance.log, + hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), + broadcastQueue: make(chan *UpdatePeerGlobal, conf.GlobalBatchLimit), + instance: instance, + conf: conf, metricGlobalSendDuration: prometheus.NewSummary(prometheus.SummaryOpts{ Name: "gubernator_global_send_duration", Help: "The duration of GLOBAL async sends in seconds.", @@ -74,8 +73,12 @@ func (gm *globalManager) QueueHit(r *RateLimitReq) { gm.hitsQueue <- r } -func (gm *globalManager) QueueUpdate(r *RateLimitReq) { - gm.updatesQueue <- r +func (gm *globalManager) QueueUpdate(req *RateLimitReq, resp *RateLimitResp) { + gm.broadcastQueue <- &UpdatePeerGlobal{ + Key: req.HashKey(), + Algorithm: req.Algorithm, + Status: resp, + } } // runAsyncHits collects async hit requests in a forever loop, @@ -179,18 +182,18 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { // and in a periodic frequency determined by GlobalSyncWait. func (gm *globalManager) runBroadcasts() { var interval = NewInterval(gm.conf.GlobalSyncWait) - updates := make(map[string]*RateLimitReq) + updates := make(map[string]*UpdatePeerGlobal) gm.wg.Until(func(done chan struct{}) bool { select { - case r := <-gm.updatesQueue: - updates[r.HashKey()] = r + case updateReq := <-gm.broadcastQueue: + updates[updateReq.Key] = updateReq // Send the hits if we reached our batch limit if len(updates) >= gm.conf.GlobalBatchLimit { gm.metricBroadcastCounter.WithLabelValues("queue_full").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) return true } @@ -204,7 +207,7 @@ func (gm *globalManager) runBroadcasts() { if len(updates) != 0 { gm.metricBroadcastCounter.WithLabelValues("timer").Inc() gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*UpdatePeerGlobal) } else { gm.metricGlobalQueueLength.Set(0) } @@ -216,30 +219,14 @@ func (gm *globalManager) runBroadcasts() { } // broadcastPeers broadcasts global rate limit statuses to all other peers -func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) { +func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*UpdatePeerGlobal) { defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration() var req UpdatePeerGlobalsReq gm.metricGlobalQueueLength.Set(float64(len(updates))) for _, r := range updates { - // Copy the original since we are removing the GLOBAL behavior - rl := proto.Clone(r).(*RateLimitReq) - // We are only sending the status of the rate limit so, we - // clear the behavior flag, so we don't get queued for update again. - SetBehavior(&rl.Behavior, Behavior_GLOBAL, false) - rl.Hits = 0 - - status, err := gm.instance.getLocalRateLimit(ctx, rl) - if err != nil { - gm.log.WithError(err).Errorf("while getting local rate limit for: '%s'", rl.HashKey()) - continue - } - req.Globals = append(req.Globals, &UpdatePeerGlobal{ - Algorithm: rl.Algorithm, - Key: rl.HashKey(), - Status: status, - }) + req.Globals = append(req.Globals, r) } fan := syncutil.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) diff --git a/gubernator.go b/gubernator.go index 59c26eca..0d094a7c 100644 --- a/gubernator.go +++ b/gubernator.go @@ -405,9 +405,23 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) // Global rate limits are always stored as RateLimitResp regardless of algorithm rl, ok := item.Value.(*RateLimitResp) if ok { + // In the case we are not the owner, global behavior dictates that we respond with + // what ever the owner has broadcast to use as the response. However, in the case + // of TOKEN_BUCKET it makes little sense to wait for the owner to respond with OVER_LIMIT + // if we already know the remainder is 0. So we check for a remainder of 0 here and set + // OVER_LIMIT only if there are actual hits and this is not a RESET_REMAINING request and + // it's a TOKEN_BUCKET. + // + // We cannot preform this for LEAKY_BUCKET as we don't know how much time or what other requests + // might have influenced the leak rate at the owning peer. + // (Maybe we should preform the leak calculation here?????) + if rl.Remaining == 0 && req.Hits > 0 && !HasBehavior(req.Behavior, Behavior_RESET_REMAINING) && + req.Algorithm == Algorithm_TOKEN_BUCKET { + rl.Status = Status_OVER_LIMIT + } return rl, nil } - // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and + // We get here if the owning node hasn't asynchronously forwarded its updates to us yet and // our cache still holds the rate limit we created on the first hit. } @@ -569,11 +583,9 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq) (_ } metricGetRateLimitCounter.WithLabelValues("local").Inc() - - // If global behavior and owning peer, broadcast update to all peers. - // Assuming that this peer does not own the ratelimit. + // If global behavior, then broadcast update to all peers. if HasBehavior(r.Behavior, Behavior_GLOBAL) { - s.global.QueueUpdate(r) + s.global.QueueUpdate(r, resp) } return resp, nil From 1e7bf81f1f2333a5c9b33f2b54a9abeaade510d4 Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Wed, 24 Jan 2024 11:12:01 -0600 Subject: [PATCH 3/6] Added findNonOwningPeer() and getClientToNonOwningPeer() --- .github/workflows/on-pull-request.yml | 2 +- Makefile | 2 +- cluster/cluster.go | 9 ++ functional_test.go | 130 ++++++++++++++------------ 4 files changed, 82 insertions(+), 61 deletions(-) diff --git a/.github/workflows/on-pull-request.yml b/.github/workflows/on-pull-request.yml index d89825f7..23854e74 100644 --- a/.github/workflows/on-pull-request.yml +++ b/.github/workflows/on-pull-request.yml @@ -50,7 +50,7 @@ jobs: skip-cache: true - name: Test - run: go test -v -race -p=1 -count=1 + run: go test -v -race -p=1 -count=1 -tags holster_test_mode go-bench: runs-on: ubuntu-latest timeout-minutes: 30 diff --git a/Makefile b/Makefile index 7c77cca8..75240d97 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ lint: $(GOLANGCI_LINT) .PHONY: test test: - (go test -v -race -p=1 -count=1 -coverprofile coverage.out ./...; ret=$$?; \ + (go test -v -race -p=1 -count=1 -tags holster_test_mode -coverprofile coverage.out ./...; ret=$$?; \ go tool cover -func coverage.out; \ go tool cover -html coverage.out -o coverage.html; \ exit $$ret) diff --git a/cluster/cluster.go b/cluster/cluster.go index 493aa71c..bacdde30 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -68,6 +68,15 @@ func PeerAt(idx int) gubernator.PeerInfo { return peers[idx] } +// FindOwningPeer finds the peer which owns the rate limit with the provided name and unique key +func FindOwningPeer(name, key string) (gubernator.PeerInfo, error) { + p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + if err != nil { + return gubernator.PeerInfo{}, err + } + return p.Info(), nil +} + // DaemonAt returns a specific daemon func DaemonAt(idx int) *gubernator.Daemon { return daemons[idx] diff --git a/functional_test.go b/functional_test.go index aad2aab1..ca2cbade 100644 --- a/functional_test.go +++ b/functional_test.go @@ -939,29 +939,23 @@ func TestGlobalRateLimits(t *testing.T) { } func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { - owner := cluster.PeerAt(2).GRPCAddress - peer := cluster.PeerAt(0).GRPCAddress - assert.NotEqual(t, owner, peer) - - dialOpts := []grpc.DialOption{ - grpc.WithResolvers(newStaticBuilder()), - grpc.WithTransportCredentials(insecure.NewCredentials()), - } + const ( + name = "test_global_token_limit" + key = "account:12345" + ) - address := fmt.Sprintf("static:///%s", peer) - conn, err := grpc.DialContext(context.Background(), address, dialOpts...) + // Make a connection to a peer in the cluster which does not own this rate limit + client, err := getClientToNonOwningPeer(name, key) require.NoError(t, err) - client := guber.NewV1Client(conn) - - sendHit := func(status guber.Status, i int) string { + sendHit := func(expectedStatus guber.Status, hits int) { ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { - Name: "test_global_token_limit", - UniqueKey: "account:12345", + Name: name, + UniqueKey: key, Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 5, @@ -970,15 +964,12 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { }, }, }) - require.NoError(t, err, i) - gotResp := resp.Responses[0] - assert.Equal(t, "", gotResp.GetError(), i) - assert.Equal(t, status, gotResp.GetStatus(), i) - - return gotResp.GetMetadata()["owner"] + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) } - // Send two hits that should be processed by the owner and the peer and deplete the remaining + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining sendHit(guber.Status_UNDER_LIMIT, 1) sendHit(guber.Status_UNDER_LIMIT, 1) // Wait for the broadcast from the owner to the peer @@ -993,29 +984,23 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { } func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { - owner := cluster.PeerAt(2).GRPCAddress - peer := cluster.PeerAt(0).GRPCAddress - assert.NotEqual(t, owner, peer) - - dialOpts := []grpc.DialOption{ - grpc.WithResolvers(newStaticBuilder()), - grpc.WithTransportCredentials(insecure.NewCredentials()), - } + const ( + name = "test_global_token_limit_leaky" + key = "account:12345" + ) - address := fmt.Sprintf("static:///%s", peer) - conn, err := grpc.DialContext(context.Background(), address, dialOpts...) + // Make a connection to a peer in the cluster which does not own this rate limit + client, err := getClientToNonOwningPeer(name, key) require.NoError(t, err) - client := guber.NewV1Client(conn) - - sendHit := func(status guber.Status, i int) string { + sendHit := func(expectedStatus guber.Status, hits int) { ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { - Name: "test_global_leaky_limit", - UniqueKey: "account:12345", + Name: name, + UniqueKey: key, Algorithm: guber.Algorithm_LEAKY_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 5, @@ -1024,15 +1009,12 @@ func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { }, }, }) - require.NoError(t, err, i) - gotResp := resp.Responses[0] - assert.Equal(t, "", gotResp.GetError(), i) - assert.Equal(t, status, gotResp.GetStatus(), i) - - return gotResp.GetMetadata()["owner"] + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) } - // Send two hits that should be processed by the owner and the peer and deplete the remaining + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining sendHit(guber.Status_UNDER_LIMIT, 1) sendHit(guber.Status_UNDER_LIMIT, 1) // Wait for the broadcast from the owner to the peer @@ -1460,11 +1442,12 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { return nil } -// staticBuilder implements the `resolver.Builder` interface. type staticBuilder struct{} -func newStaticBuilder() resolver.Builder { - return &staticBuilder{} +var _ resolver.Builder = (*staticBuilder)(nil) + +func (sb *staticBuilder) Scheme() string { + return "static" } func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { @@ -1474,31 +1457,60 @@ func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ Addr: address, ServerName: address, }) - } - r, err := newStaticResolver(cc, resolverAddrs) - if err != nil { + if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { return nil, err } - return r, nil + return &staticResolver{cc: cc}, nil } -func (sb *staticBuilder) Scheme() string { - return "static" +// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC +// to connect a specific peer in the cluster. +func newStaticBuilder() resolver.Builder { + return &staticBuilder{} } type staticResolver struct { cc resolver.ClientConn } -func newStaticResolver(cc resolver.ClientConn, addresses []resolver.Address) (resolver.Resolver, error) { - err := cc.UpdateState(resolver.State{Addresses: addresses}) +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {} + +var _ resolver.Resolver = (*staticResolver)(nil) + +// findNonOwningPeer returns peer info for a peer in the cluster which does not +// own the rate limit for the name and key provided. +func findNonOwningPeer(name, key string) (guber.PeerInfo, error) { + owner, err := cluster.FindOwningPeer(name, key) if err != nil { - return nil, err + return guber.PeerInfo{}, err } - return &staticResolver{cc: cc}, nil + + for _, p := range cluster.GetPeers() { + if p.HashKey() != owner.HashKey() { + return p, nil + } + } + return guber.PeerInfo{}, fmt.Errorf("unable to find non-owning peer in '%d' node cluster", + len(cluster.GetPeers())) } -func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} +// getClientToNonOwningPeer returns a connection to a peer in the cluster which does not own +// the rate limit for the name and key provided. +func getClientToNonOwningPeer(name, key string) (guber.V1Client, error) { + p, err := findNonOwningPeer(name, key) + if err != nil { + return nil, err + } + conn, err := grpc.DialContext(context.Background(), + fmt.Sprintf("static:///%s", p.GRPCAddress), + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + return guber.NewV1Client(conn), nil -func (sr *staticResolver) Close() {} +} From 8b8a4b928a3c3a6733ea67fcff7c5e0e0d2c4b4d Mon Sep 17 00:00:00 2001 From: Yamil Asusta Date: Mon, 29 Jan 2024 17:41:35 -0400 Subject: [PATCH 4/6] Fix global mode --- algorithms.go | 9 +++---- functional_test.go | 9 +------ global.go | 2 ++ gubernator.go | 60 ++++++++++++++++++++++++---------------------- 4 files changed, 40 insertions(+), 40 deletions(-) diff --git a/algorithms.go b/algorithms.go index 1fb8f9dd..9b6d8325 100644 --- a/algorithms.go +++ b/algorithms.go @@ -388,16 +388,17 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * // If requested hits takes the remainder if int64(b.Remaining) == r.Hits { - b.Remaining -= float64(r.Hits) - rl.Remaining = 0 + b.Remaining = 0 + rl.Remaining = int64(b.Remaining) rl.ResetTime = now + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } - // If requested is more than available, then return over the limit - // without updating the bucket. + // If requested is more than available, drain bucket in order to converge as everything is returning OVER_LIMIT. if r.Hits > int64(b.Remaining) { metricOverLimitCounter.Add(1) + b.Remaining = 0 + rl.Remaining = int64(b.Remaining) rl.Status = Status_OVER_LIMIT if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { // DRAIN_OVER_LIMIT behavior drains the remaining counter. diff --git a/functional_test.go b/functional_test.go index ca2cbade..066590af 100644 --- a/functional_test.go +++ b/functional_test.go @@ -1014,17 +1014,10 @@ func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) } - // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining sendHit(guber.Status_UNDER_LIMIT, 1) sendHit(guber.Status_UNDER_LIMIT, 1) - // Wait for the broadcast from the owner to the peer time.Sleep(time.Second * 3) - // Since the peer must wait for the owner to say it's over the limit, this will return under the limit. - sendHit(guber.Status_UNDER_LIMIT, 1) - // Wait for the broadcast from the owner to the peer - time.Sleep(time.Second * 3) - // The status should now be OVER_LIMIT - sendHit(guber.Status_OVER_LIMIT, 0) + sendHit(guber.Status_OVER_LIMIT, 1) } func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { diff --git a/global.go b/global.go index fc6c7983..2f44cb8e 100644 --- a/global.go +++ b/global.go @@ -163,6 +163,7 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { fan.Run(func(in interface{}) error { p := in.(*pair) ctx, cancel := context.WithTimeout(context.Background(), gm.conf.GlobalTimeout) + gm.log.Infof("calling owner of key: %s with hits: %d", p.req.Requests[0].UniqueKey, p.req.Requests[0].Hits) _, err := p.client.GetPeerRateLimits(ctx, &p.req) cancel() @@ -239,6 +240,7 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] fan.Run(func(in interface{}) error { peer := in.(*PeerClient) ctx, cancel := context.WithTimeout(ctx, gm.conf.GlobalTimeout) + gm.log.Infof("calling peer of key: %s with hits: %d", req.Globals[0].Key, req.Globals[0].Status.Remaining) _, err := peer.UpdatePeerGlobals(ctx, &req) cancel() diff --git a/gubernator.go b/gubernator.go index 0d094a7c..06653f60 100644 --- a/gubernator.go +++ b/gubernator.go @@ -396,35 +396,23 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) tracing.EndScope(ctx, err) }() - item, ok, err := s.workerPool.GetCacheItem(ctx, req.HashKey()) - if err != nil { - countError(err, "Error in workerPool.GetCacheItem") - return nil, errors.Wrap(err, "during in workerPool.GetCacheItem") - } - if ok { - // Global rate limits are always stored as RateLimitResp regardless of algorithm - rl, ok := item.Value.(*RateLimitResp) + /* + item, ok, err := s.workerPool.GetCacheItem(ctx, req.HashKey()) + if err != nil { + countError(err, "Error in workerPool.GetCacheItem") + return nil, errors.Wrap(err, "during in workerPool.GetCacheItem") + } + if ok { - // In the case we are not the owner, global behavior dictates that we respond with - // what ever the owner has broadcast to use as the response. However, in the case - // of TOKEN_BUCKET it makes little sense to wait for the owner to respond with OVER_LIMIT - // if we already know the remainder is 0. So we check for a remainder of 0 here and set - // OVER_LIMIT only if there are actual hits and this is not a RESET_REMAINING request and - // it's a TOKEN_BUCKET. - // - // We cannot preform this for LEAKY_BUCKET as we don't know how much time or what other requests - // might have influenced the leak rate at the owning peer. - // (Maybe we should preform the leak calculation here?????) - if rl.Remaining == 0 && req.Hits > 0 && !HasBehavior(req.Behavior, Behavior_RESET_REMAINING) && - req.Algorithm == Algorithm_TOKEN_BUCKET { - rl.Status = Status_OVER_LIMIT + // Global rate limits are always stored as RateLimitResp regardless of algorithm + rl, ok := item.Value.(*RateLimitResp) + if ok { + return rl, nil } - return rl, nil + // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and + // our cache still holds the rate limit we created on the first hit. } - // We get here if the owning node hasn't asynchronously forwarded its updates to us yet and - // our cache still holds the rate limit we created on the first hit. - } - + */ cpy := proto.Clone(req).(*RateLimitReq) cpy.Behavior = Behavior_NO_BATCHING @@ -441,13 +429,29 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) // UpdatePeerGlobals updates the local cache with a list of global rate limits. This method should only // be called by a peer who is the owner of a global rate limit. func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) { + now := MillisecondNow() for _, g := range r.Globals { item := &CacheItem{ - ExpireAt: g.Status.ResetTime, + ExpireAt: g.Status.ResetTime + 100000, Algorithm: g.Algorithm, - Value: g.Status, Key: g.Key, } + switch g.Algorithm { + case Algorithm_LEAKY_BUCKET: + item.Value = &LeakyBucketItem{ + Remaining: float64(g.Status.Remaining), + Limit: g.Status.Limit, + Burst: g.Status.Limit, + UpdatedAt: now, + } + case Algorithm_TOKEN_BUCKET: + item.Value = &TokenBucketItem{ + Status: g.Status.Status, + Limit: g.Status.Limit, + Remaining: g.Status.Remaining, + CreatedAt: now, + } + } err := s.workerPool.AddCacheItem(ctx, g.Key, item) if err != nil { return nil, errors.Wrap(err, "Error in workerPool.AddCacheItem") From b0502dee7f86bc8080d3564a58d103bae6fe5acd Mon Sep 17 00:00:00 2001 From: Maria Ines Parnisari Date: Thu, 8 Feb 2024 12:07:07 -0300 Subject: [PATCH 5/6] remove logs and add comment Co-authored-by: Yamil Asusta --- global.go | 2 -- gubernator.go | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/global.go b/global.go index 2f44cb8e..fc6c7983 100644 --- a/global.go +++ b/global.go @@ -163,7 +163,6 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { fan.Run(func(in interface{}) error { p := in.(*pair) ctx, cancel := context.WithTimeout(context.Background(), gm.conf.GlobalTimeout) - gm.log.Infof("calling owner of key: %s with hits: %d", p.req.Requests[0].UniqueKey, p.req.Requests[0].Hits) _, err := p.client.GetPeerRateLimits(ctx, &p.req) cancel() @@ -240,7 +239,6 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] fan.Run(func(in interface{}) error { peer := in.(*PeerClient) ctx, cancel := context.WithTimeout(ctx, gm.conf.GlobalTimeout) - gm.log.Infof("calling peer of key: %s with hits: %d", req.Globals[0].Key, req.Globals[0].Status.Remaining) _, err := peer.UpdatePeerGlobals(ctx, &req) cancel() diff --git a/gubernator.go b/gubernator.go index 06653f60..89e875fb 100644 --- a/gubernator.go +++ b/gubernator.go @@ -432,7 +432,7 @@ func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobals now := MillisecondNow() for _, g := range r.Globals { item := &CacheItem{ - ExpireAt: g.Status.ResetTime + 100000, + ExpireAt: g.Status.ResetTime + 1000, // account for clock drift from owner where `ResetTime` might already be less than current time of the local machine. Algorithm: g.Algorithm, Key: g.Key, } From d96b6751b696084e445c582945ec80a3a9c3af1e Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Fri, 9 Feb 2024 20:44:24 -0600 Subject: [PATCH 6/6] fixed global over-consume issue and improved global testing --- algorithms.go | 22 +- cluster/cluster.go | 42 +++- config.go | 2 + daemon.go | 77 ++++++- functional_test.go | 541 +++++++++++++++++++++++++++++++++------------ global.go | 6 +- gubernator.go | 31 +-- interval_test.go | 9 +- 8 files changed, 548 insertions(+), 182 deletions(-) diff --git a/algorithms.go b/algorithms.go index 9b6d8325..f2ed4a82 100644 --- a/algorithms.go +++ b/algorithms.go @@ -26,6 +26,13 @@ import ( "go.opentelemetry.io/otel/trace" ) +// ### NOTE ### +// The both token and leaky follow the same semantic which allows for requests of more than the limit +// to be rejected, but subsequent requests within the same window that are under the limit to succeed. +// IE: client attempts to send 1000 emails but 100 is their limit. The request is rejected as over the +// limit, but we do not set the remainder to 0 in the cache. The client can retry within the same window +// with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` + // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { @@ -82,12 +89,6 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * ResetTime: 0, }, nil } - - // The following semantic allows for requests of more than the limit to be rejected, but subsequent - // requests within the same duration that are under the limit to succeed. IE: client attempts to - // send 1000 emails but 100 is their limit. The request is rejected as over the limit, but since we - // don't store OVER_LIMIT in the cache the client can retry within the same rate limit duration with - // 100 emails and the request will succeed. t, ok := item.Value.(*TokenBucketItem) if !ok { // Client switched algorithms; perhaps due to a migration? @@ -394,17 +395,18 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * return rl, nil } - // If requested is more than available, drain bucket in order to converge as everything is returning OVER_LIMIT. + // If requested is more than available, then return over the limit + // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. if r.Hits > int64(b.Remaining) { metricOverLimitCounter.Add(1) - b.Remaining = 0 - rl.Remaining = int64(b.Remaining) rl.Status = Status_OVER_LIMIT + + // DRAIN_OVER_LIMIT behavior drains the remaining counter. if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - // DRAIN_OVER_LIMIT behavior drains the remaining counter. b.Remaining = 0 rl.Remaining = 0 } + return rl, nil } diff --git a/cluster/cluster.go b/cluster/cluster.go index bacdde30..4c18efd6 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -77,6 +77,38 @@ func FindOwningPeer(name, key string) (gubernator.PeerInfo, error) { return p.Info(), nil } +// FindOwningDaemon finds the daemon which owns the rate limit with the provided name and unique key +func FindOwningDaemon(name, key string) (*gubernator.Daemon, error) { + p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + if err != nil { + return &gubernator.Daemon{}, err + } + + for i, d := range daemons { + if d.PeerInfo.GRPCAddress == p.Info().GRPCAddress { + return daemons[i], nil + } + } + return &gubernator.Daemon{}, errors.New("unable to find owning daemon") +} + +// ListNonOwningDaemons returns a list of daemons in the cluster that do not own the rate limit +// for the name and key provided. +func ListNonOwningDaemons(name, key string) ([]*gubernator.Daemon, error) { + owner, err := FindOwningDaemon(name, key) + if err != nil { + return []*gubernator.Daemon{}, err + } + + var daemons []*gubernator.Daemon + for _, d := range GetDaemons() { + if d.PeerInfo.GRPCAddress != owner.PeerInfo.GRPCAddress { + daemons = append(daemons, d) + } + } + return daemons, nil +} + // DaemonAt returns a specific daemon func DaemonAt(idx int) *gubernator.Daemon { return daemons[idx] @@ -121,6 +153,7 @@ func StartWith(localPeers []gubernator.PeerInfo) error { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) d, err := gubernator.SpawnDaemon(ctx, gubernator.DaemonConfig{ Logger: logrus.WithField("instance", peer.GRPCAddress), + InstanceID: peer.GRPCAddress, GRPCListenAddress: peer.GRPCAddress, HTTPListenAddress: peer.HTTPAddress, DataCenter: peer.DataCenter, @@ -136,12 +169,15 @@ func StartWith(localPeers []gubernator.PeerInfo) error { return errors.Wrapf(err, "while starting server for addr '%s'", peer.GRPCAddress) } - // Add the peers and daemons to the package level variables - peers = append(peers, gubernator.PeerInfo{ + p := gubernator.PeerInfo{ GRPCAddress: d.GRPCListeners[0].Addr().String(), HTTPAddress: d.HTTPListener.Addr().String(), DataCenter: peer.DataCenter, - }) + } + d.PeerInfo = p + + // Add the peers and daemons to the package level variables + peers = append(peers, p) daemons = append(daemons, d) } diff --git a/config.go b/config.go index 122ffa22..19f9f06f 100644 --- a/config.go +++ b/config.go @@ -71,6 +71,8 @@ type BehaviorConfig struct { // Config for a gubernator instance type Config struct { + InstanceID string + // (Required) A list of GRPC servers to register our instance with GRPCServers []*grpc.Server diff --git a/daemon.go b/daemon.go index a220136b..97602075 100644 --- a/daemon.go +++ b/daemon.go @@ -19,6 +19,7 @@ package gubernator import ( "context" "crypto/tls" + "fmt" "log" "net" "net/http" @@ -40,6 +41,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/resolver" "google.golang.org/protobuf/encoding/protojson" ) @@ -47,6 +49,8 @@ type Daemon struct { GRPCListeners []net.Listener HTTPListener net.Listener V1Server *V1Instance + InstanceID string + PeerInfo PeerInfo log FieldLogger pool PoolInterface @@ -59,6 +63,7 @@ type Daemon struct { promRegister *prometheus.Registry gwCancel context.CancelFunc instanceConf Config + client V1Client } // SpawnDaemon starts a new gubernator daemon according to the provided DaemonConfig. @@ -67,8 +72,9 @@ type Daemon struct { func SpawnDaemon(ctx context.Context, conf DaemonConfig) (*Daemon, error) { s := &Daemon{ - log: conf.Logger, - conf: conf, + InstanceID: conf.InstanceID, + log: conf.Logger, + conf: conf, } return s, s.Start(ctx) } @@ -77,8 +83,8 @@ func (s *Daemon) Start(ctx context.Context) error { var err error setter.SetDefault(&s.log, logrus.WithFields(logrus.Fields{ - "instance-id": s.conf.InstanceID, - "category": "gubernator", + "instance": s.conf.InstanceID, + "category": "gubernator", })) s.promRegister = prometheus.NewRegistry() @@ -148,6 +154,7 @@ func (s *Daemon) Start(ctx context.Context) error { Behaviors: s.conf.Behaviors, CacheSize: s.conf.CacheSize, Workers: s.conf.Workers, + InstanceID: s.conf.InstanceID, } s.V1Server, err = NewV1Instance(s.instanceConf) @@ -411,6 +418,30 @@ func (s *Daemon) Peers() []PeerInfo { return peers } +func (s *Daemon) MustClient() V1Client { + c, err := s.Client() + if err != nil { + panic(fmt.Sprintf("[%s] failed to init daemon client - '%s'", s.InstanceID, err)) + } + return c +} + +func (s *Daemon) Client() (V1Client, error) { + if s.client != nil { + return s.client, nil + } + + conn, err := grpc.DialContext(context.Background(), + fmt.Sprintf("static:///%s", s.PeerInfo.GRPCAddress), + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + s.client = NewV1Client(conn) + return s.client, nil +} + // WaitForConnect returns nil if the list of addresses is listening // for connections; will block until context is cancelled. func WaitForConnect(ctx context.Context, addresses []string) error { @@ -451,3 +482,41 @@ func WaitForConnect(ctx context.Context, addresses []string) error { } return nil } + +type staticBuilder struct{} + +var _ resolver.Builder = (*staticBuilder)(nil) + +func (sb *staticBuilder) Scheme() string { + return "static" +} + +func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + var resolverAddrs []resolver.Address + for _, address := range strings.Split(target.Endpoint(), ",") { + resolverAddrs = append(resolverAddrs, resolver.Address{ + Addr: address, + ServerName: address, + }) + } + if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { + return nil, err + } + return &staticResolver{cc: cc}, nil +} + +// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC +// to connect a specific peer in the cluster. +func newStaticBuilder() resolver.Builder { + return &staticBuilder{} +} + +type staticResolver struct { + cc resolver.ClientConn +} + +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {} + +var _ resolver.Resolver = (*staticResolver)(nil) diff --git a/functional_test.go b/functional_test.go index 066590af..b377e86a 100644 --- a/functional_test.go +++ b/functional_test.go @@ -35,21 +35,9 @@ import ( "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/resolver" json "google.golang.org/protobuf/encoding/protojson" ) -var algos = []struct { - Name string - Algorithm guber.Algorithm -}{ - {Name: "Token bucket", Algorithm: guber.Algorithm_TOKEN_BUCKET}, - {Name: "Leaky bucket", Algorithm: guber.Algorithm_LEAKY_BUCKET}, -} - // Setup and shutdown the mock gubernator cluster for the entire test suite func TestMain(m *testing.M) { if err := cluster.StartWith([]guber.PeerInfo{ @@ -410,8 +398,8 @@ func TestDrainOverLimit(t *testing.T) { }, } - for idx, algoCase := range algos { - t.Run(algoCase.Name, func(t *testing.T) { + for idx, algoCase := range []guber.Algorithm{guber.Algorithm_TOKEN_BUCKET, guber.Algorithm_LEAKY_BUCKET} { + t.Run(guber.Algorithm_name[int32(algoCase)], func(t *testing.T) { for _, test := range tests { ctx := context.Background() t.Run(test.Name, func(t *testing.T) { @@ -420,7 +408,7 @@ func TestDrainOverLimit(t *testing.T) { { Name: "test_drain_over_limit", UniqueKey: fmt.Sprintf("account:1234:%d", idx), - Algorithm: algoCase.Algorithm, + Algorithm: algoCase, Behavior: guber.Behavior_DRAIN_OVER_LIMIT, Duration: guber.Second * 30, Hits: test.Hits, @@ -442,6 +430,49 @@ func TestDrainOverLimit(t *testing.T) { } } +func TestTokenBucketRequestMoreThanAvailable(t *testing.T) { + defer clock.Freeze(clock.Now()).Unfreeze() + + client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + require.NoError(t, err) + + sendHit := func(status guber.Status, remain int64, hit int64) *guber.RateLimitResp { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_token_more_than_available", + UniqueKey: "account:123456", + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Duration: guber.Millisecond * 1000, + Hits: hit, + Limit: 2000, + }, + }, + }) + require.NoError(t, err, hit) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, int64(2000), resp.Responses[0].Limit) + return resp.Responses[0] + } + + // Use half of the bucket + sendHit(guber.Status_UNDER_LIMIT, 1000, 1000) + + // Ask for more than the bucket has and the remainder is still 1000. + // See NOTE in algorithms.go + sendHit(guber.Status_OVER_LIMIT, 1000, 1500) + + // Now other clients can ask for some of the remaining until we hit our limit + sendHit(guber.Status_UNDER_LIMIT, 500, 500) + sendHit(guber.Status_UNDER_LIMIT, 100, 400) + sendHit(guber.Status_UNDER_LIMIT, 0, 100) + sendHit(guber.Status_OVER_LIMIT, 0, 1) +} + func TestLeakyBucket(t *testing.T) { defer clock.Freeze(clock.Now()).Unfreeze() @@ -701,7 +732,7 @@ func TestLeakyBucketGregorian(t *testing.T) { Hits: 1, Remaining: 58, Status: guber.Status_UNDER_LIMIT, - Sleep: clock.Second, + Sleep: clock.Millisecond * 1200, }, { Name: "third hit; leak one hit", @@ -711,7 +742,12 @@ func TestLeakyBucketGregorian(t *testing.T) { }, } + // Truncate to the nearest minute now := clock.Now() + now = now.Truncate(1 * time.Minute) + // So we don't start on the minute boundary + now = now.Add(time.Millisecond * 100) + for _, test := range tests { t.Run(test.Name, func(t *testing.T) { resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ @@ -812,6 +848,50 @@ func TestLeakyBucketNegativeHits(t *testing.T) { } } +func TestLeakyBucketRequestMoreThanAvailable(t *testing.T) { + // Freeze time so we don't leak during the test + defer clock.Freeze(clock.Now()).Unfreeze() + + client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + require.NoError(t, err) + + sendHit := func(status guber.Status, remain int64, hits int64) *guber.RateLimitResp { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_leaky_more_than_available", + UniqueKey: "account:123456", + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Duration: guber.Millisecond * 1000, + Hits: hits, + Limit: 2000, + }, + }, + }) + require.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, int64(2000), resp.Responses[0].Limit) + return resp.Responses[0] + } + + // Use half of the bucket + sendHit(guber.Status_UNDER_LIMIT, 1000, 1000) + + // Ask for more than the rate limit has and the remainder is still 1000. + // See NOTE in algorithms.go + sendHit(guber.Status_OVER_LIMIT, 1000, 1500) + + // Now other clients can ask for some of the remaining until we hit our limit + sendHit(guber.Status_UNDER_LIMIT, 500, 500) + sendHit(guber.Status_UNDER_LIMIT, 100, 400) + sendHit(guber.Status_UNDER_LIMIT, 0, 100) + sendHit(guber.Status_OVER_LIMIT, 0, 1) +} + func TestMissingFields(t *testing.T) { client, errs := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) require.Nil(t, errs) @@ -876,12 +956,16 @@ func TestMissingFields(t *testing.T) { } func TestGlobalRateLimits(t *testing.T) { - peer := cluster.PeerAt(0).GRPCAddress - client, errs := guber.DialV1Server(peer, nil) - require.NoError(t, errs) + const ( + name = "test_global" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) - sendHit := func(status guber.Status, remain int64, i int) string { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*5) + sendHit := func(client guber.V1Client, status guber.Status, hits int64, remain int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ @@ -890,52 +974,47 @@ func TestGlobalRateLimits(t *testing.T) { UniqueKey: "account:12345", Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, - Duration: guber.Second * 3, - Hits: 1, + Duration: guber.Minute * 3, + Hits: hits, Limit: 5, }, }, }) - require.NoError(t, err, i) - assert.Equal(t, "", resp.Responses[0].Error, i) - assert.Equal(t, status, resp.Responses[0].Status, i) - assert.Equal(t, remain, resp.Responses[0].Remaining, i) - assert.Equal(t, int64(5), resp.Responses[0].Limit, i) - - // ensure that we have a canonical host - assert.NotEmpty(t, resp.Responses[0].Metadata["owner"]) - - // name/key should ensure our connected peer is NOT the owner, - // the peer we are connected to should forward requests asynchronously to the owner. - assert.NotEqual(t, peer, resp.Responses[0].Metadata["owner"]) - - return resp.Responses[0].Metadata["owner"] + require.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, int64(5), resp.Responses[0].Limit) } - // Our first hit should create the request on the peer and queue for async forward - sendHit(guber.Status_UNDER_LIMIT, 4, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1, 4) // Our second should be processed as if we own it since the async forward hasn't occurred yet - sendHit(guber.Status_UNDER_LIMIT, 3, 2) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 2, 2) testutil.UntilPass(t, 20, clock.Millisecond*200, func(t testutil.TestingT) { - // Inspect our metrics, ensure they collected the counts we expected during this test - d := cluster.DaemonAt(0) - metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress) - m := getMetricRequest(t, metricsURL, "gubernator_global_send_duration_count") + // Inspect peers metrics, ensure the peer sent the global rate limit to the owner + metricsURL := fmt.Sprintf("http://%s/metrics", peers[0].Config().HTTPListenAddress) + m, err := getMetricRequest(metricsURL, "gubernator_global_send_duration_count") + assert.NoError(t, err) assert.Equal(t, 1, int(m.Value)) + }) + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) - // Expect one peer (the owning peer) to indicate a broadcast. - var broadcastCount int - for i := 0; i < cluster.NumOfDaemons(); i++ { - d := cluster.DaemonAt(i) - metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress) - m := getMetricRequest(t, metricsURL, "gubernator_broadcast_duration_count") - broadcastCount += int(m.Value) - } + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) - assert.Equal(t, 1, broadcastCount) - }) + // Check different peers, they should have gotten the broadcast from the owner + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + + // Non owning peer should calculate the rate limit remaining before forwarding + // to the owner. + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 2, 0) + + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) + + sendHit(peers[4].MustClient(), guber.Status_OVER_LIMIT, 1, 0) } func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { @@ -944,14 +1023,13 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { key = "account:12345" ) - // Make a connection to a peer in the cluster which does not own this rate limit - client, err := getClientToNonOwningPeer(name, key) + peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - sendHit := func(expectedStatus guber.Status, hits int) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + sendHit := func(expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { Name: name, @@ -959,7 +1037,7 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 5, - Hits: 1, + Hits: hits, Limit: 2, }, }, @@ -968,17 +1046,19 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { assert.Equal(t, "", resp.Responses[0].GetError()) assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining sendHit(guber.Status_UNDER_LIMIT, 1) sendHit(guber.Status_UNDER_LIMIT, 1) // Wait for the broadcast from the owner to the peer - time.Sleep(time.Second * 3) + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner // to respond with OVER_LIMIT. sendHit(guber.Status_OVER_LIMIT, 1) // Wait for the broadcast from the owner to the peer - time.Sleep(time.Second * 3) + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) // The status should still be OVER_LIMIT sendHit(guber.Status_OVER_LIMIT, 0) } @@ -989,12 +1069,11 @@ func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { key = "account:12345" ) - // Make a connection to a peer in the cluster which does not own this rate limit - client, err := getClientToNonOwningPeer(name, key) + peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - sendHit := func(expectedStatus guber.Status, hits int) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ @@ -1004,7 +1083,7 @@ func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { Algorithm: guber.Algorithm_LEAKY_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 5, - Hits: 1, + Hits: hits, Limit: 2, }, }, @@ -1013,18 +1092,228 @@ func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { assert.Equal(t, "", resp.Responses[0].GetError()) assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) - sendHit(guber.Status_UNDER_LIMIT, 1) - sendHit(guber.Status_UNDER_LIMIT, 1) - time.Sleep(time.Second * 3) - sendHit(guber.Status_OVER_LIMIT, 1) + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + // Ask a different peer if the status is over the limit + sendHit(peers[1].MustClient(), guber.Status_OVER_LIMIT, 1) } -func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { - resp, err := http.Get(url) +func TestGlobalRequestMoreThanAvailable(t *testing.T) { + const ( + name = "test_global_more_than_available" + key = "account:123456" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: hits, + Limit: 100, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + // Ensure GRPC has connections to each peer before we start, as we want + // the actual test requests to happen quite fast. + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 0, 100) + } + + // Send a request for 50 hits from each non owning peer in the cluster. These requests + // will be queued and sent to the owner as accumulated hits. As a result of the async nature + // of `Behavior_GLOBAL` rate limit requests spread across peers like this will be allowed to + // over-consume their resource within the rate limit window until the owner is updated and + // a broadcast to all peers is received. + // + // The maximum number of resources that can be over-consumed can be calculated by multiplying + // the remainder by the number of peers in the cluster. For example: If you have a remainder of 100 + // and a cluster of 10 instances, then the maximum over-consumed resource is 1,000. If you need + // a more accurate remaining calculation, and wish to avoid over consuming a resource, then do + // not use `Behavior_GLOBAL`. + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + } + + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // We should be over the limit + sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) +} + +func TestGlobalNegativeHits(t *testing.T) { + const ( + name = "test_global_negative_hits" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, status guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 100, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, status, resp.Responses[0].GetStatus()) + assert.Equal(t, remaining, resp.Responses[0].Remaining) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + // Send a negative hit on a rate limit with no hits + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, -1, 3) + + // Wait for the negative remaining to propagate + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // Send another negative hit to a different peer + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, -1, 4) + + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + + // Should have 4 in the remainder + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 4, 0) + + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+3)) + + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 0, 0) +} + +func TestGlobalResetRemaining(t *testing.T) { + const ( + name = "test_global_reset" + key = "account:123456" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: hits, + Limit: 100, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + assert.Equal(t, remaining, resp.Responses[0].Remaining) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + } + + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // We should be over the limit and remaining should be zero + sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) + + // Now reset the remaining + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL | guber.Behavior_RESET_REMAINING, + Duration: guber.Minute * 1_000, + Hits: 0, + Limit: 100, + }, + }, + }) + require.NoError(t, err) + assert.NotEqual(t, 100, resp.Responses[0].Remaining) + + // Wait for the reset to propagate. + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + + // Check a different peer to ensure remaining has been reset + resp, err = peers[1].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: 0, + Limit: 100, + }, + }, + }) + require.NoError(t, err) + assert.NotEqual(t, 100, resp.Responses[0].Remaining) + +} + +func getMetricRequest(url string, name string) (*model.Sample, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } defer resp.Body.Close() - return getMetric(t, resp.Body, name) + return getMetric(resp.Body, name) } func TestChangeLimit(t *testing.T) { @@ -1261,6 +1550,7 @@ func TestHealthCheck(t *testing.T) { } func TestLeakyBucketDivBug(t *testing.T) { + // Freeze time so we don't leak during the test defer clock.Freeze(clock.Now()).Unfreeze() client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) @@ -1408,7 +1698,7 @@ func TestGetPeerRateLimits(t *testing.T) { // TODO: Add a test for sending no rate limits RateLimitReqList.RateLimits = nil -func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { +func getMetric(in io.Reader, name string) (*model.Sample, error) { dec := expfmt.SampleDecoder{ Dec: expfmt.NewDecoder(in, expfmt.FmtText), Opts: &expfmt.DecodeOptions{ @@ -1423,87 +1713,58 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { if err == io.EOF { break } - assert.NoError(t, err) + if err != nil { + return nil, err + } all = append(all, smpls...) } for _, s := range all { if strings.Contains(s.Metric.String(), name) { - return s + return s, nil } } - return nil + return nil, nil } -type staticBuilder struct{} - -var _ resolver.Builder = (*staticBuilder)(nil) - -func (sb *staticBuilder) Scheme() string { - return "static" -} - -func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { - var resolverAddrs []resolver.Address - for _, address := range strings.Split(target.Endpoint(), ",") { - resolverAddrs = append(resolverAddrs, resolver.Address{ - Addr: address, - ServerName: address, - }) - } - if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { - return nil, err +// getBroadcastCount returns the current broadcast count for use with waitForBroadcast() +// TODO: Replace this with something else, we can call and reset via HTTP/GRPC calls in gubernator v3 +func getBroadcastCount(d *guber.Daemon) (int, error) { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return 0, err } - return &staticResolver{cc: cc}, nil -} - -// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC -// to connect a specific peer in the cluster. -func newStaticBuilder() resolver.Builder { - return &staticBuilder{} -} -type staticResolver struct { - cc resolver.ClientConn + return int(m.Value), nil } -func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} - -func (sr *staticResolver) Close() {} - -var _ resolver.Resolver = (*staticResolver)(nil) +// waitForBroadcast waits until the broadcast count for the daemon passed +// changes to the expected value. Returns an error if the expected value is +// not found before the context is cancelled. +func waitForBroadcast(timeout clock.Duration, d *guber.Daemon, expect int) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() -// findNonOwningPeer returns peer info for a peer in the cluster which does not -// own the rate limit for the name and key provided. -func findNonOwningPeer(name, key string) (guber.PeerInfo, error) { - owner, err := cluster.FindOwningPeer(name, key) - if err != nil { - return guber.PeerInfo{}, err - } + for { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return err + } - for _, p := range cluster.GetPeers() { - if p.HashKey() != owner.HashKey() { - return p, nil + // It's possible a broadcast occurred twice if waiting for multiple peer to + // forward updates to the owner. + if int(m.Value) >= expect { + // Give the nodes some time to process the broadcasts + clock.Sleep(clock.Millisecond * 500) + return nil } - } - return guber.PeerInfo{}, fmt.Errorf("unable to find non-owning peer in '%d' node cluster", - len(cluster.GetPeers())) -} -// getClientToNonOwningPeer returns a connection to a peer in the cluster which does not own -// the rate limit for the name and key provided. -func getClientToNonOwningPeer(name, key string) (guber.V1Client, error) { - p, err := findNonOwningPeer(name, key) - if err != nil { - return nil, err - } - conn, err := grpc.DialContext(context.Background(), - fmt.Sprintf("static:///%s", p.GRPCAddress), - grpc.WithResolvers(newStaticBuilder()), - grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - return nil, err + select { + case <-clock.After(time.Millisecond * 800): + case <-ctx.Done(): + return ctx.Err() + } } - return guber.NewV1Client(conn), nil - } diff --git a/global.go b/global.go index fc6c7983..adbd8e44 100644 --- a/global.go +++ b/global.go @@ -98,6 +98,11 @@ func (gm *globalManager) runAsyncHits() { key := r.HashKey() _, ok := hits[key] if ok { + // If any of our hits includes a request to RESET_REMAINING + // ensure the owning peer gets this behavior + if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { + SetBehavior(&hits[key].Behavior, Behavior_RESET_REMAINING, true) + } hits[key].Hits += r.Hits } else { hits[key] = r @@ -145,7 +150,6 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { gm.log.WithError(err).Errorf("while getting peer for hash key '%s'", r.HashKey()) continue } - p, ok := peerRequests[peer.Info().GRPCAddress] if ok { p.req.Requests = append(p.req.Requests, r) diff --git a/gubernator.go b/gubernator.go index 89e875fb..58f3f616 100644 --- a/gubernator.go +++ b/gubernator.go @@ -396,25 +396,9 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) tracing.EndScope(ctx, err) }() - /* - item, ok, err := s.workerPool.GetCacheItem(ctx, req.HashKey()) - if err != nil { - countError(err, "Error in workerPool.GetCacheItem") - return nil, errors.Wrap(err, "during in workerPool.GetCacheItem") - } - - if ok { - // Global rate limits are always stored as RateLimitResp regardless of algorithm - rl, ok := item.Value.(*RateLimitResp) - if ok { - return rl, nil - } - // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and - // our cache still holds the rate limit we created on the first hit. - } - */ cpy := proto.Clone(req).(*RateLimitReq) - cpy.Behavior = Behavior_NO_BATCHING + SetBehavior(&cpy.Behavior, Behavior_NO_BATCHING, true) + SetBehavior(&cpy.Behavior, Behavior_GLOBAL, false) // Process the rate limit like we own it resp, err = s.getLocalRateLimit(ctx, cpy) @@ -432,7 +416,7 @@ func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobals now := MillisecondNow() for _, g := range r.Globals { item := &CacheItem{ - ExpireAt: g.Status.ResetTime + 1000, // account for clock drift from owner where `ResetTime` might already be less than current time of the local machine. + ExpireAt: g.Status.ResetTime, Algorithm: g.Algorithm, Key: g.Key, } @@ -503,6 +487,15 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits // Extract the propagated context from the metadata in the request prop := propagation.TraceContext{} ctx := prop.Extract(ctx, &MetadataCarrier{Map: rin.req.Metadata}) + + // Forwarded global requests must have DRAIN_OVER_LIMIT set so token and leaky algorithms + // drain the remaining in the event a peer asks for more than is remaining. + // This is needed because with GLOBAL behavior peers will accumulate hits, which could + // result in requesting more hits than is remaining. + if HasBehavior(rin.req.Behavior, Behavior_GLOBAL) { + SetBehavior(&rin.req.Behavior, Behavior_DRAIN_OVER_LIMIT, true) + } + rl, err := s.getLocalRateLimit(ctx, rin.req) if err != nil { // Return the error for this request diff --git a/interval_test.go b/interval_test.go index 89642c3e..68c8b40d 100644 --- a/interval_test.go +++ b/interval_test.go @@ -18,9 +18,8 @@ package gubernator_test import ( "testing" - "time" - gubernator "github.com/mailgun/gubernator/v2" + "github.com/mailgun/gubernator/v2" "github.com/mailgun/holster/v4/clock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,18 +27,18 @@ import ( func TestInterval(t *testing.T) { t.Run("Happy path", func(t *testing.T) { - interval := gubernator.NewInterval(10 * time.Millisecond) + interval := gubernator.NewInterval(10 * clock.Millisecond) defer interval.Stop() interval.Next() assert.Empty(t, interval.C) - time.Sleep(10 * time.Millisecond) + clock.Sleep(10 * clock.Millisecond) // Wait for tick. select { case <-interval.C: - case <-time.After(100 * time.Millisecond): + case <-clock.After(100 * clock.Millisecond): require.Fail(t, "timeout") } })