Skip to content

Commit

Permalink
support ctrl+c to interrupt login
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Sep 16, 2023
1 parent 290dc94 commit 23fc310
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 146 deletions.
2 changes: 1 addition & 1 deletion tssh/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func remoteForward(client *ssh.Client, f *forwardCfg, args *sshArgs) {
fmt.Fprintf(os.Stderr, "remote forward accept failed: %v\r\n", err)
continue
}
local, err := net.DialTimeout("tcp", localAddr, 3*time.Second)
local, err := net.DialTimeout("tcp", localAddr, 10*time.Second)
if err != nil {
fmt.Fprintf(os.Stderr, "remote forward dial [%s] failed: %v\r\n", localAddr, err)
remote.Close()
Expand Down
84 changes: 21 additions & 63 deletions tssh/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ SOFTWARE.
package tssh

import (
"bufio"
"bytes"
"crypto/x509"
"encoding/hex"
Expand All @@ -34,7 +35,6 @@ import (
"net"
"os"
"os/exec"
"os/signal"
"os/user"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -133,7 +133,11 @@ func getLoginParam(args *sshArgs) (*loginParam, error) {
if err != nil {
return nil, fmt.Errorf("get current user failed: %v", err)
}
param.user = currentUser.Username
userName = currentUser.Username
if idx := strings.LastIndexByte(userName, '\\'); idx >= 0 {
userName = userName[idx+1:]
}
param.user = userName
}
}

Expand Down Expand Up @@ -187,37 +191,6 @@ func createKnownHosts(path string) error {
return nil
}

func readLineFromRawIO(stdin *os.File) (string, error) {
defer fmt.Fprintf(os.Stderr, "\r\n")
buffer := new(bytes.Buffer)
buf := make([]byte, 100)
for {
n, err := stdin.Read(buf)
if err != nil {
return "", nil
}
data := buf[:n]
if bytes.ContainsRune(data, '\x03') {
fmt.Fprintf(os.Stderr, "^C")
return "", fmt.Errorf("interrupt")
}
for _, b := range data {
if b == '\r' || b == '\n' {
return string(bytes.TrimSpace(buffer.Bytes())), nil
}
if b == '\x7f' {
if buffer.Len() > 0 {
buffer.Truncate(buffer.Len() - 1)
fmt.Fprintf(os.Stderr, "\b \b")
}
} else if b >= ' ' && b <= '~' {
buffer.WriteByte(b)
fmt.Fprintf(os.Stderr, "%s", string(b))
}
}
}
}

func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error {
fingerprint := ssh.FingerprintSHA256(key)
fmt.Fprintf(os.Stderr, "The authenticity of host '%s' can't be established.\r\n"+
Expand All @@ -229,12 +202,14 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error {
}
defer closer()

reader := bufio.NewReader(stdin)
fmt.Fprintf(os.Stderr, "Are you sure you want to continue connecting (yes/no/[fingerprint])? ")
for {
input, err := readLineFromRawIO(stdin)
input, err := reader.ReadString('\n')
if err != nil {
return err
}
input = strings.TrimSpace(input)
if input == fingerprint {
break
}
Expand Down Expand Up @@ -410,35 +385,14 @@ func getSigner(dest string, path string) (*sshSigner, error) {
func readSecret(prompt string) (secret []byte, err error) {
fmt.Fprintf(os.Stderr, "%s", prompt)
defer fmt.Fprintf(os.Stderr, "\r\n")
errch := make(chan error, 1)
defer close(errch)

sigch := make(chan os.Signal, 1)
signal.Notify(sigch, os.Interrupt)
go func() {
for range sigch {
errch <- fmt.Errorf("interrupt")
}
}()
defer func() { signal.Stop(sigch); close(sigch) }()
stdin, closer, err := getKeyboardInput()
if err != nil {
return nil, err
}
defer closer()

go func() {
stdin, closer, err := getKeyboardInput()
if err != nil {
errch <- err
return
}
defer closer()
pw, err := term.ReadPassword(int(stdin.Fd()))
if err != nil {
errch <- err
return
}
secret = pw
errch <- nil
}()
err = <-errch
return
return term.ReadPassword(int(stdin.Fd()))
}

func getPasswordAuthMethod(args *sshArgs, host, user string) ssh.AuthMethod {
Expand Down Expand Up @@ -710,7 +664,7 @@ func dialWithTimeout(client *ssh.Client, network, addr string) (conn net.Conn, e
done <- struct{}{}
}()
select {
case <-time.After(3 * time.Second):
case <-time.After(10 * time.Second):
err = fmt.Errorf("dial [%s] timeout", addr)
case <-done:
}
Expand Down Expand Up @@ -755,7 +709,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, e
config := &ssh.ClientConfig{
User: param.user,
Auth: authMethods,
Timeout: 3 * time.Second,
Timeout: 10 * time.Second,
HostKeyCallback: cb,
HostKeyAlgorithms: kh.HostKeyAlgorithms(param.addr),
BannerCallback: func(banner string) error {
Expand Down Expand Up @@ -878,6 +832,10 @@ func wrapStdIO(serverIn io.WriteCloser, serverOut io.Reader, tty bool) {
}
}
if err == io.EOF {
if win && tty {
_, _ = writer.Write([]byte{0x1A}) // ctrl + z
continue
}
break
}
if err != nil {
Expand Down
74 changes: 42 additions & 32 deletions tssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,21 @@ func background(args *sshArgs, dest string) (bool, error) {
}

var onExitFuncs []func()

func cleanupOnExit() {
for i := len(onExitFuncs) - 1; i >= 0; i-- {
onExitFuncs[i]()
}
}

var cleanupAfterLogined []func()

func cleanupForGC() {
for i := len(cleanupAfterLogined) - 1; i >= 0; i-- {
cleanupAfterLogined[i]()
}
}

func parseRemoteCommand(args *sshArgs) (string, error) {
command := args.Option.get("RemoteCommand")
if args.Command != "" && command != "" && strings.ToLower(command) != "none" {
Expand Down Expand Up @@ -147,11 +160,7 @@ func TsshMain() int {
}

// cleanup on exit
defer func() {
for i := len(onExitFuncs) - 1; i >= 0; i-- {
onExitFuncs[i]()
}
}()
defer cleanupOnExit()

// print message after stdin reset
var err error
Expand All @@ -163,17 +172,14 @@ func TsshMain() int {

// init user config
if err = initUserConfig(args.ConfigFile); err != nil {
return -1
return 1
}

// setup terminal
var mode *terminalMode
// setup virtual terminal on Windows
if isTerminal {
mode, err = setupTerminalMode()
if err != nil {
return 1
if err = setupVirtualTerminal(); err != nil {
return 2
}
defer resetTerminalMode(mode)
}

// choose ssh alias
Expand All @@ -182,7 +188,7 @@ func TsshMain() int {
if args.Destination == "" {
if !isTerminal {
parser.WriteHelp(os.Stderr)
return 2
return 3
}
dest, quit, err = chooseAlias("")
} else {
Expand All @@ -193,15 +199,15 @@ func TsshMain() int {
return 0
}
if err != nil {
return 3
return 4
}

// run as background
if args.Background {
var parent bool
parent, err = background(&args, dest)
if err != nil {
return 4
return 5
}
if parent {
return 0
Expand All @@ -212,49 +218,43 @@ func TsshMain() int {
// parse cmd and tty
command, tty, err := parseCmdAndTTY(&args)
if err != nil {
return 5
return 6
}

// ssh login
client, session, err := sshLogin(&args, tty)
if err != nil {
return 6
return 7
}
defer client.Close()
if session != nil {
defer session.Close()
}

// reset terminal if no login tty
if mode != nil && (!tty || args.StdioForward != "" || args.NoCommand) {
resetTerminalMode(mode)
}

// stdio forward
if args.StdioForward != "" {
var wg *sync.WaitGroup
wg, err = stdioForward(client, args.StdioForward)
if err != nil {
return 7
return 8
}
cleanupForGC()
wg.Wait()
return 0
}

// ssh forward
if err = sshForward(client, &args); err != nil {
return 8
return 9
}

// cleanup for GC
for i := len(cleanupAfterLogined) - 1; i >= 0; i-- {
cleanupAfterLogined[i]()
}
cleanupForGC()

// no command
if args.NoCommand {
if client.Wait() != nil {
return 9
return 10
}
return 0
}
Expand All @@ -263,21 +263,31 @@ func TsshMain() int {
if command != "" {
if err = session.Start(command); err != nil {
err = fmt.Errorf("start command [%s] failed: %v", command, err)
return 10
return 11
}
} else {
if err = session.Shell(); err != nil {
err = fmt.Errorf("start shell failed: %v", err)
return 11
return 12
}
}

// make stdin raw
if isTerminal && tty {
var state *stdinState
state, err = makeStdinRaw()
if err != nil {
return 13
}
defer resetStdin(state)
}

// wait for exit
if session.Wait() != nil {
return 12
return 14
}
if args.Background && client.Wait() != nil {
return 13
return 15
}
return 0
}
3 changes: 3 additions & 0 deletions tssh/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ func (p *sshPrompt) userConfirm(buf []byte) bool {
func (p *sshPrompt) wrapStdin() {
defer p.selector.Stdin.Close()
defer p.pipeOut.Close()
if state, _ := makeStdinRaw(); state != nil {
defer resetStdin(state)
}
buffer := make([]byte, 100)
for {
n, err := os.Stdin.Read(buffer)
Expand Down
Loading

0 comments on commit 23fc310

Please sign in to comment.