From 23fc3106d0208c75a50e44b42270704e4a319e39 Mon Sep 17 00:00:00 2001
From: Lonny Wong <lonnywong@qq.com>
Date: Sat, 16 Sep 2023 21:28:37 +0800
Subject: [PATCH] support ctrl+c to interrupt login

---
 tssh/forward.go      |  2 +-
 tssh/login.go        | 84 +++++++++++---------------------------------
 tssh/main.go         | 74 +++++++++++++++++++++-----------------
 tssh/prompt.go       |  3 ++
 tssh/term_unix.go    | 29 +++++++--------
 tssh/term_windows.go | 71 +++++++++++++++++++------------------
 6 files changed, 117 insertions(+), 146 deletions(-)

diff --git a/tssh/forward.go b/tssh/forward.go
index ddf8304..27563b4 100644
--- a/tssh/forward.go
+++ b/tssh/forward.go
@@ -366,7 +366,7 @@ func remoteForward(client *ssh.Client, f *forwardCfg, args *sshArgs) {
 					fmt.Fprintf(os.Stderr, "remote forward accept failed: %v\r\n", err)
 					continue
 				}
-				local, err := net.DialTimeout("tcp", localAddr, 3*time.Second)
+				local, err := net.DialTimeout("tcp", localAddr, 10*time.Second)
 				if err != nil {
 					fmt.Fprintf(os.Stderr, "remote forward dial [%s] failed: %v\r\n", localAddr, err)
 					remote.Close()
diff --git a/tssh/login.go b/tssh/login.go
index 7809e94..d56fce5 100644
--- a/tssh/login.go
+++ b/tssh/login.go
@@ -26,6 +26,7 @@ SOFTWARE.
 package tssh
 
 import (
+	"bufio"
 	"bytes"
 	"crypto/x509"
 	"encoding/hex"
@@ -34,7 +35,6 @@ import (
 	"net"
 	"os"
 	"os/exec"
-	"os/signal"
 	"os/user"
 	"path/filepath"
 	"runtime"
@@ -133,7 +133,11 @@ func getLoginParam(args *sshArgs) (*loginParam, error) {
 			if err != nil {
 				return nil, fmt.Errorf("get current user failed: %v", err)
 			}
-			param.user = currentUser.Username
+			userName = currentUser.Username
+			if idx := strings.LastIndexByte(userName, '\\'); idx >= 0 {
+				userName = userName[idx+1:]
+			}
+			param.user = userName
 		}
 	}
 
@@ -187,37 +191,6 @@ func createKnownHosts(path string) error {
 	return nil
 }
 
-func readLineFromRawIO(stdin *os.File) (string, error) {
-	defer fmt.Fprintf(os.Stderr, "\r\n")
-	buffer := new(bytes.Buffer)
-	buf := make([]byte, 100)
-	for {
-		n, err := stdin.Read(buf)
-		if err != nil {
-			return "", nil
-		}
-		data := buf[:n]
-		if bytes.ContainsRune(data, '\x03') {
-			fmt.Fprintf(os.Stderr, "^C")
-			return "", fmt.Errorf("interrupt")
-		}
-		for _, b := range data {
-			if b == '\r' || b == '\n' {
-				return string(bytes.TrimSpace(buffer.Bytes())), nil
-			}
-			if b == '\x7f' {
-				if buffer.Len() > 0 {
-					buffer.Truncate(buffer.Len() - 1)
-					fmt.Fprintf(os.Stderr, "\b \b")
-				}
-			} else if b >= ' ' && b <= '~' {
-				buffer.WriteByte(b)
-				fmt.Fprintf(os.Stderr, "%s", string(b))
-			}
-		}
-	}
-}
-
 func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error {
 	fingerprint := ssh.FingerprintSHA256(key)
 	fmt.Fprintf(os.Stderr, "The authenticity of host '%s' can't be established.\r\n"+
@@ -229,12 +202,14 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error {
 	}
 	defer closer()
 
+	reader := bufio.NewReader(stdin)
 	fmt.Fprintf(os.Stderr, "Are you sure you want to continue connecting (yes/no/[fingerprint])? ")
 	for {
-		input, err := readLineFromRawIO(stdin)
+		input, err := reader.ReadString('\n')
 		if err != nil {
 			return err
 		}
+		input = strings.TrimSpace(input)
 		if input == fingerprint {
 			break
 		}
@@ -410,35 +385,14 @@ func getSigner(dest string, path string) (*sshSigner, error) {
 func readSecret(prompt string) (secret []byte, err error) {
 	fmt.Fprintf(os.Stderr, "%s", prompt)
 	defer fmt.Fprintf(os.Stderr, "\r\n")
-	errch := make(chan error, 1)
-	defer close(errch)
 
-	sigch := make(chan os.Signal, 1)
-	signal.Notify(sigch, os.Interrupt)
-	go func() {
-		for range sigch {
-			errch <- fmt.Errorf("interrupt")
-		}
-	}()
-	defer func() { signal.Stop(sigch); close(sigch) }()
+	stdin, closer, err := getKeyboardInput()
+	if err != nil {
+		return nil, err
+	}
+	defer closer()
 
-	go func() {
-		stdin, closer, err := getKeyboardInput()
-		if err != nil {
-			errch <- err
-			return
-		}
-		defer closer()
-		pw, err := term.ReadPassword(int(stdin.Fd()))
-		if err != nil {
-			errch <- err
-			return
-		}
-		secret = pw
-		errch <- nil
-	}()
-	err = <-errch
-	return
+	return term.ReadPassword(int(stdin.Fd()))
 }
 
 func getPasswordAuthMethod(args *sshArgs, host, user string) ssh.AuthMethod {
@@ -710,7 +664,7 @@ func dialWithTimeout(client *ssh.Client, network, addr string) (conn net.Conn, e
 		done <- struct{}{}
 	}()
 	select {
-	case <-time.After(3 * time.Second):
+	case <-time.After(10 * time.Second):
 		err = fmt.Errorf("dial [%s] timeout", addr)
 	case <-done:
 	}
@@ -755,7 +709,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e
 	config := &ssh.ClientConfig{
 		User:              param.user,
 		Auth:              authMethods,
-		Timeout:           3 * time.Second,
+		Timeout:           10 * time.Second,
 		HostKeyCallback:   cb,
 		HostKeyAlgorithms: kh.HostKeyAlgorithms(param.addr),
 		BannerCallback: func(banner string) error {
@@ -878,6 +832,10 @@ func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) {
 				}
 			}
 			if err == io.EOF {
+				if win && tty {
+					_, _ = writer.Write([]byte{0x1A}) // ctrl + z
+					continue
+				}
 				break
 			}
 			if err != nil {
diff --git a/tssh/main.go b/tssh/main.go
index 7a2e202..40d711a 100644
--- a/tssh/main.go
+++ b/tssh/main.go
@@ -73,8 +73,21 @@ func background(args *sshArgs, dest string) (bool, error) {
 }
 
 var onExitFuncs []func()
+
+func cleanupOnExit() {
+	for i := len(onExitFuncs) - 1; i >= 0; i-- {
+		onExitFuncs[i]()
+	}
+}
+
 var cleanupAfterLogined []func()
 
+func cleanupForGC() {
+	for i := len(cleanupAfterLogined) - 1; i >= 0; i-- {
+		cleanupAfterLogined[i]()
+	}
+}
+
 func parseRemoteCommand(args *sshArgs) (string, error) {
 	command := args.Option.get("RemoteCommand")
 	if args.Command != "" && command != "" && strings.ToLower(command) != "none" {
@@ -147,11 +160,7 @@ func TsshMain() int {
 	}
 
 	// cleanup on exit
-	defer func() {
-		for i := len(onExitFuncs) - 1; i >= 0; i-- {
-			onExitFuncs[i]()
-		}
-	}()
+	defer cleanupOnExit()
 
 	// print message after stdin reset
 	var err error
@@ -163,17 +172,14 @@ func TsshMain() int {
 
 	// init user config
 	if err = initUserConfig(args.ConfigFile); err != nil {
-		return -1
+		return 1
 	}
 
-	// setup terminal
-	var mode *terminalMode
+	// setup virtual terminal on Windows
 	if isTerminal {
-		mode, err = setupTerminalMode()
-		if err != nil {
-			return 1
+		if err = setupVirtualTerminal(); err != nil {
+			return 2
 		}
-		defer resetTerminalMode(mode)
 	}
 
 	// choose ssh alias
@@ -182,7 +188,7 @@ func TsshMain() int {
 	if args.Destination == "" {
 		if !isTerminal {
 			parser.WriteHelp(os.Stderr)
-			return 2
+			return 3
 		}
 		dest, quit, err = chooseAlias("")
 	} else {
@@ -193,7 +199,7 @@ func TsshMain() int {
 		return 0
 	}
 	if err != nil {
-		return 3
+		return 4
 	}
 
 	// run as background
@@ -201,7 +207,7 @@ func TsshMain() int {
 		var parent bool
 		parent, err = background(&args, dest)
 		if err != nil {
-			return 4
+			return 5
 		}
 		if parent {
 			return 0
@@ -212,49 +218,43 @@ func TsshMain() int {
 	// parse cmd and tty
 	command, tty, err := parseCmdAndTTY(&args)
 	if err != nil {
-		return 5
+		return 6
 	}
 
 	// ssh login
 	client, session, err := sshLogin(&args, tty)
 	if err != nil {
-		return 6
+		return 7
 	}
 	defer client.Close()
 	if session != nil {
 		defer session.Close()
 	}
 
-	// reset terminal if no login tty
-	if mode != nil && (!tty || args.StdioForward != "" || args.NoCommand) {
-		resetTerminalMode(mode)
-	}
-
 	// stdio forward
 	if args.StdioForward != "" {
 		var wg *sync.WaitGroup
 		wg, err = stdioForward(client, args.StdioForward)
 		if err != nil {
-			return 7
+			return 8
 		}
+		cleanupForGC()
 		wg.Wait()
 		return 0
 	}
 
 	// ssh forward
 	if err = sshForward(client, &args); err != nil {
-		return 8
+		return 9
 	}
 
 	// cleanup for GC
-	for i := len(cleanupAfterLogined) - 1; i >= 0; i-- {
-		cleanupAfterLogined[i]()
-	}
+	cleanupForGC()
 
 	// no command
 	if args.NoCommand {
 		if client.Wait() != nil {
-			return 9
+			return 10
 		}
 		return 0
 	}
@@ -263,21 +263,31 @@ func TsshMain() int {
 	if command != "" {
 		if err = session.Start(command); err != nil {
 			err = fmt.Errorf("start command [%s] failed: %v", command, err)
-			return 10
+			return 11
 		}
 	} else {
 		if err = session.Shell(); err != nil {
 			err = fmt.Errorf("start shell failed: %v", err)
-			return 11
+			return 12
+		}
+	}
+
+	// make stdin raw
+	if isTerminal && tty {
+		var state *stdinState
+		state, err = makeStdinRaw()
+		if err != nil {
+			return 13
 		}
+		defer resetStdin(state)
 	}
 
 	// wait for exit
 	if session.Wait() != nil {
-		return 12
+		return 14
 	}
 	if args.Background && client.Wait() != nil {
-		return 13
+		return 15
 	}
 	return 0
 }
diff --git a/tssh/prompt.go b/tssh/prompt.go
index eb4793b..f2b484f 100644
--- a/tssh/prompt.go
+++ b/tssh/prompt.go
@@ -474,6 +474,9 @@ func (p *sshPrompt) userConfirm(buf []byte) bool {
 func (p *sshPrompt) wrapStdin() {
 	defer p.selector.Stdin.Close()
 	defer p.pipeOut.Close()
+	if state, _ := makeStdinRaw(); state != nil {
+		defer resetStdin(state)
+	}
 	buffer := make([]byte, 100)
 	for {
 		n, err := os.Stdin.Read(buffer)
diff --git a/tssh/term_unix.go b/tssh/term_unix.go
index 722c88c..be00538 100644
--- a/tssh/term_unix.go
+++ b/tssh/term_unix.go
@@ -39,22 +39,26 @@ import (
 	"golang.org/x/term"
 )
 
-type terminalMode struct {
+type stdinState struct {
 	state *term.State
 }
 
-func setupTerminalMode() (*terminalMode, error) {
+func setupVirtualTerminal() error {
+	return nil
+}
+
+func makeStdinRaw() (*stdinState, error) {
 	state, err := term.MakeRaw(int(os.Stdin.Fd()))
 	if err != nil {
 		return nil, fmt.Errorf("terminal make raw failed: %v", err)
 	}
-	return &terminalMode{state}, nil
+	return &stdinState{state}, nil
 }
 
-func resetTerminalMode(tm *terminalMode) {
-	if tm.state != nil {
-		_ = term.Restore(int(os.Stdin.Fd()), tm.state)
-		tm.state = nil
+func resetStdin(s *stdinState) {
+	if s.state != nil {
+		_ = term.Restore(int(os.Stdin.Fd()), s.state)
+		s.state = nil
 	}
 }
 
@@ -84,19 +88,12 @@ func getKeyboardInput() (*os.File, func(), error) {
 		return os.Stdin, func() {}, nil
 	}
 
-	path := "/dev/tty"
-	file, err := os.Open(path)
+	file, err := os.Open("/dev/tty")
 	if err != nil {
 		return nil, nil, err
 	}
 
-	state, err := term.MakeRaw(int(file.Fd()))
-	if err != nil {
-		_ = file.Close()
-		return nil, nil, fmt.Errorf("%s make raw failed: %v", path, err)
-	}
-
-	return file, func() { _ = term.Restore(int(file.Fd()), state); _ = file.Close() }, nil
+	return file, func() { _ = file.Close() }, nil
 }
 
 func isSshTmuxEnv() bool {
diff --git a/tssh/term_windows.go b/tssh/term_windows.go
index 26e2bdf..5b27136 100644
--- a/tssh/term_windows.go
+++ b/tssh/term_windows.go
@@ -38,8 +38,9 @@ import (
 	"golang.org/x/term"
 )
 
-type terminalMode struct {
-	state *term.State
+type stdinState struct {
+	state    *term.State
+	settings *string
 }
 
 const CP_UTF8 uint32 = 65001
@@ -154,10 +155,14 @@ func sttySize() (int, int, error) {
 	return cols, rows, nil
 }
 
-func setupTerminalMode() (*terminalMode, error) {
+func setupVirtualTerminal() error {
 	// enable virtual terminal
-	if err := enableVirtualTerminal(); err != nil && !sttyExecutable() {
-		return nil, fmt.Errorf("enable virtual terminal failed: %v", err)
+	if err := enableVirtualTerminal(); err != nil {
+		if !sttyExecutable() {
+			return fmt.Errorf("enable virtual terminal failed: %v", err)
+		}
+		promptCursorIcon = ">>"
+		promptSelectedIcon = "++"
 	}
 
 	// set code page to UTF8
@@ -170,32 +175,36 @@ func setupTerminalMode() (*terminalMode, error) {
 		setConsoleOutputCP(outCP)
 	})
 
+	return nil
+}
+
+func makeStdinRaw() (*stdinState, error) {
 	state, err := term.MakeRaw(int(os.Stdin.Fd()))
-	if err != nil {
-		if !sttyExecutable() {
-			return nil, fmt.Errorf("terminal make raw failed: %v", err)
-		}
-		settings, err := sttySettings()
-		if err != nil {
-			return nil, fmt.Errorf("get stty settings failed: %v", err)
-		}
-		onExitFuncs = append(onExitFuncs, func() {
-			sttyReset(settings)
-		})
-		if err := sttyMakeRaw(); err != nil {
-			return nil, fmt.Errorf("stty make raw failed: %v", err)
-		}
-		promptCursorIcon = ">>"
-		promptSelectedIcon = "++"
+	if err == nil {
+		return &stdinState{state, nil}, nil
 	}
 
-	return &terminalMode{state}, nil
+	if !sttyExecutable() {
+		return nil, fmt.Errorf("terminal make raw failed: %v", err)
+	}
+	settings, err := sttySettings()
+	if err != nil {
+		return nil, fmt.Errorf("get stty settings failed: %v", err)
+	}
+	if err := sttyMakeRaw(); err != nil {
+		return nil, fmt.Errorf("stty make raw failed: %v", err)
+	}
+	return &stdinState{nil, &settings}, nil
 }
 
-func resetTerminalMode(tm *terminalMode) {
-	if tm.state != nil {
-		_ = term.Restore(int(os.Stdin.Fd()), tm.state)
-		tm.state = nil
+func resetStdin(s *stdinState) {
+	if s.state != nil {
+		_ = term.Restore(int(os.Stdin.Fd()), s.state)
+		s.state = nil
+	}
+	if s.settings != nil {
+		sttyReset(*s.settings)
+		s.settings = nil
 	}
 }
 
@@ -221,7 +230,7 @@ func onTerminalResize(setTerminalSize func(int, int)) {
 	go func() {
 		columns, rows, _ := getTerminalSize()
 		for {
-			time.Sleep(1 * time.Second)
+			time.Sleep(time.Second)
 			width, height, err := getTerminalSize()
 			if err != nil {
 				continue
@@ -251,11 +260,5 @@ func getKeyboardInput() (*os.File, func(), error) {
 	}
 	file := os.NewFile(uintptr(handle), "CONIN$")
 
-	state, err := term.MakeRaw(int(file.Fd()))
-	if err != nil {
-		_ = file.Close()
-		return nil, nil, fmt.Errorf("CONIN$ make raw failed: %v", err)
-	}
-
-	return file, func() { _ = term.Restore(int(file.Fd()), state); _ = file.Close() }, nil
+	return file, func() { _ = file.Close() }, nil
 }