Skip to content

Commit

Permalink
Add tests and fix concurrency race condition bug
Browse files Browse the repository at this point in the history
  • Loading branch information
n0str committed Jan 24, 2025
1 parent 8503d88 commit 433a30f
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 42 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/go-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Go Tests

on: [push, pull_request]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.21'

- name: Install dependencies
run: go mod download

- name: Run tests
run: go test -v ./...
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module github.com/codex-team/hawk.collector

require (
github.com/alicebob/miniredis/v2 v2.34.0
github.com/caarlos0/env/v6 v6.6.0
github.com/cenkalti/backoff/v4 v4.1.0
github.com/codex-team/hawk.go v1.0.5
Expand All @@ -14,6 +15,7 @@ require (
github.com/savsgio/gotils v0.0.0-20210520110740-c57c45b83e0a // indirect
github.com/sirupsen/logrus v1.8.1
github.com/streadway/amqp v1.0.0
github.com/stretchr/testify v1.7.0
github.com/tidwall/gjson v1.8.0
github.com/tidwall/sjson v1.1.6
github.com/valyala/fasthttp v1.25.0
Expand Down
13 changes: 13 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
github.com/andybalholm/brotli v1.0.2 h1:JKnhI/XQ75uFBTiuzXpzFrUriDPiZjlOSzh6wXogP0E=
github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
Expand Down Expand Up @@ -38,6 +42,9 @@ github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInq
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8=
Expand Down Expand Up @@ -212,8 +219,10 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxv
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
Expand Down Expand Up @@ -389,6 +398,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg=
go.mongodb.org/mongo-driver v1.7.1 h1:jwqTeEM3x6L9xDXrCxN0Hbg7vdGfPBOTIkr0+/LYZDA=
Expand Down Expand Up @@ -475,6 +486,7 @@ golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down Expand Up @@ -565,6 +577,7 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
Expand Down
92 changes: 50 additions & 42 deletions pkg/redis/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package redis
import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -177,54 +175,64 @@ func (r *RedisClient) CheckAvailability() bool {
return pong == "PONG"
}

// UpdateRateLimit checks and updates the rate limit for a project
// Returns true if rate is within limit, false otherwise
// UpdateRateLimit checks and updates the rate limit for a project using a Lua script
func (r *RedisClient) UpdateRateLimit(projectID string, eventsLimit int64, eventsPeriod int64) (bool, error) {
// If eventsLimit is 0, we don't need to update the rate limit
if eventsLimit == 0 {
return true, nil
}

// Key format: "project_id" -> "timestamp:count"
now := time.Now().Unix()
// Lua script for atomic rate limit check and update
script := `
local key = KEYS[1]
local field = ARGV[1]
local now = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
local period = tonumber(ARGV[4])
local current = redis.call('HGET', key, field)
if not current then
-- No existing record, create new window
redis.call('HSET', key, field, now .. ':1')
return 1
end
local timestamp, count = string.match(current, '(%d+):(%d+)')
timestamp = tonumber(timestamp)
count = tonumber(count)
-- Check if we're in a new time window
if now - timestamp >= period then
-- Reset for new window
redis.call('HSET', key, field, now .. ':1')
return 1
end
-- Check if incrementing would exceed limit
if count + 1 > limit then
return 0
end
-- Increment counter
redis.call('HSET', key, field, timestamp .. ':' .. (count + 1))
return 1
`

// Run the script
result, err := r.rdb.Eval(
r.ctx,
script,
[]string{"rate_limits"}, // KEYS
projectID, // field (ARGV[1])
time.Now().Unix(), // now (ARGV[2])
eventsLimit, // limit (ARGV[3])
eventsPeriod, // period (ARGV[4])
).Result()

// Get current window data
val, err := r.rdb.HGet(r.ctx, "rate_limits", projectID).Result()
if err != nil && err != redis.Nil {
return false, fmt.Errorf("failed to get rate limit data: %w", err)
}

var timestamp int64
var count int64

if val != "" {
// Parse existing "timestamp:count" value
parts := strings.Split(val, ":")
timestamp, _ = strconv.ParseInt(parts[0], 10, 64)
count, _ = strconv.ParseInt(parts[1], 10, 64)

// Reset count if we're in a new window
if now-timestamp >= eventsPeriod {
count = 0
timestamp = now
}
} else {
// Initialize new window
timestamp = now
count = 0
}

// Check if incrementing would exceed limit
if count+1 > eventsLimit {
return false, nil
}

// Update the counter
newVal := fmt.Sprintf("%d:%d", timestamp, count+1)
err = r.rdb.HSet(r.ctx, "rate_limits", projectID, newVal).Err()
if err != nil {
return false, fmt.Errorf("failed to update rate limit: %w", err)
return false, fmt.Errorf("failed to execute rate limit script: %w", err)
}

return true, nil
// Script returns 1 if rate limit is not exceeded, 0 if it is
return result.(int64) == 1, nil
}
187 changes: 187 additions & 0 deletions pkg/redis/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package redis

import (
"context"
"fmt"
"testing"
"time"

"github.com/alicebob/miniredis/v2"
"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
)

func setupTestRedis(t *testing.T) (*RedisClient, *miniredis.Miniredis) {
// Create a mock Redis server
mr, err := miniredis.Run()
if err != nil {
t.Fatalf("Failed to create mock redis: %v", err)
}

// Create Redis client connected to mock server
client := &RedisClient{
rdb: redis.NewClient(&redis.Options{
Addr: mr.Addr(),
}),
ctx: context.Background(),
}

return client, mr
}

func TestUpdateRateLimit(t *testing.T) {
client, mr := setupTestRedis(t)
defer mr.Close()

tests := []struct {
name string
projectID string
eventsLimit int64
eventsPeriod int64
setup func()
calls int
wantAllowed bool
wantErr bool
}{
{
name: "should allow when no previous events",
projectID: "project1",
eventsLimit: 10,
eventsPeriod: 60,
calls: 1,
wantAllowed: true,
wantErr: false,
},
{
name: "should allow when under limit",
projectID: "project2",
eventsLimit: 10,
eventsPeriod: 60,
setup: func() {
client.rdb.HSet(client.ctx, "rate_limits", "project2",
fmt.Sprintf("%d:%d", time.Now().Unix()-30, 5))
},
calls: 1,
wantAllowed: true,
wantErr: false,
},
{
name: "should deny when at limit",
projectID: "project3",
eventsLimit: 5,
eventsPeriod: 60,
setup: func() {
client.rdb.HSet(client.ctx, "rate_limits", "project3",
fmt.Sprintf("%d:%d", time.Now().Unix()-30, 5))
},
calls: 1,
wantAllowed: false,
wantErr: false,
},
{
name: "should reset count after period expires",
projectID: "project4",
eventsLimit: 5,
eventsPeriod: 60,
setup: func() {
client.rdb.HSet(client.ctx, "rate_limits", "project4",
fmt.Sprintf("%d:%d", time.Now().Unix()-61, 5))
},
calls: 1,
wantAllowed: true,
wantErr: false,
},
{
name: "should allow all when limit is 0",
projectID: "project5",
eventsLimit: 0,
eventsPeriod: 60,
calls: 5,
wantAllowed: true,
wantErr: false,
},
{
name: "should handle multiple calls up to limit",
projectID: "project6",
eventsLimit: 3,
eventsPeriod: 60,
calls: 4,
wantAllowed: false, // Last call should be denied
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Run setup if provided
if tt.setup != nil {
tt.setup()
}

var lastAllowed bool
var lastErr error

// Make the specified number of calls
for i := 0; i < tt.calls; i++ {
lastAllowed, lastErr = client.UpdateRateLimit(tt.projectID, tt.eventsLimit, tt.eventsPeriod)
}

if tt.wantErr {
assert.Error(t, lastErr)
} else {
assert.NoError(t, lastErr)
}
assert.Equal(t, tt.wantAllowed, lastAllowed)
})
}
}

func TestUpdateRateLimitConcurrent(t *testing.T) {
client, mr := setupTestRedis(t)
defer mr.Close()

const (
projectID = "concurrent-project"
eventsLimit = 90
eventsPeriod = 60
goroutines = 10
callsPerRoutine = 20
)

var rejectedCount int = 0

done := make(chan bool)

// Launch multiple goroutines to test concurrent access
for i := 0; i < goroutines; i++ {
go func() {
for j := 0; j < callsPerRoutine; j++ {
allowed, err := client.UpdateRateLimit(projectID, eventsLimit, eventsPeriod)
assert.NoError(t, err)
if !allowed {
rejectedCount++
}
}
done <- true
}()
}

// Wait for all goroutines to complete
for i := 0; i < goroutines; i++ {
<-done
}

// Verify the total number of successful updates doesn't exceed the limit
val, err := client.rdb.HGet(client.ctx, "rate_limits", projectID).Result()
assert.NoError(t, err)
assert.NotEmpty(t, val)

// The total count should not exceed the events limit
count := 0
_, err = fmt.Sscanf(val, "%d:%d", &count, &count)
assert.NoError(t, err)
assert.Equal(t, count, eventsLimit)
assert.Equal(t, rejectedCount, goroutines*callsPerRoutine-eventsLimit)
t.Logf("count: %d", count)
t.Logf("rejectedCount: %d", rejectedCount)
}

0 comments on commit 433a30f

Please sign in to comment.