From 0e07ad70c32c0fec1cd06d80769f0e9c7b68a95b Mon Sep 17 00:00:00 2001 From: icey-yu <119291641+icey-yu@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:25:28 +0800 Subject: [PATCH] fix: admin token limit (#2871) --- internal/rpc/auth/auth.go | 10 +++- pkg/common/storage/controller/auth.go | 81 +++++++++++++++++---------- pkg/common/storage/controller/msg.go | 2 +- 3 files changed, 59 insertions(+), 34 deletions(-) diff --git a/internal/rpc/auth/auth.go b/internal/rpc/auth/auth.go index 62df74d214..a1acfd9313 100644 --- a/internal/rpc/auth/auth.go +++ b/internal/rpc/auth/auth.go @@ -16,6 +16,7 @@ package auth import ( "context" + "errors" "github.com/openimsdk/open-im-server/v3/pkg/common/config" redis2 "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" @@ -66,6 +67,7 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg config.Share.Secret, config.RpcConfig.TokenPolicy.Expire, config.Share.MultiLogin, + config.Share.IMAdminUserID, ), config: config, }) @@ -129,6 +131,10 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim if err != nil { return nil, errs.Wrap(err) } + isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID) + if isAdmin { + return claims, nil + } m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID) if err != nil { return nil, err @@ -190,7 +196,7 @@ func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID } m, err := s.authDatabase.GetTokensWithoutError(ctx, userID, int(platformID)) - if err != nil && err != redis.Nil { + if err != nil && errors.Is(err, redis.Nil) { return err } for k := range m { @@ -208,7 +214,7 @@ func (s *authServer) forceKickOff(ctx context.Context, userID string, platformID func (s *authServer) InvalidateToken(ctx context.Context, req *pbauth.InvalidateTokenReq) (*pbauth.InvalidateTokenResp, error) { m, err := s.authDatabase.GetTokensWithoutError(ctx, req.UserID, int(req.PlatformID)) - if err != nil && err != redis.Nil { + if err != nil && errors.Is(err, redis.Nil) { return nil, err } if m == nil { diff --git a/pkg/common/storage/controller/auth.go b/pkg/common/storage/controller/auth.go index 95c479a8d9..5f2e4840cb 100644 --- a/pkg/common/storage/controller/auth.go +++ b/pkg/common/storage/controller/auth.go @@ -34,14 +34,26 @@ type authDatabase struct { accessSecret string accessExpire int64 multiLogin multiLoginConfig + adminUserIDs []string } -func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin) AuthDatabase { +func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin, adminUserIDs []string) AuthDatabase { return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLogin: multiLoginConfig{ Policy: multiLogin.Policy, MaxNumOneEnd: multiLogin.MaxNumOneEnd, - }, - adminUserIDs: adminUserIDs, + CustomizeLoginNum: map[int]int{ + constant.IOSPlatformID: multiLogin.CustomizeLoginNum.IOS, + constant.AndroidPlatformID: multiLogin.CustomizeLoginNum.Android, + constant.WindowsPlatformID: multiLogin.CustomizeLoginNum.Windows, + constant.OSXPlatformID: multiLogin.CustomizeLoginNum.OSX, + constant.WebPlatformID: multiLogin.CustomizeLoginNum.Web, + constant.MiniWebPlatformID: multiLogin.CustomizeLoginNum.MiniWeb, + constant.LinuxPlatformID: multiLogin.CustomizeLoginNum.Linux, + constant.AndroidPadPlatformID: multiLogin.CustomizeLoginNum.APad, + constant.IPadPlatformID: multiLogin.CustomizeLoginNum.IPad, + constant.AdminPlatformID: multiLogin.CustomizeLoginNum.Admin, + }, + }, adminUserIDs: adminUserIDs, } } @@ -79,27 +91,31 @@ func (a *authDatabase) BatchSetTokenMapByUidPid(ctx context.Context, tokens []st // Create Token. func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { - tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID) - if err != nil { - return "", err - } - deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) - if err != nil { - return "", err - } - if len(deleteTokenKey) != 0 { - err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) + isAdmin := authverify.IsManagerUserID(userID, a.adminUserIDs) + if !isAdmin { + tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID) if err != nil { return "", err } - } - if len(kickedTokenKey) != 0 { - for _, k := range kickedTokenKey { - err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) + + deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID) + if err != nil { + return "", err + } + if len(deleteTokenKey) != 0 { + err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey) if err != nil { return "", err } - log.ZDebug(ctx, "kicked token in create token", "token", k) + } + if len(kickedTokenKey) != 0 { + for _, k := range kickedTokenKey { + err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken) + if err != nil { + return "", err + } + log.ZDebug(ctx, "kicked token in create token", "token", k) + } } } @@ -110,9 +126,12 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI return "", errs.WrapMsg(err, "token.SignedString") } - if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil { - return "", err + if !isAdmin { + if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil { + return "", err + } } + return tokenString, nil } @@ -215,16 +234,16 @@ func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string return nil, nil, errs.New("unknown multiLogin policy").Wrap() } - var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd - if a.multiLogin.Policy == constant.Customize { - adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID] - } - l := len(adminToken) - if platformID == constant.AdminPlatformID { - l++ - } - if l > adminTokenMaxNum { - kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) - } + //var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd + //if a.multiLogin.Policy == constant.Customize { + // adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID] + //} + //l := len(adminToken) + //if platformID == constant.AdminPlatformID { + // l++ + //} + //if l > adminTokenMaxNum { + // kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...) + //} return deleteToken, kickToken, nil } diff --git a/pkg/common/storage/controller/msg.go b/pkg/common/storage/controller/msg.go index 598026eee1..789adb1f66 100644 --- a/pkg/common/storage/controller/msg.go +++ b/pkg/common/storage/controller/msg.go @@ -490,7 +490,7 @@ func (db *commonMsgDatabase) GetMsgBySeqs(ctx context.Context, userID string, co } successMsgs, failedSeqs, err := db.msg.GetMessagesBySeq(ctx, conversationID, newSeqs) if err != nil { - if err != redis.Nil { + if errors.Is(err, redis.Nil) { log.ZError(ctx, "get message from redis exception", err, "failedSeqs", failedSeqs, "conversationID", conversationID) } }