From f5a99cb458fc7e119a9212dd3e07a0321283d9fb Mon Sep 17 00:00:00 2001 From: icey-yu <1186114839@qq.com> Date: Wed, 12 Jun 2024 18:29:18 +0800 Subject: [PATCH] fix:create token can set expire time --- internal/rpc/admin/token.go | 4 ++-- pkg/common/db/cache/token.go | 21 ++++++++++++++++++++- pkg/common/db/database/admin.go | 17 ++++++++++++++--- pkg/common/tokenverify/token_verify.go | 8 ++++---- 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/internal/rpc/admin/token.go b/internal/rpc/admin/token.go index 41240738..8ac7c08c 100644 --- a/internal/rpc/admin/token.go +++ b/internal/rpc/admin/token.go @@ -21,11 +21,11 @@ import ( ) func (o *adminServer) CreateToken(ctx context.Context, req *admin.CreateTokenReq) (*admin.CreateTokenResp, error) { - token, err := o.Token.CreateToken(req.UserID, req.UserType) + token, expire, err := o.Token.CreateToken(req.UserID, req.UserType) if err != nil { return nil, err } - err = o.Database.CacheToken(ctx, req.UserID, token) + err = o.Database.CacheToken(ctx, req.UserID, token, expire) if err != nil { return nil, err } diff --git a/pkg/common/db/cache/token.go b/pkg/common/db/cache/token.go index a8c09b45..e9b4c332 100644 --- a/pkg/common/db/cache/token.go +++ b/pkg/common/db/cache/token.go @@ -17,6 +17,7 @@ package cache import ( "context" "github.com/openimsdk/tools/utils/stringutil" + "time" "github.com/openimsdk/tools/errs" "github.com/redis/go-redis/v9" @@ -28,11 +29,13 @@ const ( type TokenInterface interface { AddTokenFlag(ctx context.Context, userID string, token string, flag int) error + AddTokenFlagNXEx(ctx context.Context, userID string, token string, flag int, expire time.Duration) (bool, error) GetTokensWithoutError(ctx context.Context, userID string) (map[string]int32, error) } type TokenCacheRedis struct { - rdb redis.UniversalClient + rdb redis.UniversalClient + accessExpire int64 } func NewTokenInterface(rdb redis.UniversalClient) *TokenCacheRedis { @@ -44,6 +47,22 @@ func (t *TokenCacheRedis) AddTokenFlag(ctx context.Context, userID string, token return errs.Wrap(t.rdb.HSet(ctx, key, token, flag).Err()) } +func (t *TokenCacheRedis) AddTokenFlagNXEx(ctx context.Context, userID string, token string, flag int, expire time.Duration) (bool, error) { + key := chatToken + userID + isSet, err := t.rdb.HSetNX(ctx, key, token, flag).Result() + if err != nil { + return false, errs.Wrap(err) + } + if !isSet { + // key already exists + return false, nil + } + if err = t.rdb.Expire(ctx, key, expire).Err(); err != nil { + return false, errs.Wrap(err) + } + return isSet, nil +} + func (t *TokenCacheRedis) GetTokensWithoutError(ctx context.Context, userID string) (map[string]int32, error) { key := chatToken + userID m, err := t.rdb.HGetAll(ctx, key).Result() diff --git a/pkg/common/db/database/admin.go b/pkg/common/db/database/admin.go index 7920555e..c8c3e854 100644 --- a/pkg/common/db/database/admin.go +++ b/pkg/common/db/database/admin.go @@ -16,6 +16,7 @@ package database import ( "context" + "time" "github.com/openimsdk/chat/pkg/common/db/cache" "github.com/openimsdk/protocol/constant" @@ -74,7 +75,7 @@ type AdminDatabaseInterface interface { DelUserLimitLogin(ctx context.Context, ms []*admindb.LimitUserLoginIP) error CountLimitUserLoginIP(ctx context.Context, userID string) (uint32, error) GetLimitUserLoginIP(ctx context.Context, userID string, ip string) (*admindb.LimitUserLoginIP, error) - CacheToken(ctx context.Context, userID string, token string) error + CacheToken(ctx context.Context, userID string, token string, expire time.Duration) error GetTokens(ctx context.Context, userID string) (map[string]int32, error) } @@ -324,8 +325,18 @@ func (o *AdminDatabase) GetLimitUserLoginIP(ctx context.Context, userID string, return o.limitUserLoginIP.Take(ctx, userID, ip) } -func (o *AdminDatabase) CacheToken(ctx context.Context, userID string, token string) error { - return o.cache.AddTokenFlag(ctx, userID, token, constant.NormalToken) +func (o *AdminDatabase) CacheToken(ctx context.Context, userID string, token string, expire time.Duration) error { + isSet, err := o.cache.AddTokenFlagNXEx(ctx, userID, token, constant.NormalToken, expire) + if err != nil { + return err + } + if !isSet { + // already exists, update + if err = o.cache.AddTokenFlag(ctx, userID, token, constant.NormalToken); err != nil { + return err + } + } + return nil } func (o *AdminDatabase) GetTokens(ctx context.Context, userID string) (map[string]int32, error) { diff --git a/pkg/common/tokenverify/token_verify.go b/pkg/common/tokenverify/token_verify.go index 9fdfbac5..3c35ffa7 100644 --- a/pkg/common/tokenverify/token_verify.go +++ b/pkg/common/tokenverify/token_verify.go @@ -86,16 +86,16 @@ func (t *Token) getToken(str string) (string, int32, error) { } } -func (t *Token) CreateToken(UserID string, userType int32) (string, error) { +func (t *Token) CreateToken(UserID string, userType int32) (string, time.Duration, error) { if !(userType == TokenUser || userType == TokenAdmin) { - return "", errs.ErrTokenUnknown.WrapMsg("token type unknown") + return "", 0, errs.ErrTokenUnknown.WrapMsg("token type unknown") } token := jwt.NewWithClaims(jwt.SigningMethodHS256, t.buildClaims(UserID, userType)) str, err := token.SignedString([]byte(t.Secret)) if err != nil { - return "", errs.Wrap(err) + return "", 0, errs.Wrap(err) } - return str, nil + return str, t.Expires, nil } func (t *Token) GetToken(token string) (string, int32, error) {