Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Commit

Permalink
test: Add test for global rate limiting with load balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
philipgough committed Dec 12, 2023
1 parent 885519d commit 9074374
Showing 1 changed file with 125 additions and 0 deletions.
125 changes: 125 additions & 0 deletions functional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ import (
"github.com/prometheus/common/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
json "google.golang.org/protobuf/encoding/protojson"
)

Expand Down Expand Up @@ -859,6 +863,84 @@ func TestGlobalRateLimits(t *testing.T) {
})
}

func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) {
peer := cluster.PeerAt(0).GRPCAddress
peerTwo := cluster.PeerAt(1).GRPCAddress

dialOpts := []grpc.DialOption{
grpc.WithResolvers(newStaticBuilder()),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`),
}

address := fmt.Sprintf("static:///%s,%s", peer, peerTwo)
conn, err := grpc.DialContext(context.Background(), address, dialOpts...)
require.NoError(t, err)

client := guber.NewV1Client(conn)

sendHit := func(status guber.Status, remain int64, i int) string {
ctx, cancel := context.WithTimeout(context.Background(), clock.Second*5)
defer cancel()
resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{
Requests: []*guber.RateLimitReq{
{
Name: "test_global",
UniqueKey: "account:12345",
Algorithm: guber.Algorithm_TOKEN_BUCKET,
Behavior: guber.Behavior_GLOBAL,
Duration: guber.Second * 3,
Hits: 1,
Limit: 5,
},
},
})
require.NoError(t, err, i)
assert.Equal(t, "", resp.Responses[0].Error, i)
assert.Equal(t, status, resp.Responses[0].Status, i)
assert.Equal(t, remain, resp.Responses[0].Remaining, i)
assert.Equal(t, int64(5), resp.Responses[0].Limit, i)

// ensure that we have a canonical host
assert.NotEmpty(t, resp.Responses[0].Metadata["owner"])

// name/key should ensure our connected peer is NOT the owner,
// the peer we are connected to should forward requests asynchronously to the owner.
assert.NotEqual(t, peer, resp.Responses[0].Metadata["owner"])

return resp.Responses[0].Metadata["owner"]
}

// Our first hit should create the request on the peer and queue for async forward
owner := sendHit(guber.Status_UNDER_LIMIT, 4, 1)
assert.NotEqual(t, owner, peer)
assert.NotEqual(t, owner, peerTwo)
// Our second should be processed by the other peer
sendHit(guber.Status_UNDER_LIMIT, 4, 2)
// Our third hit should be processed by the first peer as if we own it since the async forward hasn't occurred yet
sendHit(guber.Status_UNDER_LIMIT, 3, 2)

testutil.UntilPass(t, 20, clock.Millisecond*200, func(t testutil.TestingT) {
// Inspect our metrics, ensure they collected the counts we expected during this test
d := cluster.DaemonAt(0)
// Inspect our metrics, ensure they collected the counts we expected during this test
metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress)
m := getMetricRequest(t, metricsURL, "gubernator_global_send_duration_count")
assert.Equal(t, 1, int(m.Value))

// Expect one peer (the owning peer) to indicate a broadcast.
var broadcastCount int
for i := 0; i < cluster.NumOfDaemons(); i++ {
d := cluster.DaemonAt(i)
metricsURL := fmt.Sprintf("http://%s/metrics", d.Config().HTTPListenAddress)
m := getMetricRequest(t, metricsURL, "gubernator_broadcast_duration_count")
broadcastCount += int(m.Value)
}

assert.Equal(t, 1, broadcastCount)
})
}

func getMetricRequest(t testutil.TestingT, url string, name string) *model.Sample {
resp, err := http.Get(url)
require.NoError(t, err)
Expand Down Expand Up @@ -1273,3 +1355,46 @@ func getMetric(t testutil.TestingT, in io.Reader, name string) *model.Sample {
}
return nil
}

// staticBuilder implements the `resolver.Builder` interface.
type staticBuilder struct{}

func newStaticBuilder() resolver.Builder {
return &staticBuilder{}
}

func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
var resolverAddrs []resolver.Address
for _, address := range strings.Split(target.Endpoint(), ",") {
resolverAddrs = append(resolverAddrs, resolver.Address{
Addr: address,
ServerName: address,
})

}
r, err := newStaticResolver(cc, resolverAddrs)
if err != nil {
return nil, err
}
return r, nil
}

func (sb *staticBuilder) Scheme() string {
return "static"
}

type staticResolver struct {
cc resolver.ClientConn
}

func newStaticResolver(cc resolver.ClientConn, addresses []resolver.Address) (resolver.Resolver, error) {
err := cc.UpdateState(resolver.State{Addresses: addresses})
if err != nil {
return nil, err
}
return &staticResolver{cc: cc}, nil
}

func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {}

func (sr *staticResolver) Close() {}

0 comments on commit 9074374

Please sign in to comment.