From 433a30fc9b41ccd91243fe718a7b21ace67d0ad0 Mon Sep 17 00:00:00 2001 From: n0str Date: Fri, 24 Jan 2025 16:25:37 +0300 Subject: [PATCH] Add tests and fix concurrency race condition bug --- .github/workflows/go-test.yml | 20 ++++ go.mod | 2 + go.sum | 13 +++ pkg/redis/client.go | 92 +++++++++-------- pkg/redis/client_test.go | 187 ++++++++++++++++++++++++++++++++++ 5 files changed, 272 insertions(+), 42 deletions(-) create mode 100644 .github/workflows/go-test.yml create mode 100644 pkg/redis/client_test.go diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml new file mode 100644 index 0000000..c72ab49 --- /dev/null +++ b/.github/workflows/go-test.yml @@ -0,0 +1,20 @@ +name: Go Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: Install dependencies + run: go mod download + + - name: Run tests + run: go test -v ./... \ No newline at end of file diff --git a/go.mod b/go.mod index c1f6196..9230088 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/codex-team/hawk.collector require ( + github.com/alicebob/miniredis/v2 v2.34.0 github.com/caarlos0/env/v6 v6.6.0 github.com/cenkalti/backoff/v4 v4.1.0 github.com/codex-team/hawk.go v1.0.5 @@ -14,6 +15,7 @@ require ( github.com/savsgio/gotils v0.0.0-20210520110740-c57c45b83e0a // indirect github.com/sirupsen/logrus v1.8.1 github.com/streadway/amqp v1.0.0 + github.com/stretchr/testify v1.7.0 github.com/tidwall/gjson v1.8.0 github.com/tidwall/sjson v1.1.6 github.com/valyala/fasthttp v1.25.0 diff --git a/go.sum b/go.sum index e1a172f..e4ca661 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,10 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= +github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0= +github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8= github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/andybalholm/brotli v1.0.2 h1:JKnhI/XQ75uFBTiuzXpzFrUriDPiZjlOSzh6wXogP0E= github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= @@ -38,6 +42,9 @@ github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInq github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= @@ -212,8 +219,10 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= @@ -389,6 +398,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= go.mongodb.org/mongo-driver v1.7.1 h1:jwqTeEM3x6L9xDXrCxN0Hbg7vdGfPBOTIkr0+/LYZDA= @@ -475,6 +486,7 @@ golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -565,6 +577,7 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= diff --git a/pkg/redis/client.go b/pkg/redis/client.go index 539591b..4024bbe 100644 --- a/pkg/redis/client.go +++ b/pkg/redis/client.go @@ -3,8 +3,6 @@ package redis import ( "context" "fmt" - "strconv" - "strings" "sync" "time" @@ -177,54 +175,64 @@ func (r *RedisClient) CheckAvailability() bool { return pong == "PONG" } -// UpdateRateLimit checks and updates the rate limit for a project -// Returns true if rate is within limit, false otherwise +// UpdateRateLimit checks and updates the rate limit for a project using a Lua script func (r *RedisClient) UpdateRateLimit(projectID string, eventsLimit int64, eventsPeriod int64) (bool, error) { // If eventsLimit is 0, we don't need to update the rate limit if eventsLimit == 0 { return true, nil } - // Key format: "project_id" -> "timestamp:count" - now := time.Now().Unix() + // Lua script for atomic rate limit check and update + script := ` + local key = KEYS[1] + local field = ARGV[1] + local now = tonumber(ARGV[2]) + local limit = tonumber(ARGV[3]) + local period = tonumber(ARGV[4]) + + local current = redis.call('HGET', key, field) + if not current then + -- No existing record, create new window + redis.call('HSET', key, field, now .. ':1') + return 1 + end + + local timestamp, count = string.match(current, '(%d+):(%d+)') + timestamp = tonumber(timestamp) + count = tonumber(count) + + -- Check if we're in a new time window + if now - timestamp >= period then + -- Reset for new window + redis.call('HSET', key, field, now .. ':1') + return 1 + end + + -- Check if incrementing would exceed limit + if count + 1 > limit then + return 0 + end + + -- Increment counter + redis.call('HSET', key, field, timestamp .. ':' .. (count + 1)) + return 1 + ` + + // Run the script + result, err := r.rdb.Eval( + r.ctx, + script, + []string{"rate_limits"}, // KEYS + projectID, // field (ARGV[1]) + time.Now().Unix(), // now (ARGV[2]) + eventsLimit, // limit (ARGV[3]) + eventsPeriod, // period (ARGV[4]) + ).Result() - // Get current window data - val, err := r.rdb.HGet(r.ctx, "rate_limits", projectID).Result() - if err != nil && err != redis.Nil { - return false, fmt.Errorf("failed to get rate limit data: %w", err) - } - - var timestamp int64 - var count int64 - - if val != "" { - // Parse existing "timestamp:count" value - parts := strings.Split(val, ":") - timestamp, _ = strconv.ParseInt(parts[0], 10, 64) - count, _ = strconv.ParseInt(parts[1], 10, 64) - - // Reset count if we're in a new window - if now-timestamp >= eventsPeriod { - count = 0 - timestamp = now - } - } else { - // Initialize new window - timestamp = now - count = 0 - } - - // Check if incrementing would exceed limit - if count+1 > eventsLimit { - return false, nil - } - - // Update the counter - newVal := fmt.Sprintf("%d:%d", timestamp, count+1) - err = r.rdb.HSet(r.ctx, "rate_limits", projectID, newVal).Err() if err != nil { - return false, fmt.Errorf("failed to update rate limit: %w", err) + return false, fmt.Errorf("failed to execute rate limit script: %w", err) } - return true, nil + // Script returns 1 if rate limit is not exceeded, 0 if it is + return result.(int64) == 1, nil } diff --git a/pkg/redis/client_test.go b/pkg/redis/client_test.go new file mode 100644 index 0000000..2a11d2b --- /dev/null +++ b/pkg/redis/client_test.go @@ -0,0 +1,187 @@ +package redis + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/assert" +) + +func setupTestRedis(t *testing.T) (*RedisClient, *miniredis.Miniredis) { + // Create a mock Redis server + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("Failed to create mock redis: %v", err) + } + + // Create Redis client connected to mock server + client := &RedisClient{ + rdb: redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }), + ctx: context.Background(), + } + + return client, mr +} + +func TestUpdateRateLimit(t *testing.T) { + client, mr := setupTestRedis(t) + defer mr.Close() + + tests := []struct { + name string + projectID string + eventsLimit int64 + eventsPeriod int64 + setup func() + calls int + wantAllowed bool + wantErr bool + }{ + { + name: "should allow when no previous events", + projectID: "project1", + eventsLimit: 10, + eventsPeriod: 60, + calls: 1, + wantAllowed: true, + wantErr: false, + }, + { + name: "should allow when under limit", + projectID: "project2", + eventsLimit: 10, + eventsPeriod: 60, + setup: func() { + client.rdb.HSet(client.ctx, "rate_limits", "project2", + fmt.Sprintf("%d:%d", time.Now().Unix()-30, 5)) + }, + calls: 1, + wantAllowed: true, + wantErr: false, + }, + { + name: "should deny when at limit", + projectID: "project3", + eventsLimit: 5, + eventsPeriod: 60, + setup: func() { + client.rdb.HSet(client.ctx, "rate_limits", "project3", + fmt.Sprintf("%d:%d", time.Now().Unix()-30, 5)) + }, + calls: 1, + wantAllowed: false, + wantErr: false, + }, + { + name: "should reset count after period expires", + projectID: "project4", + eventsLimit: 5, + eventsPeriod: 60, + setup: func() { + client.rdb.HSet(client.ctx, "rate_limits", "project4", + fmt.Sprintf("%d:%d", time.Now().Unix()-61, 5)) + }, + calls: 1, + wantAllowed: true, + wantErr: false, + }, + { + name: "should allow all when limit is 0", + projectID: "project5", + eventsLimit: 0, + eventsPeriod: 60, + calls: 5, + wantAllowed: true, + wantErr: false, + }, + { + name: "should handle multiple calls up to limit", + projectID: "project6", + eventsLimit: 3, + eventsPeriod: 60, + calls: 4, + wantAllowed: false, // Last call should be denied + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Run setup if provided + if tt.setup != nil { + tt.setup() + } + + var lastAllowed bool + var lastErr error + + // Make the specified number of calls + for i := 0; i < tt.calls; i++ { + lastAllowed, lastErr = client.UpdateRateLimit(tt.projectID, tt.eventsLimit, tt.eventsPeriod) + } + + if tt.wantErr { + assert.Error(t, lastErr) + } else { + assert.NoError(t, lastErr) + } + assert.Equal(t, tt.wantAllowed, lastAllowed) + }) + } +} + +func TestUpdateRateLimitConcurrent(t *testing.T) { + client, mr := setupTestRedis(t) + defer mr.Close() + + const ( + projectID = "concurrent-project" + eventsLimit = 90 + eventsPeriod = 60 + goroutines = 10 + callsPerRoutine = 20 + ) + + var rejectedCount int = 0 + + done := make(chan bool) + + // Launch multiple goroutines to test concurrent access + for i := 0; i < goroutines; i++ { + go func() { + for j := 0; j < callsPerRoutine; j++ { + allowed, err := client.UpdateRateLimit(projectID, eventsLimit, eventsPeriod) + assert.NoError(t, err) + if !allowed { + rejectedCount++ + } + } + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < goroutines; i++ { + <-done + } + + // Verify the total number of successful updates doesn't exceed the limit + val, err := client.rdb.HGet(client.ctx, "rate_limits", projectID).Result() + assert.NoError(t, err) + assert.NotEmpty(t, val) + + // The total count should not exceed the events limit + count := 0 + _, err = fmt.Sscanf(val, "%d:%d", &count, &count) + assert.NoError(t, err) + assert.Equal(t, count, eventsLimit) + assert.Equal(t, rejectedCount, goroutines*callsPerRoutine-eventsLimit) + t.Logf("count: %d", count) + t.Logf("rejectedCount: %d", rejectedCount) +}