From d96b6751b696084e445c582945ec80a3a9c3af1e Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Fri, 9 Feb 2024 20:44:24 -0600 Subject: [PATCH] fixed global over-consume issue and improved global testing --- algorithms.go | 22 +- cluster/cluster.go | 42 +++- config.go | 2 + daemon.go | 77 ++++++- functional_test.go | 541 +++++++++++++++++++++++++++++++++------------ global.go | 6 +- gubernator.go | 31 +-- interval_test.go | 9 +- 8 files changed, 548 insertions(+), 182 deletions(-) diff --git a/algorithms.go b/algorithms.go index 9b6d8325..f2ed4a82 100644 --- a/algorithms.go +++ b/algorithms.go @@ -26,6 +26,13 @@ import ( "go.opentelemetry.io/otel/trace" ) +// ### NOTE ### +// The both token and leaky follow the same semantic which allows for requests of more than the limit +// to be rejected, but subsequent requests within the same window that are under the limit to succeed. +// IE: client attempts to send 1000 emails but 100 is their limit. The request is rejected as over the +// limit, but we do not set the remainder to 0 in the cache. The client can retry within the same window +// with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` + // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp *RateLimitResp, err error) { @@ -82,12 +89,6 @@ func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * ResetTime: 0, }, nil } - - // The following semantic allows for requests of more than the limit to be rejected, but subsequent - // requests within the same duration that are under the limit to succeed. IE: client attempts to - // send 1000 emails but 100 is their limit. The request is rejected as over the limit, but since we - // don't store OVER_LIMIT in the cache the client can retry within the same rate limit duration with - // 100 emails and the request will succeed. t, ok := item.Value.(*TokenBucketItem) if !ok { // Client switched algorithms; perhaps due to a migration? @@ -394,17 +395,18 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq) (resp * return rl, nil } - // If requested is more than available, drain bucket in order to converge as everything is returning OVER_LIMIT. + // If requested is more than available, then return over the limit + // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. if r.Hits > int64(b.Remaining) { metricOverLimitCounter.Add(1) - b.Remaining = 0 - rl.Remaining = int64(b.Remaining) rl.Status = Status_OVER_LIMIT + + // DRAIN_OVER_LIMIT behavior drains the remaining counter. if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - // DRAIN_OVER_LIMIT behavior drains the remaining counter. b.Remaining = 0 rl.Remaining = 0 } + return rl, nil } diff --git a/cluster/cluster.go b/cluster/cluster.go index bacdde30..4c18efd6 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -77,6 +77,38 @@ func FindOwningPeer(name, key string) (gubernator.PeerInfo, error) { return p.Info(), nil } +// FindOwningDaemon finds the daemon which owns the rate limit with the provided name and unique key +func FindOwningDaemon(name, key string) (*gubernator.Daemon, error) { + p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + if err != nil { + return &gubernator.Daemon{}, err + } + + for i, d := range daemons { + if d.PeerInfo.GRPCAddress == p.Info().GRPCAddress { + return daemons[i], nil + } + } + return &gubernator.Daemon{}, errors.New("unable to find owning daemon") +} + +// ListNonOwningDaemons returns a list of daemons in the cluster that do not own the rate limit +// for the name and key provided. +func ListNonOwningDaemons(name, key string) ([]*gubernator.Daemon, error) { + owner, err := FindOwningDaemon(name, key) + if err != nil { + return []*gubernator.Daemon{}, err + } + + var daemons []*gubernator.Daemon + for _, d := range GetDaemons() { + if d.PeerInfo.GRPCAddress != owner.PeerInfo.GRPCAddress { + daemons = append(daemons, d) + } + } + return daemons, nil +} + // DaemonAt returns a specific daemon func DaemonAt(idx int) *gubernator.Daemon { return daemons[idx] @@ -121,6 +153,7 @@ func StartWith(localPeers []gubernator.PeerInfo) error { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) d, err := gubernator.SpawnDaemon(ctx, gubernator.DaemonConfig{ Logger: logrus.WithField("instance", peer.GRPCAddress), + InstanceID: peer.GRPCAddress, GRPCListenAddress: peer.GRPCAddress, HTTPListenAddress: peer.HTTPAddress, DataCenter: peer.DataCenter, @@ -136,12 +169,15 @@ func StartWith(localPeers []gubernator.PeerInfo) error { return errors.Wrapf(err, "while starting server for addr '%s'", peer.GRPCAddress) } - // Add the peers and daemons to the package level variables - peers = append(peers, gubernator.PeerInfo{ + p := gubernator.PeerInfo{ GRPCAddress: d.GRPCListeners[0].Addr().String(), HTTPAddress: d.HTTPListener.Addr().String(), DataCenter: peer.DataCenter, - }) + } + d.PeerInfo = p + + // Add the peers and daemons to the package level variables + peers = append(peers, p) daemons = append(daemons, d) } diff --git a/config.go b/config.go index 122ffa22..19f9f06f 100644 --- a/config.go +++ b/config.go @@ -71,6 +71,8 @@ type BehaviorConfig struct { // Config for a gubernator instance type Config struct { + InstanceID string + // (Required) A list of GRPC servers to register our instance with GRPCServers []*grpc.Server diff --git a/daemon.go b/daemon.go index a220136b..97602075 100644 --- a/daemon.go +++ b/daemon.go @@ -19,6 +19,7 @@ package gubernator import ( "context" "crypto/tls" + "fmt" "log" "net" "net/http" @@ -40,6 +41,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/resolver" "google.golang.org/protobuf/encoding/protojson" ) @@ -47,6 +49,8 @@ type Daemon struct { GRPCListeners []net.Listener HTTPListener net.Listener V1Server *V1Instance + InstanceID string + PeerInfo PeerInfo log FieldLogger pool PoolInterface @@ -59,6 +63,7 @@ type Daemon struct { promRegister *prometheus.Registry gwCancel context.CancelFunc instanceConf Config + client V1Client } // SpawnDaemon starts a new gubernator daemon according to the provided DaemonConfig. @@ -67,8 +72,9 @@ type Daemon struct { func SpawnDaemon(ctx context.Context, conf DaemonConfig) (*Daemon, error) { s := &Daemon{ - log: conf.Logger, - conf: conf, + InstanceID: conf.InstanceID, + log: conf.Logger, + conf: conf, } return s, s.Start(ctx) } @@ -77,8 +83,8 @@ func (s *Daemon) Start(ctx context.Context) error { var err error setter.SetDefault(&s.log, logrus.WithFields(logrus.Fields{ - "instance-id": s.conf.InstanceID, - "category": "gubernator", + "instance": s.conf.InstanceID, + "category": "gubernator", })) s.promRegister = prometheus.NewRegistry() @@ -148,6 +154,7 @@ func (s *Daemon) Start(ctx context.Context) error { Behaviors: s.conf.Behaviors, CacheSize: s.conf.CacheSize, Workers: s.conf.Workers, + InstanceID: s.conf.InstanceID, } s.V1Server, err = NewV1Instance(s.instanceConf) @@ -411,6 +418,30 @@ func (s *Daemon) Peers() []PeerInfo { return peers } +func (s *Daemon) MustClient() V1Client { + c, err := s.Client() + if err != nil { + panic(fmt.Sprintf("[%s] failed to init daemon client - '%s'", s.InstanceID, err)) + } + return c +} + +func (s *Daemon) Client() (V1Client, error) { + if s.client != nil { + return s.client, nil + } + + conn, err := grpc.DialContext(context.Background(), + fmt.Sprintf("static:///%s", s.PeerInfo.GRPCAddress), + grpc.WithResolvers(newStaticBuilder()), + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + s.client = NewV1Client(conn) + return s.client, nil +} + // WaitForConnect returns nil if the list of addresses is listening // for connections; will block until context is cancelled. func WaitForConnect(ctx context.Context, addresses []string) error { @@ -451,3 +482,41 @@ func WaitForConnect(ctx context.Context, addresses []string) error { } return nil } + +type staticBuilder struct{} + +var _ resolver.Builder = (*staticBuilder)(nil) + +func (sb *staticBuilder) Scheme() string { + return "static" +} + +func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { + var resolverAddrs []resolver.Address + for _, address := range strings.Split(target.Endpoint(), ",") { + resolverAddrs = append(resolverAddrs, resolver.Address{ + Addr: address, + ServerName: address, + }) + } + if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { + return nil, err + } + return &staticResolver{cc: cc}, nil +} + +// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC +// to connect a specific peer in the cluster. +func newStaticBuilder() resolver.Builder { + return &staticBuilder{} +} + +type staticResolver struct { + cc resolver.ClientConn +} + +func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} + +func (sr *staticResolver) Close() {} + +var _ resolver.Resolver = (*staticResolver)(nil) diff --git a/functional_test.go b/functional_test.go index 066590af..b377e86a 100644 --- a/functional_test.go +++ b/functional_test.go @@ -35,21 +35,9 @@ import ( "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/resolver" json "google.golang.org/protobuf/encoding/protojson" ) -var algos = []struct { - Name string - Algorithm guber.Algorithm -}{ - {Name: "Token bucket", Algorithm: guber.Algorithm_TOKEN_BUCKET}, - {Name: "Leaky bucket", Algorithm: guber.Algorithm_LEAKY_BUCKET}, -} - // Setup and shutdown the mock gubernator cluster for the entire test suite func TestMain(m *testing.M) { if err := cluster.StartWith([]guber.PeerInfo{ @@ -410,8 +398,8 @@ func TestDrainOverLimit(t *testing.T) { }, } - for idx, algoCase := range algos { - t.Run(algoCase.Name, func(t *testing.T) { + for idx, algoCase := range []guber.Algorithm{guber.Algorithm_TOKEN_BUCKET, guber.Algorithm_LEAKY_BUCKET} { + t.Run(guber.Algorithm_name[int32(algoCase)], func(t *testing.T) { for _, test := range tests { ctx := context.Background() t.Run(test.Name, func(t *testing.T) { @@ -420,7 +408,7 @@ func TestDrainOverLimit(t *testing.T) { { Name: "test_drain_over_limit", UniqueKey: fmt.Sprintf("account:1234:%d", idx), - Algorithm: algoCase.Algorithm, + Algorithm: algoCase, Behavior: guber.Behavior_DRAIN_OVER_LIMIT, Duration: guber.Second * 30, Hits: test.Hits, @@ -442,6 +430,49 @@ func TestDrainOverLimit(t *testing.T) { } } +func TestTokenBucketRequestMoreThanAvailable(t *testing.T) { + defer clock.Freeze(clock.Now()).Unfreeze() + + client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + require.NoError(t, err) + + sendHit := func(status guber.Status, remain int64, hit int64) *guber.RateLimitResp { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_token_more_than_available", + UniqueKey: "account:123456", + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Duration: guber.Millisecond * 1000, + Hits: hit, + Limit: 2000, + }, + }, + }) + require.NoError(t, err, hit) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, int64(2000), resp.Responses[0].Limit) + return resp.Responses[0] + } + + // Use half of the bucket + sendHit(guber.Status_UNDER_LIMIT, 1000, 1000) + + // Ask for more than the bucket has and the remainder is still 1000. + // See NOTE in algorithms.go + sendHit(guber.Status_OVER_LIMIT, 1000, 1500) + + // Now other clients can ask for some of the remaining until we hit our limit + sendHit(guber.Status_UNDER_LIMIT, 500, 500) + sendHit(guber.Status_UNDER_LIMIT, 100, 400) + sendHit(guber.Status_UNDER_LIMIT, 0, 100) + sendHit(guber.Status_OVER_LIMIT, 0, 1) +} + func TestLeakyBucket(t *testing.T) { defer clock.Freeze(clock.Now()).Unfreeze() @@ -701,7 +732,7 @@ func TestLeakyBucketGregorian(t *testing.T) { Hits: 1, Remaining: 58, Status: guber.Status_UNDER_LIMIT, - Sleep: clock.Second, + Sleep: clock.Millisecond * 1200, }, { Name: "third hit; leak one hit", @@ -711,7 +742,12 @@ func TestLeakyBucketGregorian(t *testing.T) { }, } + // Truncate to the nearest minute now := clock.Now() + now = now.Truncate(1 * time.Minute) + // So we don't start on the minute boundary + now = now.Add(time.Millisecond * 100) + for _, test := range tests { t.Run(test.Name, func(t *testing.T) { resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ @@ -812,6 +848,50 @@ func TestLeakyBucketNegativeHits(t *testing.T) { } } +func TestLeakyBucketRequestMoreThanAvailable(t *testing.T) { + // Freeze time so we don't leak during the test + defer clock.Freeze(clock.Now()).Unfreeze() + + client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + require.NoError(t, err) + + sendHit := func(status guber.Status, remain int64, hits int64) *guber.RateLimitResp { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: "test_leaky_more_than_available", + UniqueKey: "account:123456", + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Duration: guber.Millisecond * 1000, + Hits: hits, + Limit: 2000, + }, + }, + }) + require.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, int64(2000), resp.Responses[0].Limit) + return resp.Responses[0] + } + + // Use half of the bucket + sendHit(guber.Status_UNDER_LIMIT, 1000, 1000) + + // Ask for more than the rate limit has and the remainder is still 1000. + // See NOTE in algorithms.go + sendHit(guber.Status_OVER_LIMIT, 1000, 1500) + + // Now other clients can ask for some of the remaining until we hit our limit + sendHit(guber.Status_UNDER_LIMIT, 500, 500) + sendHit(guber.Status_UNDER_LIMIT, 100, 400) + sendHit(guber.Status_UNDER_LIMIT, 0, 100) + sendHit(guber.Status_OVER_LIMIT, 0, 1) +} + func TestMissingFields(t *testing.T) { client, errs := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) require.Nil(t, errs) @@ -876,12 +956,16 @@ func TestMissingFields(t *testing.T) { } func TestGlobalRateLimits(t *testing.T) { - peer := cluster.PeerAt(0).GRPCAddress - client, errs := guber.DialV1Server(peer, nil) - require.NoError(t, errs) + const ( + name = "test_global" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) - sendHit := func(status guber.Status, remain int64, i int) string { - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*5) + sendHit := func(client guber.V1Client, status guber.Status, hits int64, remain int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ @@ -890,52 +974,47 @@ func TestGlobalRateLimits(t *testing.T) { UniqueKey: "account:12345", Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, - Duration: guber.Second * 3, - Hits: 1, + Duration: guber.Minute * 3, + Hits: hits, Limit: 5, }, }, }) - require.NoError(t, err, i) - assert.Equal(t, "", resp.Responses[0].Error, i) - assert.Equal(t, status, resp.Responses[0].Status, i) - assert.Equal(t, remain, resp.Responses[0].Remaining, i) - assert.Equal(t, int64(5), resp.Responses[0].Limit, i) - - // ensure that we have a canonical host - assert.NotEmpty(t, resp.Responses[0].Metadata["owner"]) - - // name/key should ensure our connected peer is NOT the owner, - // the peer we are connected to should forward requests asynchronously to the owner. - assert.NotEqual(t, peer, resp.Responses[0].Metadata["owner"]) - - return resp.Responses[0].Metadata["owner"] + require.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].Error) + assert.Equal(t, remain, resp.Responses[0].Remaining) + assert.Equal(t, status, resp.Responses[0].Status) + assert.Equal(t, int64(5), resp.Responses[0].Limit) } - // Our first hit should create the request on the peer and queue for async forward - sendHit(guber.Status_UNDER_LIMIT, 4, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1, 4) // Our second should be processed as if we own it since the async forward hasn't occurred yet - sendHit(guber.Status_UNDER_LIMIT, 3, 2) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 2, 2) testutil.UntilPass(t, 20, clock.Millisecond*200, func(t testutil.TestingT) { - // Inspect our metrics, ensure they collected the counts we expected during this test - d := cluster.DaemonAt(0) - metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress) - m := getMetricRequest(t, metricsURL, "gubernator_global_send_duration_count") + // Inspect peers metrics, ensure the peer sent the global rate limit to the owner + metricsURL := fmt.Sprintf("http://%s/metrics", peers[0].Config().HTTPListenAddress) + m, err := getMetricRequest(metricsURL, "gubernator_global_send_duration_count") + assert.NoError(t, err) assert.Equal(t, 1, int(m.Value)) + }) + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) - // Expect one peer (the owning peer) to indicate a broadcast. - var broadcastCount int - for i := 0; i < cluster.NumOfDaemons(); i++ { - d := cluster.DaemonAt(i) - metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress) - m := getMetricRequest(t, metricsURL, "gubernator_broadcast_duration_count") - broadcastCount += int(m.Value) - } + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) - assert.Equal(t, 1, broadcastCount) - }) + // Check different peers, they should have gotten the broadcast from the owner + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 0, 2) + + // Non owning peer should calculate the rate limit remaining before forwarding + // to the owner. + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 2, 0) + + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) + + sendHit(peers[4].MustClient(), guber.Status_OVER_LIMIT, 1, 0) } func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { @@ -944,14 +1023,13 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { key = "account:12345" ) - // Make a connection to a peer in the cluster which does not own this rate limit - client, err := getClientToNonOwningPeer(name, key) + peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - sendHit := func(expectedStatus guber.Status, hits int) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + sendHit := func(expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ { Name: name, @@ -959,7 +1037,7 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 5, - Hits: 1, + Hits: hits, Limit: 2, }, }, @@ -968,17 +1046,19 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { assert.Equal(t, "", resp.Responses[0].GetError()) assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining sendHit(guber.Status_UNDER_LIMIT, 1) sendHit(guber.Status_UNDER_LIMIT, 1) // Wait for the broadcast from the owner to the peer - time.Sleep(time.Second * 3) + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) // Since the remainder is 0, the peer should set OVER_LIMIT instead of waiting for the owner // to respond with OVER_LIMIT. sendHit(guber.Status_OVER_LIMIT, 1) // Wait for the broadcast from the owner to the peer - time.Sleep(time.Second * 3) + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 2)) // The status should still be OVER_LIMIT sendHit(guber.Status_OVER_LIMIT, 0) } @@ -989,12 +1069,11 @@ func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { key = "account:12345" ) - // Make a connection to a peer in the cluster which does not own this rate limit - client, err := getClientToNonOwningPeer(name, key) + peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - sendHit := func(expectedStatus guber.Status, hits int) { - ctx, cancel := context.WithTimeout(context.Background(), clock.Hour*5) + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ Requests: []*guber.RateLimitReq{ @@ -1004,7 +1083,7 @@ func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { Algorithm: guber.Algorithm_LEAKY_BUCKET, Behavior: guber.Behavior_GLOBAL, Duration: guber.Minute * 5, - Hits: 1, + Hits: hits, Limit: 2, }, }, @@ -1013,18 +1092,228 @@ func TestGlobalRateLimitsPeerOverLimitLeaky(t *testing.T) { assert.Equal(t, "", resp.Responses[0].GetError()) assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) - sendHit(guber.Status_UNDER_LIMIT, 1) - sendHit(guber.Status_UNDER_LIMIT, 1) - time.Sleep(time.Second * 3) - sendHit(guber.Status_OVER_LIMIT, 1) + // Send two hits that should be processed by the owner and the broadcast to peer, depleting the remaining + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 1) + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) + // Ask a different peer if the status is over the limit + sendHit(peers[1].MustClient(), guber.Status_OVER_LIMIT, 1) } -func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample { - resp, err := http.Get(url) +func TestGlobalRequestMoreThanAvailable(t *testing.T) { + const ( + name = "test_global_more_than_available" + key = "account:123456" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: hits, + Limit: 100, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + // Ensure GRPC has connections to each peer before we start, as we want + // the actual test requests to happen quite fast. + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 0, 100) + } + + // Send a request for 50 hits from each non owning peer in the cluster. These requests + // will be queued and sent to the owner as accumulated hits. As a result of the async nature + // of `Behavior_GLOBAL` rate limit requests spread across peers like this will be allowed to + // over-consume their resource within the rate limit window until the owner is updated and + // a broadcast to all peers is received. + // + // The maximum number of resources that can be over-consumed can be calculated by multiplying + // the remainder by the number of peers in the cluster. For example: If you have a remainder of 100 + // and a cluster of 10 instances, then the maximum over-consumed resource is 1,000. If you need + // a more accurate remaining calculation, and wish to avoid over consuming a resource, then do + // not use `Behavior_GLOBAL`. + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + } + + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // We should be over the limit + sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) +} + +func TestGlobalNegativeHits(t *testing.T) { + const ( + name = "test_global_negative_hits" + key = "account:12345" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) + require.NoError(t, err) + + sendHit := func(client guber.V1Client, status guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 100, + Hits: hits, + Limit: 2, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, status, resp.Responses[0].GetStatus()) + assert.Equal(t, remaining, resp.Responses[0].Remaining) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + // Send a negative hit on a rate limit with no hits + sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, -1, 3) + + // Wait for the negative remaining to propagate + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // Send another negative hit to a different peer + sendHit(peers[1].MustClient(), guber.Status_UNDER_LIMIT, -1, 4) + + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + + // Should have 4 in the remainder + sendHit(peers[2].MustClient(), guber.Status_UNDER_LIMIT, 4, 0) + + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+3)) + + sendHit(peers[3].MustClient(), guber.Status_UNDER_LIMIT, 0, 0) +} + +func TestGlobalResetRemaining(t *testing.T) { + const ( + name = "test_global_reset" + key = "account:123456" + ) + + peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) + + sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: hits, + Limit: 100, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, "", resp.Responses[0].GetError()) + assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) + assert.Equal(t, remaining, resp.Responses[0].Remaining) + } + owner, err := cluster.FindOwningDaemon(name, key) + require.NoError(t, err) + prev, err := getBroadcastCount(owner) + require.NoError(t, err) + + for _, p := range peers { + sendHit(p.MustClient(), guber.Status_UNDER_LIMIT, 50, 50) + } + + // Wait for the broadcast from the owner to the peers + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+1)) + + // We should be over the limit and remaining should be zero + sendHit(peers[0].MustClient(), guber.Status_OVER_LIMIT, 1, 0) + + // Now reset the remaining + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + defer cancel() + resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL | guber.Behavior_RESET_REMAINING, + Duration: guber.Minute * 1_000, + Hits: 0, + Limit: 100, + }, + }, + }) + require.NoError(t, err) + assert.NotEqual(t, 100, resp.Responses[0].Remaining) + + // Wait for the reset to propagate. + require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) + + // Check a different peer to ensure remaining has been reset + resp, err = peers[1].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ + Requests: []*guber.RateLimitReq{ + { + Name: name, + UniqueKey: key, + Algorithm: guber.Algorithm_LEAKY_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute * 1_000, + Hits: 0, + Limit: 100, + }, + }, + }) + require.NoError(t, err) + assert.NotEqual(t, 100, resp.Responses[0].Remaining) + +} + +func getMetricRequest(url string, name string) (*model.Sample, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } defer resp.Body.Close() - return getMetric(t, resp.Body, name) + return getMetric(resp.Body, name) } func TestChangeLimit(t *testing.T) { @@ -1261,6 +1550,7 @@ func TestHealthCheck(t *testing.T) { } func TestLeakyBucketDivBug(t *testing.T) { + // Freeze time so we don't leak during the test defer clock.Freeze(clock.Now()).Unfreeze() client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) @@ -1408,7 +1698,7 @@ func TestGetPeerRateLimits(t *testing.T) { // TODO: Add a test for sending no rate limits RateLimitReqList.RateLimits = nil -func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { +func getMetric(in io.Reader, name string) (*model.Sample, error) { dec := expfmt.SampleDecoder{ Dec: expfmt.NewDecoder(in, expfmt.FmtText), Opts: &expfmt.DecodeOptions{ @@ -1423,87 +1713,58 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample { if err == io.EOF { break } - assert.NoError(t, err) + if err != nil { + return nil, err + } all = append(all, smpls...) } for _, s := range all { if strings.Contains(s.Metric.String(), name) { - return s + return s, nil } } - return nil + return nil, nil } -type staticBuilder struct{} - -var _ resolver.Builder = (*staticBuilder)(nil) - -func (sb *staticBuilder) Scheme() string { - return "static" -} - -func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { - var resolverAddrs []resolver.Address - for _, address := range strings.Split(target.Endpoint(), ",") { - resolverAddrs = append(resolverAddrs, resolver.Address{ - Addr: address, - ServerName: address, - }) - } - if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { - return nil, err +// getBroadcastCount returns the current broadcast count for use with waitForBroadcast() +// TODO: Replace this with something else, we can call and reset via HTTP/GRPC calls in gubernator v3 +func getBroadcastCount(d *guber.Daemon) (int, error) { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return 0, err } - return &staticResolver{cc: cc}, nil -} - -// newStaticBuilder returns a builder which returns a staticResolver that tells GRPC -// to connect a specific peer in the cluster. -func newStaticBuilder() resolver.Builder { - return &staticBuilder{} -} -type staticResolver struct { - cc resolver.ClientConn + return int(m.Value), nil } -func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} - -func (sr *staticResolver) Close() {} - -var _ resolver.Resolver = (*staticResolver)(nil) +// waitForBroadcast waits until the broadcast count for the daemon passed +// changes to the expected value. Returns an error if the expected value is +// not found before the context is cancelled. +func waitForBroadcast(timeout clock.Duration, d *guber.Daemon, expect int) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() -// findNonOwningPeer returns peer info for a peer in the cluster which does not -// own the rate limit for the name and key provided. -func findNonOwningPeer(name, key string) (guber.PeerInfo, error) { - owner, err := cluster.FindOwningPeer(name, key) - if err != nil { - return guber.PeerInfo{}, err - } + for { + m, err := getMetricRequest(fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress), + "gubernator_broadcast_duration_count") + if err != nil { + return err + } - for _, p := range cluster.GetPeers() { - if p.HashKey() != owner.HashKey() { - return p, nil + // It's possible a broadcast occurred twice if waiting for multiple peer to + // forward updates to the owner. + if int(m.Value) >= expect { + // Give the nodes some time to process the broadcasts + clock.Sleep(clock.Millisecond * 500) + return nil } - } - return guber.PeerInfo{}, fmt.Errorf("unable to find non-owning peer in '%d' node cluster", - len(cluster.GetPeers())) -} -// getClientToNonOwningPeer returns a connection to a peer in the cluster which does not own -// the rate limit for the name and key provided. -func getClientToNonOwningPeer(name, key string) (guber.V1Client, error) { - p, err := findNonOwningPeer(name, key) - if err != nil { - return nil, err - } - conn, err := grpc.DialContext(context.Background(), - fmt.Sprintf("static:///%s", p.GRPCAddress), - grpc.WithResolvers(newStaticBuilder()), - grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - return nil, err + select { + case <-clock.After(time.Millisecond * 800): + case <-ctx.Done(): + return ctx.Err() + } } - return guber.NewV1Client(conn), nil - } diff --git a/global.go b/global.go index fc6c7983..adbd8e44 100644 --- a/global.go +++ b/global.go @@ -98,6 +98,11 @@ func (gm *globalManager) runAsyncHits() { key := r.HashKey() _, ok := hits[key] if ok { + // If any of our hits includes a request to RESET_REMAINING + // ensure the owning peer gets this behavior + if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { + SetBehavior(&hits[key].Behavior, Behavior_RESET_REMAINING, true) + } hits[key].Hits += r.Hits } else { hits[key] = r @@ -145,7 +150,6 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { gm.log.WithError(err).Errorf("while getting peer for hash key '%s'", r.HashKey()) continue } - p, ok := peerRequests[peer.Info().GRPCAddress] if ok { p.req.Requests = append(p.req.Requests, r) diff --git a/gubernator.go b/gubernator.go index 89e875fb..58f3f616 100644 --- a/gubernator.go +++ b/gubernator.go @@ -396,25 +396,9 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) tracing.EndScope(ctx, err) }() - /* - item, ok, err := s.workerPool.GetCacheItem(ctx, req.HashKey()) - if err != nil { - countError(err, "Error in workerPool.GetCacheItem") - return nil, errors.Wrap(err, "during in workerPool.GetCacheItem") - } - - if ok { - // Global rate limits are always stored as RateLimitResp regardless of algorithm - rl, ok := item.Value.(*RateLimitResp) - if ok { - return rl, nil - } - // We get here if the owning node hasn't asynchronously forwarded it's updates to us yet and - // our cache still holds the rate limit we created on the first hit. - } - */ cpy := proto.Clone(req).(*RateLimitReq) - cpy.Behavior = Behavior_NO_BATCHING + SetBehavior(&cpy.Behavior, Behavior_NO_BATCHING, true) + SetBehavior(&cpy.Behavior, Behavior_GLOBAL, false) // Process the rate limit like we own it resp, err = s.getLocalRateLimit(ctx, cpy) @@ -432,7 +416,7 @@ func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobals now := MillisecondNow() for _, g := range r.Globals { item := &CacheItem{ - ExpireAt: g.Status.ResetTime + 1000, // account for clock drift from owner where `ResetTime` might already be less than current time of the local machine. + ExpireAt: g.Status.ResetTime, Algorithm: g.Algorithm, Key: g.Key, } @@ -503,6 +487,15 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits // Extract the propagated context from the metadata in the request prop := propagation.TraceContext{} ctx := prop.Extract(ctx, &MetadataCarrier{Map: rin.req.Metadata}) + + // Forwarded global requests must have DRAIN_OVER_LIMIT set so token and leaky algorithms + // drain the remaining in the event a peer asks for more than is remaining. + // This is needed because with GLOBAL behavior peers will accumulate hits, which could + // result in requesting more hits than is remaining. + if HasBehavior(rin.req.Behavior, Behavior_GLOBAL) { + SetBehavior(&rin.req.Behavior, Behavior_DRAIN_OVER_LIMIT, true) + } + rl, err := s.getLocalRateLimit(ctx, rin.req) if err != nil { // Return the error for this request diff --git a/interval_test.go b/interval_test.go index 89642c3e..68c8b40d 100644 --- a/interval_test.go +++ b/interval_test.go @@ -18,9 +18,8 @@ package gubernator_test import ( "testing" - "time" - gubernator "github.com/mailgun/gubernator/v2" + "github.com/mailgun/gubernator/v2" "github.com/mailgun/holster/v4/clock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -28,18 +27,18 @@ import ( func TestInterval(t *testing.T) { t.Run("Happy path", func(t *testing.T) { - interval := gubernator.NewInterval(10 * time.Millisecond) + interval := gubernator.NewInterval(10 * clock.Millisecond) defer interval.Stop() interval.Next() assert.Empty(t, interval.C) - time.Sleep(10 * time.Millisecond) + clock.Sleep(10 * clock.Millisecond) // Wait for tick. select { case <-interval.C: - case <-time.After(100 * time.Millisecond): + case <-clock.After(100 * clock.Millisecond): require.Fail(t, "timeout") } })