Skip to content

Commit

Permalink
change review
Browse files Browse the repository at this point in the history
  • Loading branch information
kai2321 committed Feb 24, 2025
1 parent 98070cd commit c3c2627
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 65 deletions.
11 changes: 8 additions & 3 deletions plugins/wasm-go/extensions/key-rate-limit/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ module key-rate-limit

go 1.19

replace github.com/alibaba/higress/plugins/wasm-go => ../..

require (
github.com/alibaba/higress/plugins/wasm-go v1.4.4-0.20250208094229-512385d22551
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.18.0
)

require (
github.com/alibaba/higress/plugins/wasm-go v1.4.4-0.20250208094229-512385d22551 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
Expand Down
109 changes: 47 additions & 62 deletions plugins/wasm-go/extensions/key-rate-limit/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ import (
)

const (
secondNano = 1000 * 1000 * 1000
minuteNano = 60 * secondNano
hourNano = 60 * minuteNano
dayNano = 24 * hourNano
tickMilliseconds int64 = 500
maxGetTokenRetry int = 20
secondNano = 1000 * 1000 * 1000
minuteNano = 60 * secondNano
hourNano = 60 * minuteNano
dayNano = 24 * hourNano
tickMilliseconds int64 = 500
maxGetTokenRetry int = 20
initialValue uint64 = 0
tokenBucketPrefix string = "mse.token_bucket"
lastRefilledPrefix string = "mse.last_refilled"
)

type KeyRateLimitConfig struct {
Expand All @@ -37,12 +40,6 @@ type LimitItem struct {
maxTokens uint64
}

// Key-prefix for token bucket shared data.
var tokenBucketPrefix string = "mse.token_bucket"

// Key-prefix for token bucket last updated time.
var lastRefilledPrefix string = "mse.last_refilled"

var ruleId int = 0

func main() {
Expand All @@ -57,57 +54,35 @@ func parseConfig(json gjson.Result, config *KeyRateLimitConfig, log wrapper.Log)
//解析配置规则
config.limitKeys = make(map[string]LimitItem)
limitKeys := json.Get("limit_keys").Array()
limitMapping := map[string]int64{
"query_per_second": secondNano,
"query_per_minute": minuteNano,
"query_per_hour": hourNano,
"query_per_day": dayNano,
}
for _, item := range limitKeys {
key := item.Get("key")
if !key.Exists() || key.String() == "" {
return errors.New("key name is required")
}
qps := item.Get("query_per_second")
if qps.Exists() && qps.String() != "" {
config.limitKeys[key.String()] = LimitItem{
ruleId: ruleId,
key: key.String(),
tokensPerRefill: qps.Uint(),
refillIntervalNanosec: secondNano,
maxTokens: qps.Uint(),
}
continue
}
qpm := item.Get("query_per_minute")
if qpm.Exists() && qpm.String() != "" {
config.limitKeys[key.String()] = LimitItem{
ruleId: ruleId,
key: key.String(),
tokensPerRefill: qpm.Uint(),
refillIntervalNanosec: minuteNano,
maxTokens: qpm.Uint(),
}
continue
}
qph := item.Get("query_per_hour")
if qph.Exists() && qph.String() != "" {
config.limitKeys[key.String()] = LimitItem{
ruleId: ruleId,
key: key.String(),
tokensPerRefill: qph.Uint(),
refillIntervalNanosec: hourNano,
maxTokens: qph.Uint(),
matched := false
for field, nanoValue := range limitMapping {
qps := item.Get(field)
if qps.Exists() && qps.String() != "" {
config.limitKeys[key.String()] = LimitItem{
ruleId: ruleId,
key: key.String(),
tokensPerRefill: qps.Uint(),
refillIntervalNanosec: uint64(nanoValue),
maxTokens: qps.Uint(),
}
matched = true
break
}
continue
}
qpd := item.Get("query_per_day")
if qpd.Exists() && qpd.String() != "" {
config.limitKeys[key.String()] = LimitItem{
ruleId: ruleId,
key: key.String(),
tokensPerRefill: qpd.Uint(),
refillIntervalNanosec: dayNano,
maxTokens: qpd.Uint(),
}
continue
if !matched {
return errors.New("one of 'query_per_second', 'query_per_minute', 'query_per_hour' or 'query_per_day' must be set")
}
return errors.New("one of 'query_per_second', 'query_per_minute', " +
"'query_per_hour' or 'query_per_day' must be set")
}
if len(config.limitKeys) == 0 {
return errors.New("no limit keys found in configuration")
Expand Down Expand Up @@ -164,6 +139,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config KeyRateLimitConfig, lo
}
}
}
if key == "" {
return types.ActionContinue
}
limitKeys := config.limitKeys
_, exists := limitKeys[key]
if !exists {
Expand All @@ -183,10 +161,9 @@ func tooManyRequest() types.Action {
}

func initializeTokenBucket(rules map[string]LimitItem, log wrapper.Log) bool {
var initialValue uint64 = 0
for _, rule := range rules {
lastRefilledKey := strconv.Itoa(rule.ruleId) + lastRefilledPrefix + rule.key
tokenBucketKey := strconv.Itoa(rule.ruleId) + tokenBucketPrefix + rule.key
lastRefilledKey := rule.GenerateLastRefilledKey()
tokenBucketKey := rule.GenerateTokenBucketKey()
initialBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(initialBuf, initialValue)
maxTokenBuf := make([]byte, 8)
Expand All @@ -202,7 +179,7 @@ func initializeTokenBucket(rules map[string]LimitItem, log wrapper.Log) bool {
for {
_, lastUpdateCas, err := proxywasm.GetSharedData(lastRefilledKey)
if err != nil {
log.Warnf("failed to get lastRefilled")
log.Warnf("failed to get lastRefilled, err: %v", err)
return false
}
err = proxywasm.SetSharedData(lastRefilledKey, initialBuf, lastUpdateCas)
Expand All @@ -214,7 +191,7 @@ func initializeTokenBucket(rules map[string]LimitItem, log wrapper.Log) bool {
for {
_, lastUpdateCas, err := proxywasm.GetSharedData(tokenBucketKey)
if err != nil {
log.Warnf("failed to get tokenBucket")
log.Warnf("failed to get tokenBucket, err: %v", err)
return false
}
err = proxywasm.SetSharedData(tokenBucketKey, maxTokenBuf, lastUpdateCas)
Expand All @@ -231,8 +208,8 @@ func initializeTokenBucket(rules map[string]LimitItem, log wrapper.Log) bool {

func refillToken(rules map[string]LimitItem, log wrapper.Log) {
for _, rule := range rules {
lastRefilledKey := strconv.Itoa(rule.ruleId) + lastRefilledPrefix + rule.key
tokenBucketKey := strconv.Itoa(rule.ruleId) + tokenBucketPrefix + rule.key
lastRefilledKey := rule.GenerateLastRefilledKey()
tokenBucketKey := rule.GenerateTokenBucketKey()
lastUpdateData, lastUpdateCas, err := proxywasm.GetSharedData(lastRefilledKey)
if err != nil {
log.Warnf("failed to get last update time of the local rate limit: %s", err)
Expand Down Expand Up @@ -299,3 +276,11 @@ func getToken(ruleId int, key string) bool {
}
return true
}

func (l *LimitItem) GenerateLastRefilledKey() string {
return strconv.Itoa(l.ruleId) + lastRefilledPrefix + l.key
}

func (l *LimitItem) GenerateTokenBucketKey() string {
return strconv.Itoa(l.ruleId) + tokenBucketPrefix + l.key
}

0 comments on commit c3c2627

Please sign in to comment.