diff --git a/tssh/login.go b/tssh/login.go index 32fddfd..1fd1c70 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -47,7 +47,8 @@ import ( "golang.org/x/term" ) -var enableDebugLogging bool +var enableDebugLogging bool = false +var envbleWarningLogging bool = true func debug(format string, a ...any) { if !enableDebugLogging { @@ -57,6 +58,9 @@ func debug(format string, a ...any) { } var warning = func(format string, a ...any) { + if !envbleWarningLogging { + return + } fmt.Fprintf(os.Stderr, fmt.Sprintf("\033[0;33mWarning: %s\033[0m\r\n", format), a...) } @@ -179,35 +183,37 @@ func getLoginParam(args *sshArgs) (*loginParam, error) { return param, nil } -func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error { - fingerprint := ssh.FingerprintSHA256(key) - fmt.Fprintf(os.Stderr, "The authenticity of host '%s' can't be established.\r\n"+ - "%s key fingerprint is %s.\r\n", host, key.Type(), fingerprint) +func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey, ask bool) error { + if ask { + fingerprint := ssh.FingerprintSHA256(key) + fmt.Fprintf(os.Stderr, "The authenticity of host '%s' can't be established.\r\n"+ + "%s key fingerprint is %s.\r\n", host, key.Type(), fingerprint) - stdin, closer, err := getKeyboardInput() - if err != nil { - return err - } - defer closer() - - reader := bufio.NewReader(stdin) - fmt.Fprintf(os.Stderr, "Are you sure you want to continue connecting (yes/no/[fingerprint])? ") - for { - input, err := reader.ReadString('\n') + stdin, closer, err := getKeyboardInput() if err != nil { return err } - input = strings.TrimSpace(input) - if input == fingerprint { - break - } - input = strings.ToLower(input) - if input == "yes" { - break - } else if input == "no" { - return fmt.Errorf("host key not trusted") + defer closer() + + reader := bufio.NewReader(stdin) + fmt.Fprintf(os.Stderr, "Are you sure you want to continue connecting (yes/no/[fingerprint])? ") + for { + input, err := reader.ReadString('\n') + if err != nil { + return err + } + input = strings.TrimSpace(input) + if input == fingerprint { + break + } + input = strings.ToLower(input) + if input == "yes" { + break + } else if input == "no" { + return fmt.Errorf("host key not trusted") + } + fmt.Fprintf(os.Stderr, "Please type 'yes', 'no' or the fingerprint: ") } - fmt.Fprintf(os.Stderr, "Please type 'yes', 'no' or the fingerprint: ") } writeKnownHost := func() error { @@ -219,7 +225,7 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey) error { return knownhosts.WriteKnownHost(file, host, remote, key) } - if err = writeKnownHost(); err != nil { + if err := writeKnownHost(); err != nil { warning("Failed to add the host to the list of known hosts (%s): %v", path, err) return nil } @@ -266,6 +272,7 @@ func getHostKeyCallback(args *sshArgs) (ssh.HostKeyCallback, knownhosts.HostKeyC cb := func(host string, remote net.Addr, key ssh.PublicKey) error { err := kh(host, remote, key) + strictHostKeyChecking := strings.ToLower(getOptionConfig(args, "StrictHostKeyChecking")) if knownhosts.IsHostKeyChanged(err) { path := primaryPath if path == "" { @@ -282,11 +289,22 @@ func getHostKeyCallback(args *sshArgs) (ssh.HostKeyCallback, knownhosts.HostKeyC "Please contact your system administrator.\r\n"+ "Add correct host key in %s to get rid of this message.\r\n", key.Type(), ssh.FingerprintSHA256(key), path) - return err } else if knownhosts.IsHostUnknown(err) && primaryPath != "" { - return addHostKey(primaryPath, host, remote, key) + ask := true + switch strictHostKeyChecking { + case "yes": + return err + case "accept-new", "no", "off": + ask = false + } + return addHostKey(primaryPath, host, remote, key, ask) + } + switch strictHostKeyChecking { + case "no", "off": + return nil + default: + return err } - return err } return cb, kh, err @@ -700,12 +718,43 @@ func (c *connWithTimeout) Read(b []byte) (n int, err error) { return } +func setupLogLevel(args *sshArgs) func() { + previousDebug := enableDebugLogging + previousWarning := envbleWarningLogging + reset := func() { + enableDebugLogging = previousDebug + envbleWarningLogging = previousWarning + } + if args.Debug { + enableDebugLogging = true + envbleWarningLogging = true + return reset + } + switch strings.ToLower(getOptionConfig(args, "LogLevel")) { + case "quiet", "fatal", "error": + enableDebugLogging = false + envbleWarningLogging = false + case "debug", "debug1", "debug2", "debug3": + enableDebugLogging = true + envbleWarningLogging = true + case "info", "verbose": + fallthrough + default: + enableDebugLogging = false + envbleWarningLogging = true + } + return reset +} + func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, bool, error) { param, err := getLoginParam(args) if err != nil { return nil, false, err } + resetLogLevel := setupLogLevel(args) + defer resetLogLevel() + if client := connectViaControl(args, param); client != nil { return client, true, nil }