diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..22d0d82 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +vendor diff --git a/Gopkg.lock b/Gopkg.lock new file mode 100644 index 0000000..0c4f583 --- /dev/null +++ b/Gopkg.lock @@ -0,0 +1,23 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + name = "github.com/go-redis/redis" + packages = [ + ".", + "internal", + "internal/consistenthash", + "internal/hashtag", + "internal/pool", + "internal/proto", + "internal/util" + ] + revision = "75795aa4236dc7341eefac3bbe945e68c99ef9df" + version = "v6.15.3" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + inputs-digest = "f0f820716152dc4e016e3c86b43ed6442fa064f6ed98320078c232b45dbc229c" + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml new file mode 100644 index 0000000..63b9e4a --- /dev/null +++ b/Gopkg.toml @@ -0,0 +1,33 @@ +# Gopkg.toml example +# +# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true + + +[[constraint]] + name = "github.com/go-redis/redis" + version = "6.15.3" + +[prune] + unused-packages = true diff --git a/redispattern/concurrentratelimiter/concurrent_rate_limiter.go b/redispattern/concurrentratelimiter/concurrent_rate_limiter.go new file mode 100644 index 0000000..41e6433 --- /dev/null +++ b/redispattern/concurrentratelimiter/concurrent_rate_limiter.go @@ -0,0 +1,141 @@ +package concurrentratelimiter + +import ( + "crypto/sha1" + "encoding/hex" + "io" + "strings" + "time" + + "github.com/xiaojiaoyu100/lizard/redispattern" + "github.com/xiaojiaoyu100/lizard/timekit" +) + +const ( + enterScript = ` +local key = KEYS[1] +local limit = tonumber(ARGV[1]) +local now = tonumber(ARGV[2]) +local random = ARGV[3] +local ttl = tonumber(ARGV[4]) + +redis.call('zremrangebyscore', key, '-inf', now - ttl) + +local count = redis.call("zcard", key) + +if count < limit then + redis.call("zadd", key, now, random) + return 1 +end + +return 0 +` + leaveScript = ` +local key = KEYS[1] +local random = ARGV[1] +local ret = redis.call("zrem", key, random) +return ret +` +) + +var ( + enterScriptDigest string + leaveScriptDigest string +) + +func init() { + e := sha1.New() + io.WriteString(e, enterScript) + enterScriptDigest = hex.EncodeToString(e.Sum(nil)) + + l := sha1.New() + io.WriteString(l, leaveScript) + leaveScriptDigest = hex.EncodeToString(l.Sum(nil)) +} + +type Setting func(o *Option) error + +type Option struct { + ttl int64 // time to live in millisecond + limit int64 // maximum running limit +} + +func WithTTL(ttl time.Duration) Setting { + return func(o *Option) error { + o.ttl = timekit.DurationToMillis(ttl) + return nil + } +} + +func WithLimit(limit int64) Setting { + return func(o *Option) error { + o.limit = limit + return nil + } +} + +type ConcurrentRateLimiter struct { + runner redispattern.Runner + key string + option Option +} + +func New(runner redispattern.Runner, key string, settings ...Setting) (*ConcurrentRateLimiter, error) { + c := &ConcurrentRateLimiter{ + runner: runner, + key: key, + } + o := Option{ + ttl: timekit.DurationToMillis(3 * time.Second), + limit: 10, + } + for _, setting := range settings { + if err := setting(&o); err != nil { + return nil, err + } + } + c.option = o + return c, nil +} + +func (c *ConcurrentRateLimiter) Enter(random string) (bool, error) { + ok, err := c.runner.EvaSha1(enterScriptDigest, + c.key, + c.option.limit, + timekit.NowInMillis(), + random, + c.option.ttl, + ) + if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT") { + ok, err := c.runner.Eva(enterScript, + c.key, + c.option.limit, + timekit.NowInMillis(), + random, + c.option.ttl, + ) + if err != nil { + return false, err + } + return ok == 1, nil + } + if err != nil { + return false, err + } + return ok == 1, nil +} + +func (c *ConcurrentRateLimiter) Leave(random string) error { + _, err := c.runner.EvaSha1(leaveScriptDigest, c.key, random) + if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT") { + _, err := c.runner.Eva(leaveScript, c.key, random) + if err != nil { + return err + } + return nil + } + if err != nil { + return err + } + return nil +} diff --git a/redispattern/tokenbucket/token_bucket.go b/redispattern/tokenbucket/token_bucket.go new file mode 100644 index 0000000..492e014 --- /dev/null +++ b/redispattern/tokenbucket/token_bucket.go @@ -0,0 +1,121 @@ +package tokenbucket + +import ( + "crypto/sha1" + "encoding/hex" + "errors" + "io" + "strings" + "time" + + "github.com/go-redis/redis" + "github.com/xiaojiaoyu100/lizard/timekit" +) + +const script = ` +local key = KEYS[1] +local rate = tonumber(ARGV[1]) +local tokenNum = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local num = tonumber(ARGV[4]) +local expiration = ARGV[5] +local obj = { +tn=tokenNum, +ts=now +} + +local value = redis.call("get", key) +if value then + obj = cjson.decode(value) +end + +local incr = math.floor((now - obj.ts) / rate) +if incr > 0 then + obj.tn = math.min(obj.tn + incr, tokenNum) + obj.ts = obj.ts + incr * rate +end + +if obj.tn >= num then + obj.tn = obj.tn - num + obj.ts = string.format("%.f", obj.ts) + if redis.call("set", key, cjson.encode(obj), "EX", expiration) then + return 1 + end +end + +return 0 +` + +var scriptDigest string + +func init() { + s := sha1.New() + io.WriteString(s, script) + scriptDigest = hex.EncodeToString(s.Sum(nil)) +} + +// TokenBucket stands for a token bucket. +type TokenBucket struct { + client *redis.Client // redis client + Key string // redis key + TokenNum int64 // token bucket size + Rate time.Duration // the rate of putting token into bucket + Expiration int64 // redis key expiration in seconds +} + +// New returns an instance of TokenBucket +func New(client *redis.Client, key string, tokenNum int64, rate time.Duration, expiration int64) (*TokenBucket, error) { + h := sha1.New() + _, err := io.WriteString(h, script) + if err != nil { + return nil, err + } + + if timekit.DurationToMillis(rate) == 0 { + return nil, errors.New("wrong rate") + } + + return &TokenBucket{ + client: client, + Key: key, + TokenNum: tokenNum, + Rate: rate, + Expiration: expiration, + }, nil +} + +func (tb *TokenBucket) eva(script string, key string, argv ...interface{}) (int64, error) { + ret, err := tb.client.Eval(script, []string{key}, argv...).Result() + if err != nil { + return 0, err + } + return ret.(int64), nil +} + +func (tb *TokenBucket) evaSha1(sha1 string, key string, argv ...interface{}) (int64, error) { + ret, err := tb.client.EvalSha(sha1, []string{key}, argv...).Result() + if err != nil { + return 0, err + } + return ret.(int64), nil +} + +// Consume consumes the number of token in the token bucket. +func (tb *TokenBucket) Consume(num int64) (bool, error) { + if num > tb.TokenNum { + return false, errors.New("token is not enough") + } + ok, err := tb.evaSha1(scriptDigest, tb.Key, timekit.DurationToMillis(tb.Rate), tb.TokenNum, timekit.NowInMillis(), num, tb.Expiration) + // NOSCRIPT 这个error是稳定的 see https://redis.io/commands/eval + if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT") { + ok, err := tb.eva(script, tb.Key, timekit.DurationToMillis(tb.Rate), tb.TokenNum, timekit.NowInMillis(), num, tb.Expiration) + if err != nil { + return false, err + } + return ok == 1, nil + } + if err != nil { + return false, err + } + return ok == 1, nil +} diff --git a/timekit/timekit.go b/timekit/timekit.go new file mode 100644 index 0000000..e280bc3 --- /dev/null +++ b/timekit/timekit.go @@ -0,0 +1,23 @@ +package timekit + +import "time" + +// DurationToMillis converts duration to milliseconds. +func DurationToMillis(d time.Duration) int64 { + return int64(d / time.Millisecond) +} + +// NowInMillis returns timestamp in milliseconds. +func NowInMillis() int64 { + return time.Now().UnixNano() / int64(time.Millisecond) +} + +// NowInSecs returns timestamp in seconds. +func NowInSecs() int64 { + return time.Now().Unix() +} + +// UTCNowTime returns current time in utc. +func UTCNowTime() time.Time { + return time.Now().UTC() +}