Skip to content

Commit

Permalink
support more options
Browse files Browse the repository at this point in the history
LogLevel
StrictHostKeyChecking
  • Loading branch information
lonnywong committed Nov 25, 2023
1 parent 3013f1b commit d9c3d96
Showing 1 changed file with 78 additions and 29 deletions.
107 changes: 78 additions & 29 deletions tssh/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ import (
"golang.org/x/term"
)

var enableDebugLogging bool
var enableDebugLogging bool = false
var envbleWarningLogging bool = true

func debug(format string, a ...any) {
if !enableDebugLogging {
Expand All @@ -57,6 +58,9 @@ func debug(format string, a ...any) {
}

var warning = func(format string, a ...any) {
if !envbleWarningLogging {
return
}
fmt.Fprintf(os.Stderr, fmt.Sprintf("\033[0;33mWarning: %s\033[0m\r\n", format), a...)
}

Expand Down Expand Up @@ -179,35 +183,37 @@ func getLoginParam(args *sshArgs) (*loginParam, error) {
return param, nil
}

func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error {
fingerprint := ssh.FingerprintSHA256(key)
fmt.Fprintf(os.Stderr, "The authenticity of host '%s' can't be established.\r\n"+
"%s key fingerprint is %s.\r\n", host, key.Type(), fingerprint)
func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey, ask bool) error {
if ask {
fingerprint := ssh.FingerprintSHA256(key)
fmt.Fprintf(os.Stderr, "The authenticity of host '%s' can't be established.\r\n"+
"%s key fingerprint is %s.\r\n", host, key.Type(), fingerprint)

stdin, closer, err := getKeyboardInput()
if err != nil {
return err
}
defer closer()

reader := bufio.NewReader(stdin)
fmt.Fprintf(os.Stderr, "Are you sure you want to continue connecting (yes/no/[fingerprint])? ")
for {
input, err := reader.ReadString('\n')
stdin, closer, err := getKeyboardInput()
if err != nil {
return err
}
input = strings.TrimSpace(input)
if input == fingerprint {
break
}
input = strings.ToLower(input)
if input == "yes" {
break
} else if input == "no" {
return fmt.Errorf("host key not trusted")
defer closer()

reader := bufio.NewReader(stdin)
fmt.Fprintf(os.Stderr, "Are you sure you want to continue connecting (yes/no/[fingerprint])? ")
for {
input, err := reader.ReadString('\n')
if err != nil {
return err
}
input = strings.TrimSpace(input)
if input == fingerprint {
break
}
input = strings.ToLower(input)
if input == "yes" {
break
} else if input == "no" {
return fmt.Errorf("host key not trusted")
}
fmt.Fprintf(os.Stderr, "Please type 'yes', 'no' or the fingerprint: ")
}
fmt.Fprintf(os.Stderr, "Please type 'yes', 'no' or the fingerprint: ")
}

writeKnownHost := func() error {
Expand All @@ -219,7 +225,7 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error {
return knownhosts.WriteKnownHost(file, host, remote, key)
}

if err = writeKnownHost(); err != nil {
if err := writeKnownHost(); err != nil {
warning("Failed to add the host to the list of known hosts (%s): %v", path, err)
return nil
}
Expand Down Expand Up @@ -266,6 +272,7 @@ func getHostKeyCallback(args *sshArgs) (ssh.HostKeyCallback, knownhosts.HostKeyC

cb := func(host string, remote net.Addr, key ssh.PublicKey) error {
err := kh(host, remote, key)
strictHostKeyChecking := strings.ToLower(getOptionConfig(args, "StrictHostKeyChecking"))
if knownhosts.IsHostKeyChanged(err) {
path := primaryPath
if path == "" {
Expand All @@ -282,11 +289,22 @@ func getHostKeyCallback(args *sshArgs) (ssh.HostKeyCallback, knownhosts.HostKeyC
"Please contact your system administrator.\r\n"+
"Add correct host key in %s to get rid of this message.\r\n",
key.Type(), ssh.FingerprintSHA256(key), path)
return err
} else if knownhosts.IsHostUnknown(err) && primaryPath != "" {
return addHostKey(primaryPath, host, remote, key)
ask := true
switch strictHostKeyChecking {
case "yes":
return err
case "accept-new", "no", "off":
ask = false
}
return addHostKey(primaryPath, host, remote, key, ask)
}
switch strictHostKeyChecking {
case "no", "off":
return nil
default:
return err
}
return err
}

return cb, kh, err
Expand Down Expand Up @@ -700,12 +718,43 @@ func (c *connWithTimeout) Read(b []byte) (n int, err error) {
return
}

func setupLogLevel(args *sshArgs) func() {
previousDebug := enableDebugLogging
previousWarning := envbleWarningLogging
reset := func() {
enableDebugLogging = previousDebug
envbleWarningLogging = previousWarning
}
if args.Debug {
enableDebugLogging = true
envbleWarningLogging = true
return reset
}
switch strings.ToLower(getOptionConfig(args, "LogLevel")) {
case "quiet", "fatal", "error":
enableDebugLogging = false
envbleWarningLogging = false
case "debug", "debug1", "debug2", "debug3":
enableDebugLogging = true
envbleWarningLogging = true
case "info", "verbose":
fallthrough
default:
enableDebugLogging = false
envbleWarningLogging = true
}
return reset
}

func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, bool, error) {
param, err := getLoginParam(args)
if err != nil {
return nil, false, err
}

resetLogLevel := setupLogLevel(args)
defer resetLogLevel()

if client := connectViaControl(args, param); client != nil {
return client, true, nil
}
Expand Down

0 comments on commit d9c3d96

Please sign in to comment.