From c0cc87271d54ae2a9cd1c9e7fc390e7a75b14de9 Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sat, 16 Dec 2023 13:59:10 +0800 Subject: [PATCH] improve expect feature --- README.md | 13 ++- tssh/expect.go | 240 ++++++++++++++++++++++++++++++++++++++++--------- tssh/tokens.go | 7 +- 3 files changed, 212 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 9be09a1..d9a9e7c 100644 --- a/README.md +++ b/README.md @@ -229,16 +229,25 @@ _`~/` 代表 HOME 目录。在 Windows 中,请将下文的 `~/` 替换成 `C:\ ``` Host auto - #!! ExpectCount 2 # 配置自动交互的次数,默认是 0 即无自动交互 + #!! ExpectCount 3 # 配置自动交互的次数,默认是 0 即无自动交互 #!! ExpectTimeout 30 # 配置自动交互的超时时间(单位:秒),默认是 30 秒 #!! ExpectPattern1 *password # 配置第一个自动交互的匹配表达式 # 配置第一个自动输入(密文),填 tssh --enc-secret 编码后的字符串,会自动发送 \r 回车 #!! ExpectSendPass1 d7983b4a8ac204bd073ed04741913befd4fbf813ad405d7404cb7d779536f8b87e71106d7780b2 - #!! ExpectPattern2 $ # 配置第二个自动交互的匹配表达式 + #!! ExpectPattern2 hostname*$ # 配置第二个自动交互的匹配表达式 #!! ExpectSendText2 echo tssh expect\r # 配置第二个自动输入(明文),需要指定 \r 才会发送回车 # 以上 ExpectSendPass? 和 ExpectSendText? 只要二选一即可,若都配置则 ExpectSendPass? 的优先级更高 + # -------------------------------------------------- + # 在每个 ExpectPattern 匹配之前,可以配置一个或多个可选的匹配,用法如下: + #!! ExpectPattern3 hostname*$ # 配置第三个自动交互的匹配表达式 + #!! ExpectSendText3 ssh xxx\r # 配置第三个自动输入,也可以换成 ExpectSendPass3 然后配置密文 + #!! ExpectCaseSendText3 yes/no y\r # 在 ExpectPattern3 匹配之前,若遇到 yes/no 则发送 y 并回车 + #!! ExpectCaseSendText3 y/n yes\r # 在 ExpectPattern3 匹配之前,若遇到 y/n 则发送 yes 并回车 + #!! ExpectCaseSendPass3 token d7... # 在 ExpectPattern3 匹配之前,若遇到 token 则解码并发送 d7... ``` + 使用 `tssh --debug` 登录,可以看到 `expect` 捕获到的输出,以及其匹配结果和自动输入的交互。 + ## 记住密码 - 为了兼容标准 ssh ,密码可以单独配置在 `~/.ssh/password` 中,也可以在 `~/.ssh/config` 中加上 `#!!` 前缀。 diff --git a/tssh/expect.go b/tssh/expect.go index ca9f082..e6c886f 100644 --- a/tssh/expect.go +++ b/tssh/expect.go @@ -33,6 +33,7 @@ import ( "strconv" "strings" "time" + "unicode" ) const kDefaultExpectTimeout = 30 @@ -42,11 +43,12 @@ func decodeExpectText(text string) string { state := byte(0) for _, c := range text { if state == 0 { - if c == '\\' { + switch c { + case '\\': state = '\\' - continue + default: + buf.WriteRune(c) } - buf.WriteRune(c) continue } state = 0 @@ -68,88 +70,231 @@ func decodeExpectText(text string) string { return buf.String() } +func quoteExpectPattern(pattern string) string { + var buf strings.Builder + for _, c := range pattern { + switch c { + case '*': + buf.WriteString(".*") + case '?', '(', ')', '[', ']', '{', '}', '.', '+', ',', '-', '^', '$', '|', '\\': + buf.WriteRune('\\') + buf.WriteRune(c) + default: + buf.WriteRune(c) + } + } + return buf.String() +} + +type caseSend struct { + pattern string + display string + input []byte + re *regexp.Regexp + buffer strings.Builder +} + +type caseSendList struct { + writer io.Writer + list []*caseSend +} + +func (c *caseSendList) splitConfig(config string) (string, string, error) { + index := strings.IndexFunc(config, unicode.IsSpace) + if index <= 0 { + return "", "", fmt.Errorf("invalid expect case send: %s", config) + } + pattern := strings.TrimSpace(config[:index]) + send := strings.TrimSpace(config[index+1:]) + if pattern == "" || send == "" { + return "", "", fmt.Errorf("invalid expect case send: %s", config) + } + return pattern, send, nil +} + +func (c *caseSendList) addCase(re *regexp.Regexp, pattern, display, input string) { + c.list = append(c.list, &caseSend{ + pattern: pattern, + display: display, + input: []byte(input), + re: re, + }) +} + +func (c *caseSendList) addCaseSendPass(config string) error { + pattern, secret, err := c.splitConfig(config) + if err != nil { + return err + } + expr := quoteExpectPattern(pattern) + re, err := regexp.Compile(expr) + if err != nil { + return fmt.Errorf("compile expect expr [%s] failed: %v", expr, err) + } + pass, err := decodeSecret(secret) + if err != nil { + return fmt.Errorf("decode secret [%s] failed: %v", secret, err) + } + c.addCase(re, pattern, strings.Repeat("*", len(pass))+"\\r", pass+"\r") + return nil +} + +func (c *caseSendList) addCaseSendText(config string) error { + pattern, text, err := c.splitConfig(config) + if err != nil { + return err + } + expr := quoteExpectPattern(pattern) + re, err := regexp.Compile(expr) + if err != nil { + return fmt.Errorf("compile expect expr [%s] failed: %v", expr, err) + } + c.addCase(re, pattern, text, decodeExpectText(text)) + return nil +} + +func (c *caseSendList) handleOutput(output string) { + for _, cs := range c.list { + cs.buffer.WriteString(output) + if cs.re.MatchString(cs.buffer.String()) { + debug("expect case match: %s", cs.pattern) + debug("expect case send: %s", cs.display) + if err := writeAll(c.writer, cs.input); err != nil { + warning("expect send input failed: %v", err) + } + cs.buffer.Reset() + } else { + debug("expect case not match: %s", cs.pattern) + } + } +} + type sshExpect struct { - outputChan chan []byte - outputBuffer strings.Builder - expectContext context.Context + ctx context.Context + out chan []byte + err chan []byte } -func (e *sshExpect) wrapOutput(reader io.Reader, writer io.Writer) { - for { +func (e *sshExpect) captureOutput(reader io.Reader, ch chan<- []byte) ([]byte, error) { + defer close(ch) + for e.ctx.Err() == nil { buffer := make([]byte, 32*1024) n, err := reader.Read(buffer) if n > 0 { buf := buffer[:n] - if e.expectContext.Err() != nil { - if err := writeAll(writer, buf); err != nil { - warning("expect wrap output write failed: %v", err) - } - break + select { + case <-e.ctx.Done(): + return buf, nil + case ch <- buf: } - e.outputChan <- buf } if err == io.EOF { - return + return nil, err } if err != nil { - warning("expect wrap output read failed: %v", err) + warning("expect read output failed: %v", err) + return nil, err + } + } + return nil, nil +} + +func (e *sshExpect) wrapOutput(reader io.Reader, writer io.Writer, ch chan []byte) { + buf, err := e.captureOutput(reader, ch) + if err != nil { + return + } + for data := range ch { + if err := writeAll(writer, data); err != nil { + warning("expect write output failed: %v", err) + return + } + } + if buf != nil { + if err := writeAll(writer, buf); err != nil { + warning("expect write output failed: %v", err) return } } if _, err := io.Copy(writer, reader); err != nil && err != io.EOF { - warning("expect wrap output failed: %v", err) + warning("expect copy output failed: %v", err) } } -func (e *sshExpect) waitForPattern(pattern string) error { - expr := strings.ReplaceAll(pattern, "*", ".*") +func (e *sshExpect) waitForPattern(pattern string, caseSends *caseSendList) error { + expr := quoteExpectPattern(pattern) re, err := regexp.Compile(expr) if err != nil { warning("compile expect expr [%s] failed: %v", expr, err) return err } - e.outputBuffer.Reset() + var builder strings.Builder for { + var buf []byte select { - case <-e.expectContext.Done(): + case <-e.ctx.Done(): warning("expect timeout") - return e.expectContext.Err() - case buf := <-e.outputChan: - output := string(buf) - debug("expect output: %s", strconv.QuoteToASCII(output)) - e.outputBuffer.WriteString(output) + return e.ctx.Err() + case buf = <-e.out: + case buf = <-e.err: } - if re.MatchString(e.outputBuffer.String()) { + output := strconv.QuoteToASCII(string(buf)) + debug("expect output: %s", output) + caseSends.handleOutput(output[1 : len(output)-1]) + builder.WriteString(output[1 : len(output)-1]) + if re.MatchString(builder.String()) { debug("expect match: %s", pattern) - return nil + // cleanup for next expect + for { + select { + case buf = <-e.out: + case buf = <-e.err: + default: + return nil + } + debug("expect output: %s", strconv.QuoteToASCII(string(buf))) + } + } else { + debug("expect not match: %s", pattern) } } } -func (e *sshExpect) execInteractions(args *sshArgs, writer io.Writer, expectCount uint32) { +func (e *sshExpect) execInteractions(alias string, writer io.Writer, expectCount uint32) { for i := uint32(1); i <= expectCount; i++ { - pattern := getExOptionConfig(args, fmt.Sprintf("ExpectPattern%d", i)) + pattern := getExConfig(alias, fmt.Sprintf("ExpectPattern%d", i)) debug("expect pattern %d: %s", i, pattern) if pattern != "" { - if err := e.waitForPattern(pattern); err != nil { + caseSends := &caseSendList{writer: writer} + for _, cfg := range getAllExConfig(alias, fmt.Sprintf("ExpectCaseSendPass%d", i)) { + if err := caseSends.addCaseSendPass(cfg); err != nil { + warning("Invalid ExpectCaseSendPass%d: %v", i, err) + } + } + for _, cfg := range getAllExConfig(alias, fmt.Sprintf("ExpectCaseSendText%d", i)) { + if err := caseSends.addCaseSendText(cfg); err != nil { + warning("Invalid ExpectCaseSendText%d: %v", i, err) + } + } + if err := e.waitForPattern(pattern, caseSends); err != nil { return } } - if e.expectContext.Err() != nil { + if e.ctx.Err() != nil { return } var input string - pass := getExOptionConfig(args, fmt.Sprintf("ExpectSendPass%d", i)) - if pass != "" { - secret, err := decodeSecret(pass) + secret := getExConfig(alias, fmt.Sprintf("ExpectSendPass%d", i)) + if secret != "" { + pass, err := decodeSecret(secret) if err != nil { - warning("decode secret [%s] failed: %v", pass, err) + warning("decode secret [%s] failed: %v", secret, err) return } - debug("expect send %d: %s", i, strings.Repeat("*", len(secret))) - input = secret + "\r" + debug("expect send %d: %s\\r", i, strings.Repeat("*", len(pass))) + input = pass + "\r" } else { - text := getExOptionConfig(args, fmt.Sprintf("ExpectSendText%d", i)) + text := getExConfig(alias, fmt.Sprintf("ExpectSendText%d", i)) if text == "" { continue } @@ -208,11 +353,20 @@ func execExpectInteractions(args *sshArgs, serverIn io.Writer, } defer cancel() - expect := &sshExpect{outputChan: make(chan []byte, 1), expectContext: ctx} - go expect.wrapOutput(serverOut, outWriter) - go expect.wrapOutput(serverErr, errWriter) + expect := &sshExpect{ + ctx: ctx, + out: make(chan []byte, 10), + err: make(chan []byte, 10), + } + go expect.wrapOutput(serverOut, outWriter, expect.out) + go expect.wrapOutput(serverErr, errWriter, expect.err) + + expect.execInteractions(args.Destination, serverIn, expectCount) - expect.execInteractions(args, serverIn, expectCount) + if ctx.Err() == context.DeadlineExceeded { + // enter for shell prompt if timeout + _, _ = serverIn.Write([]byte("\r")) + } return outReader, errReader } diff --git a/tssh/tokens.go b/tssh/tokens.go index 8016814..682876a 100644 --- a/tssh/tokens.go +++ b/tssh/tokens.go @@ -46,11 +46,12 @@ func expandTokens(str string, args *sshArgs, param *loginParam, tokens string) s state := byte(0) for _, c := range str { if state == 0 { - if c == '%' { + switch c { + case '%': state = '%' - continue + default: + buf.WriteRune(c) } - buf.WriteRune(c) continue } state = 0