diff --git a/middleware/jwt/jwt.go b/middleware/jwt/jwt.go index 0d09fc1..cd684de 100644 --- a/middleware/jwt/jwt.go +++ b/middleware/jwt/jwt.go @@ -1,11 +1,11 @@ package jwt import ( + "github.com/dgrijalva/jwt-go" "github.com/gin-gonic/gin" "k8s-webshell/pkg/e" "k8s-webshell/pkg/utils" "net/http" - "time" ) func JWT() gin.HandlerFunc { @@ -15,6 +15,7 @@ func JWT() gin.HandlerFunc { code = e.SUCCESS token := c.Query("token") + if token == "" { code = e.INVALID_PARAMS @@ -22,17 +23,25 @@ func JWT() gin.HandlerFunc { claims, err := utils.ParseToken(token) if err != nil { - code = e.ERROR_AUTH_CHECK_TOKEN_FAIL - } else if time.Now().Unix() > claims.ExpiresAt { - code = e.ERROR_AUTH_CHECK_TOKEN_TIMEOUT + + switch err.(*jwt.ValidationError).Errors { + + case jwt.ValidationErrorExpired: + code = e.ERROR_AUTH_CHECK_TOKEN_TIMEOUT + default: + code = e.ERROR_AUTH_CHECK_TOKEN_FAIL + + } + + } else { + c.Set("podNs", claims.PodNs) + c.Set("podName", claims.PodName) + c.Set("containerName", claims.ContainerName) + c.Set("paasUser", claims.PaasUser) + } - c.Set("podNs", claims.PodNs) - c.Set("podName", claims.PodName) - c.Set("containerName", claims.ContainerName) - c.Set("paasUser", claims.PaasUser) } - if code != e.SUCCESS { c.JSON(http.StatusUnauthorized, gin.H{ "code": code, diff --git a/pkg/api/controller.go b/pkg/api/controller.go index 0a727e4..86d469f 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -1,8 +1,10 @@ package api import ( + "bytes" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" "k8s-webshell/pkg/common" "k8s-webshell/pkg/utils" "k8s-webshell/pkg/ws" @@ -12,6 +14,7 @@ import ( "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" "k8s.io/client-go/tools/remotecommand" + "regexp" ) var ( @@ -35,6 +38,8 @@ type streamHandler struct { podName *string podNs *string paasUser *string + logBuff *bytes.Buffer + moveCursor int } // web终端发来的包 @@ -52,8 +57,122 @@ func (handler *streamHandler) Next() (size *remotecommand.TerminalSize) { return } -// executor 回调读取web端的输入 +func (handler *streamHandler) RuneSliceDeleteStr() { + + defer func() { + if r := recover(); r != nil { + utils.Logger.Warn("Recovered in RuneSliceDeleteStr:", r) + } + }() + runeSlice := []rune(handler.logBuff.String()) + + if len(runeSlice) > handler.moveCursor { + deleteIndex := len(runeSlice) - handler.moveCursor + runeSlice = append(runeSlice[:deleteIndex-1], runeSlice[deleteIndex:]...) + handler.logBuff.Reset() + handler.logBuff.WriteString(string(runeSlice)) + } + + runeSlice = nil +} +func (handler *streamHandler) RuneSliceInsertStr(insertString *string) { + defer func() { + if r := recover(); r != nil { + utils.Logger.Warn("Recovered in RuneSliceInsertStr:", r) + } + }() + runeSlice := []rune(handler.logBuff.String()) + insertIndex := len(runeSlice) - handler.moveCursor + runeSlice = append(runeSlice[:insertIndex], append([]rune(*insertString), runeSlice[insertIndex:]...)...) + handler.logBuff.Reset() + handler.logBuff.WriteString(string(runeSlice)) + + runeSlice = nil + +} + +func (handler *streamHandler) RecordCommand(inputString *string) { + defer func() { + if r := recover(); r != nil { + utils.Logger.Warn("Recovered in RecordCommand:", r) + } + }() + var ( + n int + ) + + if len(*inputString) > 0 { + n = len(*inputString) - 1 + } + invalidChart, _ := regexp.MatchString(`\s?\[\d{1,100}\;9R`, *inputString) + leftMoveCursor, _ := regexp.MatchString(`\s?\[D`, *inputString) + rightMoveCursor, _ := regexp.MatchString(`\s?\[C`, *inputString) + switch { + case invalidChart: + utils.Logger.Info("enter >> :", []rune(*inputString)) + + case leftMoveCursor: + cmdLens := len([]rune(handler.logBuff.String())) + if cmdLens-handler.moveCursor != 0 { + handler.moveCursor += 1 + } + + case rightMoveCursor: + handler.moveCursor -= 1 + + case []byte(*inputString)[n] == 12: //12 FF (NP form feed, new page) + utils.Logger.WithFields(logrus.Fields{ + "PassUser": handler.paasUser, + "PodName": handler.podName, + "NameSpace": handler.podNs, + "command": "clear screen", + }).Info("record input") + + case []byte(*inputString)[n] == 13: // 13 CR (carriage return) + handler.moveCursor = 0 // cursor flag reset + if len(*inputString) > 1 { + handler.logBuff.WriteString(*inputString) + } + if len(handler.logBuff.String()) > 0 { + utils.Logger.WithFields(logrus.Fields{ + "PassUser": handler.paasUser, + "PodName": handler.podName, + "NameSpace": handler.podNs, + "command": handler.logBuff.String(), + }).Info("record input") + } + + handler.logBuff.Reset() + //case []byte(*inputString)[n] == 37: + // utils.Logger.Info("fangxiangjian", []byte(*inputString)[n]) + case []byte(*inputString)[n] == 127: // 127 DEL + + if len([]rune(handler.logBuff.String())) > 0 { + handler.RuneSliceDeleteStr() + } + + case []byte(*inputString)[n] == 3: + utils.Logger.WithFields(logrus.Fields{ + "PassUser": handler.paasUser, + "PodName": handler.podName, + "NameSpace": handler.podNs, + "command": "ctrl + c", + }).Info("record input") + handler.logBuff.Reset() + + default: + if handler.moveCursor != 0 { + handler.RuneSliceInsertStr(inputString) + } else { + handler.logBuff.WriteString(*inputString) + } + + } + +} + +// executor 回调读取web端的输入 func (handler *streamHandler) Read(p []byte) (size int, err error) { var ( msg *ws.WsMessage @@ -64,20 +183,21 @@ func (handler *streamHandler) Read(p []byte) (size int, err error) { if msg, err = handler.wsConn.WsRead(); err != nil { return } - // 解析客户端请求 + if err = json.Unmarshal(msg.Data, &xtermMsg); err != nil { return } - // web ssh 调整了终端大小 - if xtermMsg.MsgType == "resize" { + switch xtermMsg.MsgType { + case "resize": // 放到channel里, 等remotecommand executor 调用我们的Next取走 handler.resizeEvent <- remotecommand.TerminalSize{Width: xtermMsg.Cols, Height: xtermMsg.Rows} - } else if xtermMsg.MsgType == "input" { // web ssh 终端输入了字符 - // copy 到p数组中 + case "input": + // web ssh 终端输入了字符 size = len(xtermMsg.Input) + // copy 到p数组中 copy(p, xtermMsg.Input) - + handler.RecordCommand(&xtermMsg.Input) } return @@ -88,6 +208,7 @@ func (handler *streamHandler) Read(p []byte) (size int, err error) { func (handler *streamHandler) Write(p []byte) (size int, err error) { size = len(p) copy := append(make([]byte, 0, size), p...) // 解决 发送数据丢失的问题 + err = handler.wsConn.WsWrite(websocket.BinaryMessage, copy) return } @@ -116,7 +237,9 @@ func WsHandler(c *gin.Context) { utils.Logger.Info("up to ws error:", err) return } + //var logBuff bytes.Buffer + logBuff := bytes.NewBufferString("") // 获取k8s rest client 配置 if restConf, err = common.GetRestConf(); err != nil { utils.Logger.Info("get kubeconfig error ", err) @@ -142,7 +265,17 @@ func WsHandler(c *gin.Context) { } // 配置与容器之间的数据流处理回调 - handler = &streamHandler{wsConn: wsConn, resizeEvent: make(chan remotecommand.TerminalSize), podName: &podName, podNs: &podNs, paasUser: &paasUser} + + handler = &streamHandler{ + wsConn: wsConn, + resizeEvent: make(chan remotecommand.TerminalSize), + podName: &podName, + podNs: &podNs, + paasUser: &paasUser, + logBuff: logBuff, + moveCursor: 0, + } + utils.Logger.Infof("Start to exec command from pod:%s,", podName) if err = executor.Stream(remotecommand.StreamOptions{ Stdin: handler, Stdout: handler, @@ -152,6 +285,9 @@ func WsHandler(c *gin.Context) { }); err != nil { goto END } + + defer handler.logBuff.Reset() + return END: diff --git a/pkg/utils/jwt.go b/pkg/utils/jwt.go index 23afd44..51e62f6 100644 --- a/pkg/utils/jwt.go +++ b/pkg/utils/jwt.go @@ -21,7 +21,7 @@ type Claims struct { func GenerateToken(secretKey, paasUser, podNs, podName, containerName string) (string, error) { nowTime := time.Now() - expireTime := nowTime.Add(3 * time.Hour) + expireTime := nowTime.Add(20 * time.Minute) claims := Claims{ secretKey,