diff --git a/tssh/agent.go b/tssh/agent.go index a93dae8..7ef9784 100644 --- a/tssh/agent.go +++ b/tssh/agent.go @@ -75,7 +75,7 @@ func getAgentClient(args *sshArgs) agent.ExtendedAgent { agentClient = agent.NewClient(conn) debug("new ssh agent client [%s] success", addr) - cleanupAfterLogined = append(cleanupAfterLogined, func() { + afterLoginFuncs = append(afterLoginFuncs, func() { conn.Close() agentClient = nil }) diff --git a/tssh/config.go b/tssh/config.go index 533fd10..9b8e97d 100644 --- a/tssh/config.go +++ b/tssh/config.go @@ -177,15 +177,6 @@ func showTsshConfig() { } func initUserConfig(configFile string) error { - cleanupAfterLogined = append(cleanupAfterLogined, func() { - userConfig.config = nil - userConfig.sysConfig = nil - userConfig.exConfig = nil - userConfig.allHosts = nil - userConfig.wildcardPatterns = nil - userConfig = nil - }) - var err error userHomeDir, err = os.UserHomeDir() if err != nil { @@ -358,14 +349,16 @@ func getAllExConfig(alias, key string) []string { func getAllHosts() []*sshHost { userConfig.loadHosts.Do(func() { userConfig.doLoadConfig() - if userConfig.config != nil { userConfig.allHosts = append(userConfig.allHosts, recursiveGetHosts(userConfig.config.Hosts)...) } - if userConfig.sysConfig != nil { userConfig.allHosts = append(userConfig.allHosts, recursiveGetHosts(userConfig.sysConfig.Hosts)...) } + afterLoginFuncs = append(afterLoginFuncs, func() { + userConfig.allHosts = nil + userConfig.wildcardPatterns = nil + }) }) return userConfig.allHosts diff --git a/tssh/login.go b/tssh/login.go index 11c9d3a..ad3533c 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -39,6 +39,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/skeema/knownhosts" @@ -183,8 +184,23 @@ func getLoginParam(args *sshArgs) (*loginParam, error) { return param, nil } +var acceptHostKeys []string +var sshLoginSuccess atomic.Bool + func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey, ask bool) error { + keyNormalizedLine := knownhosts.Line([]string{host}, key) + for _, acceptKey := range acceptHostKeys { + if acceptKey == keyNormalizedLine { + return nil + } + } + if ask { + if sshLoginSuccess.Load() { + fmt.Fprintf(os.Stderr, "\r\n\033[0;31mThe public key of the remote server has changed after login.\033[0m\r\n") + return fmt.Errorf("host key changed") + } + fingerprint := ssh.FingerprintSHA256(key) fmt.Fprintf(os.Stderr, "The authenticity of host '%s' can't be established.\r\n"+ "%s key fingerprint is %s.\r\n", host, key.Type(), fingerprint) @@ -216,6 +232,8 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey, ask bool) } } + acceptHostKeys = append(acceptHostKeys, keyNormalizedLine) + writeKnownHost := func() error { file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600) if err != nil { @@ -272,6 +290,9 @@ func getHostKeyCallback(args *sshArgs) (ssh.HostKeyCallback, knownhosts.HostKeyC cb := func(host string, remote net.Addr, key ssh.PublicKey) error { err := kh(host, remote, key) + if err == nil { + return nil + } strictHostKeyChecking := strings.ToLower(getOptionConfig(args, "StrictHostKeyChecking")) if knownhosts.IsHostKeyChanged(err) { path := primaryPath @@ -906,13 +927,11 @@ func sshLogin(args *sshArgs, tty bool) (client *ssh.Client, session *ssh.Session if client != nil { client.Close() } + } else { + sshLoginSuccess.Store(true) } }() - cleanupAfterLogined = append(cleanupAfterLogined, func() { - getDefaultSigners = nil - }) - // ssh login var control bool client, control, err = sshConnect(args, nil, "") diff --git a/tssh/main.go b/tssh/main.go index 0858c84..039a4c5 100644 --- a/tssh/main.go +++ b/tssh/main.go @@ -110,11 +110,11 @@ func cleanupOnExit() { } } -var cleanupAfterLogined []func() +var afterLoginFuncs []func() -func cleanupForGC() { - for i := len(cleanupAfterLogined) - 1; i >= 0; i-- { - cleanupAfterLogined[i]() +func cleanupAfterLogin() { + for i := len(afterLoginFuncs) - 1; i >= 0; i-- { + afterLoginFuncs[i]() } } @@ -276,14 +276,14 @@ func sshStart(args *sshArgs) error { if err != nil { return err } - cleanupForGC() + cleanupAfterLogin() wg.Wait() return nil } // no command if args.NoCommand { - cleanupForGC() + cleanupAfterLogin() _ = client.Wait() return nil } @@ -328,7 +328,7 @@ func sshStart(args *sshArgs) error { } // cleanup and wait for exit - cleanupForGC() + cleanupAfterLogin() _ = session.Wait() if args.Background { _ = client.Wait() diff --git a/tssh/tmgr_iterm2.go b/tssh/tmgr_iterm2.go index cca57d2..1fc943b 100644 --- a/tssh/tmgr_iterm2.go +++ b/tssh/tmgr_iterm2.go @@ -196,7 +196,7 @@ func getIterm2Manager() terminalManager { debug("new iTerm2 app failed: %v", err) return nil } - cleanupAfterLogined = append(cleanupAfterLogined, func() { + afterLoginFuncs = append(afterLoginFuncs, func() { app.Close() }) debug("running in iTerm2")