From 23fc3106d0208c75a50e44b42270704e4a319e39 Mon Sep 17 00:00:00 2001 From: Lonny Wong <lonnywong@qq.com> Date: Sat, 16 Sep 2023 21:28:37 +0800 Subject: [PATCH] support ctrl+c to interrupt login --- tssh/forward.go | 2 +- tssh/login.go | 84 +++++++++++--------------------------------- tssh/main.go | 74 +++++++++++++++++++++----------------- tssh/prompt.go | 3 ++ tssh/term_unix.go | 29 +++++++-------- tssh/term_windows.go | 71 +++++++++++++++++++------------------ 6 files changed, 117 insertions(+), 146 deletions(-) diff --git a/tssh/forward.go b/tssh/forward.go index ddf8304..27563b4 100644 --- a/tssh/forward.go +++ b/tssh/forward.go @@ -366,7 +366,7 @@ func remoteForward(client *ssh.Client, f *forwardCfg, args *sshArgs) { fmt.Fprintf(os.Stderr, "remote forward accept failed: %v\r\n", err) continue } - local, err := net.DialTimeout("tcp", localAddr, 3*time.Second) + local, err := net.DialTimeout("tcp", localAddr, 10*time.Second) if err != nil { fmt.Fprintf(os.Stderr, "remote forward dial [%s] failed: %v\r\n", localAddr, err) remote.Close() diff --git a/tssh/login.go b/tssh/login.go index 7809e94..d56fce5 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -26,6 +26,7 @@ SOFTWARE. package tssh import ( + "bufio" "bytes" "crypto/x509" "encoding/hex" @@ -34,7 +35,6 @@ import ( "net" "os" "os/exec" - "os/signal" "os/user" "path/filepath" "runtime" @@ -133,7 +133,11 @@ func getLoginParam(args *sshArgs) (*loginParam, error) { if err != nil { return nil, fmt.Errorf("get current user failed: %v", err) } - param.user = currentUser.Username + userName = currentUser.Username + if idx := strings.LastIndexByte(userName, '\\'); idx >= 0 { + userName = userName[idx+1:] + } + param.user = userName } } @@ -187,37 +191,6 @@ func createKnownHosts(path string) error { return nil } -func readLineFromRawIO(stdin *os.File) (string, error) { - defer fmt.Fprintf(os.Stderr, "\r\n") - buffer := new(bytes.Buffer) - buf := make([]byte, 100) - for { - n, err := stdin.Read(buf) - if err != nil { - return "", nil - } - data := buf[:n] - if bytes.ContainsRune(data, '\x03') { - fmt.Fprintf(os.Stderr, "^C") - return "", fmt.Errorf("interrupt") - } - for _, b := range data { - if b == '\r' || b == '\n' { - return string(bytes.TrimSpace(buffer.Bytes())), nil - } - if b == '\x7f' { - if buffer.Len() > 0 { - buffer.Truncate(buffer.Len() - 1) - fmt.Fprintf(os.Stderr, "\b \b") - } - } else if b >= ' ' && b <= '~' { - buffer.WriteByte(b) - fmt.Fprintf(os.Stderr, "%s", string(b)) - } - } - } -} - func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error { fingerprint := ssh.FingerprintSHA256(key) fmt.Fprintf(os.Stderr, "The authenticity of host '%s' can't be established.\r\n"+ @@ -229,12 +202,14 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error { } defer closer() + reader := bufio.NewReader(stdin) fmt.Fprintf(os.Stderr, "Are you sure you want to continue connecting (yes/no/[fingerprint])? ") for { - input, err := readLineFromRawIO(stdin) + input, err := reader.ReadString('\n') if err != nil { return err } + input = strings.TrimSpace(input) if input == fingerprint { break } @@ -410,35 +385,14 @@ func getSigner(dest string, path string) (*sshSigner, error) { func readSecret(prompt string) (secret []byte, err error) { fmt.Fprintf(os.Stderr, "%s", prompt) defer fmt.Fprintf(os.Stderr, "\r\n") - errch := make(chan error, 1) - defer close(errch) - sigch := make(chan os.Signal, 1) - signal.Notify(sigch, os.Interrupt) - go func() { - for range sigch { - errch <- fmt.Errorf("interrupt") - } - }() - defer func() { signal.Stop(sigch); close(sigch) }() + stdin, closer, err := getKeyboardInput() + if err != nil { + return nil, err + } + defer closer() - go func() { - stdin, closer, err := getKeyboardInput() - if err != nil { - errch <- err - return - } - defer closer() - pw, err := term.ReadPassword(int(stdin.Fd())) - if err != nil { - errch <- err - return - } - secret = pw - errch <- nil - }() - err = <-errch - return + return term.ReadPassword(int(stdin.Fd())) } func getPasswordAuthMethod(args *sshArgs, host, user string) ssh.AuthMethod { @@ -710,7 +664,7 @@ func dialWithTimeout(client *ssh.Client, network, addr string) (conn net.Conn, e done <- struct{}{} }() select { - case <-time.After(3 * time.Second): + case <-time.After(10 * time.Second): err = fmt.Errorf("dial [%s] timeout", addr) case <-done: } @@ -755,7 +709,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e config := &ssh.ClientConfig{ User: param.user, Auth: authMethods, - Timeout: 3 * time.Second, + Timeout: 10 * time.Second, HostKeyCallback: cb, HostKeyAlgorithms: kh.HostKeyAlgorithms(param.addr), BannerCallback: func(banner string) error { @@ -878,6 +832,10 @@ func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) { } } if err == io.EOF { + if win && tty { + _, _ = writer.Write([]byte{0x1A}) // ctrl + z + continue + } break } if err != nil { diff --git a/tssh/main.go b/tssh/main.go index 7a2e202..40d711a 100644 --- a/tssh/main.go +++ b/tssh/main.go @@ -73,8 +73,21 @@ func background(args *sshArgs, dest string) (bool, error) { } var onExitFuncs []func() + +func cleanupOnExit() { + for i := len(onExitFuncs) - 1; i >= 0; i-- { + onExitFuncs[i]() + } +} + var cleanupAfterLogined []func() +func cleanupForGC() { + for i := len(cleanupAfterLogined) - 1; i >= 0; i-- { + cleanupAfterLogined[i]() + } +} + func parseRemoteCommand(args *sshArgs) (string, error) { command := args.Option.get("RemoteCommand") if args.Command != "" && command != "" && strings.ToLower(command) != "none" { @@ -147,11 +160,7 @@ func TsshMain() int { } // cleanup on exit - defer func() { - for i := len(onExitFuncs) - 1; i >= 0; i-- { - onExitFuncs[i]() - } - }() + defer cleanupOnExit() // print message after stdin reset var err error @@ -163,17 +172,14 @@ func TsshMain() int { // init user config if err = initUserConfig(args.ConfigFile); err != nil { - return -1 + return 1 } - // setup terminal - var mode *terminalMode + // setup virtual terminal on Windows if isTerminal { - mode, err = setupTerminalMode() - if err != nil { - return 1 + if err = setupVirtualTerminal(); err != nil { + return 2 } - defer resetTerminalMode(mode) } // choose ssh alias @@ -182,7 +188,7 @@ func TsshMain() int { if args.Destination == "" { if !isTerminal { parser.WriteHelp(os.Stderr) - return 2 + return 3 } dest, quit, err = chooseAlias("") } else { @@ -193,7 +199,7 @@ func TsshMain() int { return 0 } if err != nil { - return 3 + return 4 } // run as background @@ -201,7 +207,7 @@ func TsshMain() int { var parent bool parent, err = background(&args, dest) if err != nil { - return 4 + return 5 } if parent { return 0 @@ -212,49 +218,43 @@ func TsshMain() int { // parse cmd and tty command, tty, err := parseCmdAndTTY(&args) if err != nil { - return 5 + return 6 } // ssh login client, session, err := sshLogin(&args, tty) if err != nil { - return 6 + return 7 } defer client.Close() if session != nil { defer session.Close() } - // reset terminal if no login tty - if mode != nil && (!tty || args.StdioForward != "" || args.NoCommand) { - resetTerminalMode(mode) - } - // stdio forward if args.StdioForward != "" { var wg *sync.WaitGroup wg, err = stdioForward(client, args.StdioForward) if err != nil { - return 7 + return 8 } + cleanupForGC() wg.Wait() return 0 } // ssh forward if err = sshForward(client, &args); err != nil { - return 8 + return 9 } // cleanup for GC - for i := len(cleanupAfterLogined) - 1; i >= 0; i-- { - cleanupAfterLogined[i]() - } + cleanupForGC() // no command if args.NoCommand { if client.Wait() != nil { - return 9 + return 10 } return 0 } @@ -263,21 +263,31 @@ func TsshMain() int { if command != "" { if err = session.Start(command); err != nil { err = fmt.Errorf("start command [%s] failed: %v", command, err) - return 10 + return 11 } } else { if err = session.Shell(); err != nil { err = fmt.Errorf("start shell failed: %v", err) - return 11 + return 12 + } + } + + // make stdin raw + if isTerminal && tty { + var state *stdinState + state, err = makeStdinRaw() + if err != nil { + return 13 } + defer resetStdin(state) } // wait for exit if session.Wait() != nil { - return 12 + return 14 } if args.Background && client.Wait() != nil { - return 13 + return 15 } return 0 } diff --git a/tssh/prompt.go b/tssh/prompt.go index eb4793b..f2b484f 100644 --- a/tssh/prompt.go +++ b/tssh/prompt.go @@ -474,6 +474,9 @@ func (p *sshPrompt) userConfirm(buf []byte) bool { func (p *sshPrompt) wrapStdin() { defer p.selector.Stdin.Close() defer p.pipeOut.Close() + if state, _ := makeStdinRaw(); state != nil { + defer resetStdin(state) + } buffer := make([]byte, 100) for { n, err := os.Stdin.Read(buffer) diff --git a/tssh/term_unix.go b/tssh/term_unix.go index 722c88c..be00538 100644 --- a/tssh/term_unix.go +++ b/tssh/term_unix.go @@ -39,22 +39,26 @@ import ( "golang.org/x/term" ) -type terminalMode struct { +type stdinState struct { state *term.State } -func setupTerminalMode() (*terminalMode, error) { +func setupVirtualTerminal() error { + return nil +} + +func makeStdinRaw() (*stdinState, error) { state, err := term.MakeRaw(int(os.Stdin.Fd())) if err != nil { return nil, fmt.Errorf("terminal make raw failed: %v", err) } - return &terminalMode{state}, nil + return &stdinState{state}, nil } -func resetTerminalMode(tm *terminalMode) { - if tm.state != nil { - _ = term.Restore(int(os.Stdin.Fd()), tm.state) - tm.state = nil +func resetStdin(s *stdinState) { + if s.state != nil { + _ = term.Restore(int(os.Stdin.Fd()), s.state) + s.state = nil } } @@ -84,19 +88,12 @@ func getKeyboardInput() (*os.File, func(), error) { return os.Stdin, func() {}, nil } - path := "/dev/tty" - file, err := os.Open(path) + file, err := os.Open("/dev/tty") if err != nil { return nil, nil, err } - state, err := term.MakeRaw(int(file.Fd())) - if err != nil { - _ = file.Close() - return nil, nil, fmt.Errorf("%s make raw failed: %v", path, err) - } - - return file, func() { _ = term.Restore(int(file.Fd()), state); _ = file.Close() }, nil + return file, func() { _ = file.Close() }, nil } func isSshTmuxEnv() bool { diff --git a/tssh/term_windows.go b/tssh/term_windows.go index 26e2bdf..5b27136 100644 --- a/tssh/term_windows.go +++ b/tssh/term_windows.go @@ -38,8 +38,9 @@ import ( "golang.org/x/term" ) -type terminalMode struct { - state *term.State +type stdinState struct { + state *term.State + settings *string } const CP_UTF8 uint32 = 65001 @@ -154,10 +155,14 @@ func sttySize() (int, int, error) { return cols, rows, nil } -func setupTerminalMode() (*terminalMode, error) { +func setupVirtualTerminal() error { // enable virtual terminal - if err := enableVirtualTerminal(); err != nil && !sttyExecutable() { - return nil, fmt.Errorf("enable virtual terminal failed: %v", err) + if err := enableVirtualTerminal(); err != nil { + if !sttyExecutable() { + return fmt.Errorf("enable virtual terminal failed: %v", err) + } + promptCursorIcon = ">>" + promptSelectedIcon = "++" } // set code page to UTF8 @@ -170,32 +175,36 @@ func setupTerminalMode() (*terminalMode, error) { setConsoleOutputCP(outCP) }) + return nil +} + +func makeStdinRaw() (*stdinState, error) { state, err := term.MakeRaw(int(os.Stdin.Fd())) - if err != nil { - if !sttyExecutable() { - return nil, fmt.Errorf("terminal make raw failed: %v", err) - } - settings, err := sttySettings() - if err != nil { - return nil, fmt.Errorf("get stty settings failed: %v", err) - } - onExitFuncs = append(onExitFuncs, func() { - sttyReset(settings) - }) - if err := sttyMakeRaw(); err != nil { - return nil, fmt.Errorf("stty make raw failed: %v", err) - } - promptCursorIcon = ">>" - promptSelectedIcon = "++" + if err == nil { + return &stdinState{state, nil}, nil } - return &terminalMode{state}, nil + if !sttyExecutable() { + return nil, fmt.Errorf("terminal make raw failed: %v", err) + } + settings, err := sttySettings() + if err != nil { + return nil, fmt.Errorf("get stty settings failed: %v", err) + } + if err := sttyMakeRaw(); err != nil { + return nil, fmt.Errorf("stty make raw failed: %v", err) + } + return &stdinState{nil, &settings}, nil } -func resetTerminalMode(tm *terminalMode) { - if tm.state != nil { - _ = term.Restore(int(os.Stdin.Fd()), tm.state) - tm.state = nil +func resetStdin(s *stdinState) { + if s.state != nil { + _ = term.Restore(int(os.Stdin.Fd()), s.state) + s.state = nil + } + if s.settings != nil { + sttyReset(*s.settings) + s.settings = nil } } @@ -221,7 +230,7 @@ func onTerminalResize(setTerminalSize func(int, int)) { go func() { columns, rows, _ := getTerminalSize() for { - time.Sleep(1 * time.Second) + time.Sleep(time.Second) width, height, err := getTerminalSize() if err != nil { continue @@ -251,11 +260,5 @@ func getKeyboardInput() (*os.File, func(), error) { } file := os.NewFile(uintptr(handle), "CONIN$") - state, err := term.MakeRaw(int(file.Fd())) - if err != nil { - _ = file.Close() - return nil, nil, fmt.Errorf("CONIN$ make raw failed: %v", err) - } - - return file, func() { _ = term.Restore(int(file.Fd()), state); _ = file.Close() }, nil + return file, func() { _ = file.Close() }, nil }