Skip to content

Commit

Permalink
support expect feature
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Dec 10, 2023
1 parent a44d060 commit 3b2d6c9
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 19 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,22 @@ _`~/` 代表 HOME 目录。在 Windows 中,请将下文的 `~/` 替换成 `C:\
#!! GroupLabels label4 group5
```

## 自动交互

- 支持类似 `expect` 的自动交互功能,可以在登录服务器之后,自动匹配服务器的输出,然后自动输入。

```
Host auto
#!! ExpectCount 2 # 配置自动交互的次数,默认是 0 即无自动交互
#!! ExpectTimeout 30 # 配置自动交互的超时时间(单位:秒),默认是 30 秒
#!! ExpectPattern1 *password # 配置第一个自动交互的匹配表达式
# 配置第一个自动输入(密文),填 tssh --enc-secret 编码后的字符串,会自动发送 \r 回车
#!! ExpectSendPass1 d7983b4a8ac204bd073ed04741913befd4fbf813ad405d7404cb7d779536f8b87e71106d7780b2
#!! ExpectPattern2 $ # 配置第二个自动交互的匹配表达式
#!! ExpectSendText2 echo tssh expect\r # 配置第二个自动输入(明文),需要指定 \r 才会发送回车
# 以上 ExpectSendPass? 和 ExpectSendText? 只要二选一即可,若都配置则 ExpectSendPass? 的优先级更高
```

## 记住密码

- 为了兼容标准 ssh ,密码可以单独配置在 `~/.ssh/password` 中,也可以在 `~/.ssh/config` 中加上 `#!!` 前缀。
Expand Down
218 changes: 218 additions & 0 deletions tssh/expect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
/*
MIT License
Copyright (c) 2023 Lonny Wong <[email protected]>
Copyright (c) 2023 [Contributors](https://github.com/trzsz/trzsz-ssh/graphs/contributors)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

package tssh

import (
"context"
"fmt"
"io"
"regexp"
"strconv"
"strings"
"time"
)

const kDefaultExpectTimeout = 30

func decodeExpectText(text string) string {
var buf strings.Builder
state := byte(0)
for _, c := range text {
if state == 0 {
if c == '\\' {
state = '\\'
continue
}
buf.WriteRune(c)
continue
}
state = 0
switch c {
case '\\':
buf.WriteRune('\\')
case 'r':
buf.WriteRune('\r')
default:
warning("token [\\%c] in [%s] is not supported yet", c, text)
buf.WriteRune('\\')
buf.WriteRune(c)
}
}
if state != 0 {
warning("[%s] ends with \\ is invalid", text)
buf.WriteRune('\\')
}
return buf.String()
}

type sshExpect struct {
outputChan chan []byte
outputBuffer strings.Builder
expectContext context.Context
}

func (e *sshExpect) wrapOutput(reader io.Reader, writer io.Writer) {
for {
buffer := make([]byte, 32*1024)
n, err := reader.Read(buffer)
if n > 0 {
buf := buffer[:n]
if e.expectContext.Err() != nil {
if err := writeAll(writer, buf); err != nil {
warning("expect wrap output write failed: %v", err)
}
break
}
e.outputChan <- buf
}
if err == io.EOF {
return
}
if err != nil {
warning("expect wrap output read failed: %v", err)
return
}
}
if _, err := io.Copy(writer, reader); err != nil && err != io.EOF {
warning("expect wrap output failed: %v", err)
}
}

func (e *sshExpect) waitForPattern(pattern string) error {
expr := strings.ReplaceAll(pattern, "*", ".*")
re, err := regexp.Compile(expr)
if err != nil {
warning("compile expect expr [%s] failed: %v", expr, err)
return err
}
e.outputBuffer.Reset()
for {
select {
case <-e.expectContext.Done():
warning("expect timeout")
return e.expectContext.Err()
case buf := <-e.outputChan:
output := string(buf)
debug("expect output: %s", strconv.QuoteToASCII(output))
e.outputBuffer.WriteString(output)
}
if re.MatchString(e.outputBuffer.String()) {
debug("expect match: %s", pattern)
return nil
}
}
}

func (e *sshExpect) execInteractions(args *sshArgs, writer io.Writer, expectCount uint32) {
for i := uint32(1); i <= expectCount; i++ {
pattern := getExOptionConfig(args, fmt.Sprintf("ExpectPattern%d", i))
debug("expect pattern %d: %s", i, pattern)
if pattern != "" {
if err := e.waitForPattern(pattern); err != nil {
return
}
}
if e.expectContext.Err() != nil {
return
}
var input string
pass := getExOptionConfig(args, fmt.Sprintf("ExpectSendPass%d", i))
if pass != "" {
secret, err := decodeSecret(pass)
if err != nil {
warning("decode secret [%s] failed: %v", pass, err)
return
}
debug("expect send %d: %s", i, strings.Repeat("*", len(secret)))
input = secret + "\r"
} else {
text := getExOptionConfig(args, fmt.Sprintf("ExpectSendText%d", i))
if text == "" {
continue
}
debug("expect send %d: %s", i, text)
input = decodeExpectText(text)
}
if err := writeAll(writer, []byte(input)); err != nil {
warning("expect send input failed: %v", err)
return
}
}
}

func getExpectCount(args *sshArgs) uint32 {
expectCount := getExOptionConfig(args, "ExpectCount")
if expectCount == "" {
return 0
}
count, err := strconv.ParseUint(expectCount, 10, 32)
if err != nil {
warning("Invalid ExpectCount [%s]: %v", expectCount, err)
return 0
}
return uint32(count)
}

func getExpectTimeout(args *sshArgs) uint32 {
expectCount := getExOptionConfig(args, "ExpectTimeout")
if expectCount == "" {
return kDefaultExpectTimeout
}
count, err := strconv.ParseUint(expectCount, 10, 32)
if err != nil {
warning("Invalid ExpectTimeout [%s]: %v", expectCount, err)
return kDefaultExpectTimeout
}
return uint32(count)
}

func execExpectInteractions(args *sshArgs, serverIn io.Writer,
serverOut io.Reader, serverErr io.Reader) (io.Reader, io.Reader) {
expectCount := getExpectCount(args)
if expectCount <= 0 {
return serverOut, serverErr
}

outReader, outWriter := io.Pipe()
errReader, errWriter := io.Pipe()

var ctx context.Context
var cancel context.CancelFunc
if expectTimeout := getExpectTimeout(args); expectTimeout > 0 {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(expectTimeout)*time.Second)
} else {
ctx, cancel = context.WithCancel(context.Background())
}
defer cancel()

expect := &sshExpect{outputChan: make(chan []byte, 1), expectContext: ctx}
go expect.wrapOutput(serverOut, outWriter)
go expect.wrapOutput(serverErr, errWriter)

expect.execInteractions(args, serverIn, expectCount)

return outReader, errReader
}
9 changes: 7 additions & 2 deletions tssh/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,8 @@ func sshAgentForward(args *sshArgs, client *ssh.Client, session *ssh.Session) {
debug("request ssh agent forwarding success")
}

func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session, serverIn io.WriteCloser, serverOut io.Reader, err error) {
func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session,
serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader, err error) {
defer func() {
if err != nil {
if session != nil {
Expand Down Expand Up @@ -960,7 +961,11 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session
err = fmt.Errorf("stdout pipe failed: %v", err)
return
}
session.Stderr = os.Stderr
serverErr, err = session.StderrPipe()
if err != nil {
err = fmt.Errorf("stderr pipe failed: %v", err)
return
}

// ssh agent forward
if !control {
Expand Down
7 changes: 5 additions & 2 deletions tssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func sshStart(args *sshArgs) error {
}

// ssh login
client, session, serverIn, serverOut, err := sshLogin(args, tty)
client, session, serverIn, serverOut, serverErr, err := sshLogin(args, tty)
if err != nil {
return err
}
Expand Down Expand Up @@ -302,6 +302,9 @@ func sshStart(args *sshArgs) error {
}
}

// execute expect interactions if necessary
serverOut, serverErr = execExpectInteractions(args, serverIn, serverOut, serverErr)

// make stdin raw
if isTerminal && tty {
state, err := makeStdinRaw()
Expand All @@ -312,7 +315,7 @@ func sshStart(args *sshArgs) error {
}

// enable trzsz
if err := enableTrzsz(args, client, session, serverIn, serverOut, tty); err != nil {
if err := enableTrzsz(args, client, session, serverIn, serverOut, serverErr, tty); err != nil {
return err
}

Expand Down
1 change: 1 addition & 0 deletions tssh/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func expandTokens(str string, args *sshArgs, param *loginParam, tokens string) s
}
if state != 0 {
warning("[%s] ends with %% is invalid", str)
buf.WriteRune('%')
}
return buf.String()
}
2 changes: 1 addition & 1 deletion tssh/tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@ func TestExpandTokens(t *testing.T) {

assertControlPath("%j", "%j", "token [%j] in [%j] is not supported")
assertControlPath("p_%h_%d", "p_127.0.0.1_%d", "token [%d] in [p_%h_%d] is not supported yet")
assertControlPath("h%", "h", "[h%] ends with % is invalid")
assertControlPath("h%", "h%", "[h%] ends with % is invalid")
}
46 changes: 32 additions & 14 deletions tssh/trzsz.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,20 @@ import (
"golang.org/x/crypto/ssh"
)

func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) {
func writeAll(dst io.Writer, data []byte) error {
m := 0
l := len(data)
for m < l {
n, err := dst.Write(data[m:])
if err != nil {
return err
}
m += n
}
return nil
}

func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader, tty bool) {
win := runtime.GOOS == "windows"
forwardIO := func(reader io.Reader, writer io.WriteCloser, oldVal, newVal []byte) {
defer writer.Close()
Expand All @@ -51,14 +64,9 @@ func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) {
if win && !tty {
buf = bytes.ReplaceAll(buf, oldVal, newVal)
}
w := 0
for w < len(buf) {
n, err := writer.Write(buf[w:])
if err != nil {
warning("wrap stdio write failed: %v", err)
return
}
w += n
if err := writeAll(writer, buf); err != nil {
warning("wrap stdio write failed: %v", err)
return
}
}
if err == io.EOF {
Expand All @@ -74,26 +82,36 @@ func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) {
}
}
}
go forwardIO(os.Stdin, serverIn, []byte("\r\n"), []byte("\n"))
go forwardIO(serverOut, os.Stdout, []byte("\n"), []byte("\r\n"))
if serverIn != nil {
go forwardIO(os.Stdin, serverIn, []byte("\r\n"), []byte("\n"))
}
if serverOut != nil {
go forwardIO(serverOut, os.Stdout, []byte("\n"), []byte("\r\n"))
}
if serverErr != nil {
go forwardIO(serverErr, os.Stderr, []byte("\n"), []byte("\r\n"))
}
}

func enableTrzsz(args *sshArgs, client *ssh.Client, session *ssh.Session, serverIn io.WriteCloser, serverOut io.Reader, tty bool) error {
func enableTrzsz(args *sshArgs, client *ssh.Client, session *ssh.Session,
serverIn io.WriteCloser, serverOut io.Reader, serverErr io.Reader, tty bool) error {
// not terminal or not tty
if !isTerminal || !tty {
wrapStdIO(serverIn, serverOut, tty)
wrapStdIO(serverIn, serverOut, serverErr, tty)
return nil
}

// disable trzsz ( trz / tsz )
if strings.ToLower(getExOptionConfig(args, "EnableTrzsz")) == "no" {
wrapStdIO(serverIn, serverOut, tty)
wrapStdIO(serverIn, serverOut, serverErr, tty)
onTerminalResize(func(width, height int) { _ = session.WindowChange(height, width) })
return nil
}

// support trzsz ( trz / tsz )

wrapStdIO(nil, nil, serverErr, tty)

trzsz.SetAffectedByWindows(false)

if args.Relay || isNoGUI() {
Expand Down

0 comments on commit 3b2d6c9

Please sign in to comment.