forked from mantlenetworkio/mantle-v2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrontend_rate_limiter.go
139 lines (118 loc) · 3.7 KB
/
frontend_rate_limiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package proxyd
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-redis/redis/v8"
)
type FrontendRateLimiter interface {
// Take consumes a key, and a maximum number of requests
// per time interval. It returns a boolean denoting if
// the limit could be taken, or an error if a failure
// occurred in the backing rate limit implementation.
//
// No error will be returned if the limit could not be taken
// as a result of the requestor being over the limit.
Take(ctx context.Context, key string) (bool, error)
}
// limitedKeys is a wrapper around a map that stores a truncated
// timestamp and a mutex. The map is used to keep track of rate
// limit keys, and their used limits.
type limitedKeys struct {
truncTS int64
keys map[string]int
mtx sync.Mutex
}
func newLimitedKeys(t int64) *limitedKeys {
return &limitedKeys{
truncTS: t,
keys: make(map[string]int),
}
}
func (l *limitedKeys) Take(key string, max int) bool {
l.mtx.Lock()
defer l.mtx.Unlock()
val, ok := l.keys[key]
if !ok {
l.keys[key] = 0
val = 0
}
l.keys[key] = val + 1
return val < max
}
// MemoryFrontendRateLimiter is a rate limiter that stores
// all rate limiting information in local memory. It works
// by storing a limitedKeys struct that references the
// truncated timestamp at which the struct was created. If
// the current truncated timestamp doesn't match what's
// referenced, the limit is reset. Otherwise, values in
// a map are incremented to represent the limit.
type MemoryFrontendRateLimiter struct {
currGeneration *limitedKeys
dur time.Duration
max int
mtx sync.Mutex
}
func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter {
return &MemoryFrontendRateLimiter{
dur: dur,
max: max,
}
}
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
m.mtx.Lock()
// Create truncated timestamp
truncTS := truncateNow(m.dur)
// If there is no current rate limit map or the rate limit map reference
// a different timestamp, reset limits.
if m.currGeneration == nil || m.currGeneration.truncTS != truncTS {
m.currGeneration = newLimitedKeys(truncTS)
}
// Pull out the limiter so we can unlock before incrementing the limit.
limiter := m.currGeneration
m.mtx.Unlock()
return limiter.Take(key, m.max), nil
}
// RedisFrontendRateLimiter is a rate limiter that stores data in Redis.
// It uses the basic rate limiter pattern described on the Redis best
// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/.
type RedisFrontendRateLimiter struct {
r *redis.Client
dur time.Duration
max int
prefix string
}
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter {
return &RedisFrontendRateLimiter{
r: r,
dur: dur,
max: max,
prefix: prefix,
}
}
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
var incr *redis.IntCmd
truncTS := truncateNow(r.dur)
fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS)
_, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, fullKey)
pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond)
return nil
})
if err != nil {
frontendRateLimitTakeErrors.Inc()
return false, err
}
return incr.Val()-1 < int64(r.max), nil
}
type noopFrontendRateLimiter struct{}
var NoopFrontendRateLimiter = &noopFrontendRateLimiter{}
func (n *noopFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
return true, nil
}
// truncateNow truncates the current timestamp
// to the specified duration.
func truncateNow(dur time.Duration) int64 {
return time.Now().Truncate(dur).Unix()
}