diff --git a/README.md b/README.md index 1ab6fc5..d9f5d31 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,22 @@ _`~/` 代表 HOME 目录。在 Windows 中,请将下文的 `~/` 替换成 `C:\ #!! GroupLabels label4 group5 ``` +## 自动交互 + +- 支持类似 `expect` 的自动交互功能,可以在登录服务器之后,自动匹配服务器的输出,然后自动输入。 + + ``` + Host auto + #!! ExpectCount 2 # 配置自动交互的次数,默认是 0 即无自动交互 + #!! ExpectTimeout 30 # 配置自动交互的超时时间(单位:秒),默认是 30 秒 + #!! ExpectPattern1 *password # 配置第一个自动交互的匹配表达式 + # 配置第一个自动输入(密文),填 tssh --enc-secret 编码后的字符串,会自动发送 \r 回车 + #!! ExpectSendPass1 d7983b4a8ac204bd073ed04741913befd4fbf813ad405d7404cb7d779536f8b87e71106d7780b2 + #!! ExpectPattern2 $ # 配置第二个自动交互的匹配表达式 + #!! ExpectSendText2 echo tssh expect\r # 配置第二个自动输入(明文),需要指定 \r 才会发送回车 + # 以上 ExpectSendPass? 和 ExpectSendText? 只要二选一即可,若都配置则 ExpectSendPass? 的优先级更高 + ``` + ## 记住密码 - 为了兼容标准 ssh ,密码可以单独配置在 `~/.ssh/password` 中,也可以在 `~/.ssh/config` 中加上 `#!!` 前缀。 diff --git a/tssh/expect.go b/tssh/expect.go new file mode 100644 index 0000000..ca9f082 --- /dev/null +++ b/tssh/expect.go @@ -0,0 +1,218 @@ +/* +MIT License + +Copyright (c) 2023 Lonny Wong +Copyright (c) 2023 [Contributors](https://github.com/trzsz/trzsz-ssh/graphs/contributors) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +package tssh + +import ( + "context" + "fmt" + "io" + "regexp" + "strconv" + "strings" + "time" +) + +const kDefaultExpectTimeout = 30 + +func decodeExpectText(text string) string { + var buf strings.Builder + state := byte(0) + for _, c := range text { + if state == 0 { + if c == '\\' { + state = '\\' + continue + } + buf.WriteRune(c) + continue + } + state = 0 + switch c { + case '\\': + buf.WriteRune('\\') + case 'r': + buf.WriteRune('\r') + default: + warning("token [\\%c] in [%s] is not supported yet", c, text) + buf.WriteRune('\\') + buf.WriteRune(c) + } + } + if state != 0 { + warning("[%s] ends with \\ is invalid", text) + buf.WriteRune('\\') + } + return buf.String() +} + +type sshExpect struct { + outputChan chan []byte + outputBuffer strings.Builder + expectContext context.Context +} + +func (e *sshExpect) wrapOutput(reader io.Reader, writer io.Writer) { + for { + buffer := make([]byte, 32*1024) + n, err := reader.Read(buffer) + if n > 0 { + buf := buffer[:n] + if e.expectContext.Err() != nil { + if err := writeAll(writer, buf); err != nil { + warning("expect wrap output write failed: %v", err) + } + break + } + e.outputChan <- buf + } + if err == io.EOF { + return + } + if err != nil { + warning("expect wrap output read failed: %v", err) + return + } + } + if _, err := io.Copy(writer, reader); err != nil && err != io.EOF { + warning("expect wrap output failed: %v", err) + } +} + +func (e *sshExpect) waitForPattern(pattern string) error { + expr := strings.ReplaceAll(pattern, "*", ".*") + re, err := regexp.Compile(expr) + if err != nil { + warning("compile expect expr [%s] failed: %v", expr, err) + return err + } + e.outputBuffer.Reset() + for { + select { + case <-e.expectContext.Done(): + warning("expect timeout") + return e.expectContext.Err() + case buf := <-e.outputChan: + output := string(buf) + debug("expect output: %s", strconv.QuoteToASCII(output)) + e.outputBuffer.WriteString(output) + } + if re.MatchString(e.outputBuffer.String()) { + debug("expect match: %s", pattern) + return nil + } + } +} + +func (e *sshExpect) execInteractions(args *sshArgs, writer io.Writer, expectCount uint32) { + for i := uint32(1); i <= expectCount; i++ { + pattern := getExOptionConfig(args, fmt.Sprintf("ExpectPattern%d", i)) + debug("expect pattern %d: %s", i, pattern) + if pattern != "" { + if err := e.waitForPattern(pattern); err != nil { + return + } + } + if e.expectContext.Err() != nil { + return + } + var input string + pass := getExOptionConfig(args, fmt.Sprintf("ExpectSendPass%d", i)) + if pass != "" { + secret, err := decodeSecret(pass) + if err != nil { + warning("decode secret [%s] failed: %v", pass, err) + return + } + debug("expect send %d: %s", i, strings.Repeat("*", len(secret))) + input = secret + "\r" + } else { + text := getExOptionConfig(args, fmt.Sprintf("ExpectSendText%d", i)) + if text == "" { + continue + } + debug("expect send %d: %s", i, text) + input = decodeExpectText(text) + } + if err := writeAll(writer, []byte(input)); err != nil { + warning("expect send input failed: %v", err) + return + } + } +} + +func getExpectCount(args *sshArgs) uint32 { + expectCount := getExOptionConfig(args, "ExpectCount") + if expectCount == "" { + return 0 + } + count, err := strconv.ParseUint(expectCount, 10, 32) + if err != nil { + warning("Invalid ExpectCount [%s]: %v", expectCount, err) + return 0 + } + return uint32(count) +} + +func getExpectTimeout(args *sshArgs) uint32 { + expectCount := getExOptionConfig(args, "ExpectTimeout") + if expectCount == "" { + return kDefaultExpectTimeout + } + count, err := strconv.ParseUint(expectCount, 10, 32) + if err != nil { + warning("Invalid ExpectTimeout [%s]: %v", expectCount, err) + return kDefaultExpectTimeout + } + return uint32(count) +} + +func execExpectInteractions(args *sshArgs, serverIn io.Writer, + serverOut io.Reader, serverErr io.Reader) (io.Reader, io.Reader) { + expectCount := getExpectCount(args) + if expectCount <= 0 { + return serverOut, serverErr + } + + outReader, outWriter := io.Pipe() + errReader, errWriter := io.Pipe() + + var ctx context.Context + var cancel context.CancelFunc + if expectTimeout := getExpectTimeout(args); expectTimeout > 0 { + ctx, cancel = context.WithTimeout(context.Background(), time.Duration(expectTimeout)*time.Second) + } else { + ctx, cancel = context.WithCancel(context.Background()) + } + defer cancel() + + expect := &sshExpect{outputChan: make(chan []byte, 1), expectContext: ctx} + go expect.wrapOutput(serverOut, outWriter) + go expect.wrapOutput(serverErr, errWriter) + + expect.execInteractions(args, serverIn, expectCount) + + return outReader, errReader +} diff --git a/tssh/login.go b/tssh/login.go index 1fd1c70..84e631f 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -892,7 +892,8 @@ func sshAgentForward(args *sshArgs, client *ssh.Client, session *ssh.Session) { debug("request ssh agent forwarding success") } -func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session, serverIn io.WriteCloser, serverOut io.Reader, err error) { +func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session, + serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader, err error) { defer func() { if err != nil { if session != nil { @@ -960,7 +961,11 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session err = fmt.Errorf("stdout pipe failed: %v", err) return } - session.Stderr = os.Stderr + serverErr, err = session.StderrPipe() + if err != nil { + err = fmt.Errorf("stderr pipe failed: %v", err) + return + } // ssh agent forward if !control { diff --git a/tssh/main.go b/tssh/main.go index 55c296e..1c4a722 100644 --- a/tssh/main.go +++ b/tssh/main.go @@ -260,7 +260,7 @@ func sshStart(args *sshArgs) error { } // ssh login - client, session, serverIn, serverOut, err := sshLogin(args, tty) + client, session, serverIn, serverOut, serverErr, err := sshLogin(args, tty) if err != nil { return err } @@ -302,6 +302,9 @@ func sshStart(args *sshArgs) error { } } + // execute expect interactions if necessary + serverOut, serverErr = execExpectInteractions(args, serverIn, serverOut, serverErr) + // make stdin raw if isTerminal && tty { state, err := makeStdinRaw() @@ -312,7 +315,7 @@ func sshStart(args *sshArgs) error { } // enable trzsz - if err := enableTrzsz(args, client, session, serverIn, serverOut, tty); err != nil { + if err := enableTrzsz(args, client, session, serverIn, serverOut, serverErr, tty); err != nil { return err } diff --git a/tssh/tokens.go b/tssh/tokens.go index 37d3401..8016814 100644 --- a/tssh/tokens.go +++ b/tssh/tokens.go @@ -90,6 +90,7 @@ func expandTokens(str string, args *sshArgs, param *loginParam, tokens string) s } if state != 0 { warning("[%s] ends with %% is invalid", str) + buf.WriteRune('%') } return buf.String() } diff --git a/tssh/tokens_test.go b/tssh/tokens_test.go index 1432511..8774188 100644 --- a/tssh/tokens_test.go +++ b/tssh/tokens_test.go @@ -89,5 +89,5 @@ func TestExpandTokens(t *testing.T) { assertControlPath("%j", "%j", "token [%j] in [%j] is not supported") assertControlPath("p_%h_%d", "p_127.0.0.1_%d", "token [%d] in [p_%h_%d] is not supported yet") - assertControlPath("h%", "h", "[h%] ends with % is invalid") + assertControlPath("h%", "h%", "[h%] ends with % is invalid") } diff --git a/tssh/trzsz.go b/tssh/trzsz.go index 08ec147..c5014d8 100644 --- a/tssh/trzsz.go +++ b/tssh/trzsz.go @@ -39,7 +39,20 @@ import ( "golang.org/x/crypto/ssh" ) -func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) { +func writeAll(dst io.Writer, data []byte) error { + m := 0 + l := len(data) + for m < l { + n, err := dst.Write(data[m:]) + if err != nil { + return err + } + m += n + } + return nil +} + +func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader, tty bool) { win := runtime.GOOS == "windows" forwardIO := func(reader io.Reader, writer io.WriteCloser, oldVal, newVal []byte) { defer writer.Close() @@ -51,14 +64,9 @@ func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) { if win && !tty { buf = bytes.ReplaceAll(buf, oldVal, newVal) } - w := 0 - for w < len(buf) { - n, err := writer.Write(buf[w:]) - if err != nil { - warning("wrap stdio write failed: %v", err) - return - } - w += n + if err := writeAll(writer, buf); err != nil { + warning("wrap stdio write failed: %v", err) + return } } if err == io.EOF { @@ -74,26 +82,36 @@ func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) { } } } - go forwardIO(os.Stdin, serverIn, []byte("\r\n"), []byte("\n")) - go forwardIO(serverOut, os.Stdout, []byte("\n"), []byte("\r\n")) + if serverIn != nil { + go forwardIO(os.Stdin, serverIn, []byte("\r\n"), []byte("\n")) + } + if serverOut != nil { + go forwardIO(serverOut, os.Stdout, []byte("\n"), []byte("\r\n")) + } + if serverErr != nil { + go forwardIO(serverErr, os.Stderr, []byte("\n"), []byte("\r\n")) + } } -func enableTrzsz(args *sshArgs, client *ssh.Client, session *ssh.Session, serverIn io.WriteCloser, serverOut io.Reader, tty bool) error { +func enableTrzsz(args *sshArgs, client *ssh.Client, session *ssh.Session, + serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader, tty bool) error { // not terminal or not tty if !isTerminal || !tty { - wrapStdIO(serverIn, serverOut, tty) + wrapStdIO(serverIn, serverOut, serverErr, tty) return nil } // disable trzsz ( trz / tsz ) if strings.ToLower(getExOptionConfig(args, "EnableTrzsz")) == "no" { - wrapStdIO(serverIn, serverOut, tty) + wrapStdIO(serverIn, serverOut, serverErr, tty) onTerminalResize(func(width, height int) { _ = session.WindowChange(height, width) }) return nil } // support trzsz ( trz / tsz ) + wrapStdIO(nil, nil, serverErr, tty) + trzsz.SetAffectedByWindows(false) if args.Relay || isNoGUI() {