Skip to content

Commit

Permalink
improve known hosts #77
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Jan 27, 2024
1 parent 7cc598d commit f016236
Showing 1 changed file with 76 additions and 34 deletions.
110 changes: 76 additions & 34 deletions tssh/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,32 @@ func getLoginParam(args *sshArgs) (*loginParam, error) {
var acceptHostKeys []string
var sshLoginSuccess atomic.Bool

func ensureNewline(file *os.File) error {
if _, err := file.Seek(-1, io.SeekEnd); err != nil {
return nil
}
buf := make([]byte, 1)
if n, err := file.Read(buf); err != nil || n != 1 || buf[0] == '\n' {
return nil
}
if _, err := file.Write([]byte("\n")); err != nil {
return err
}
return nil
}

func writeKnownHost(path, host string, remote net.Addr, key ssh.PublicKey) error {
file, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
if err != nil {
return err
}
defer file.Close()
if err := ensureNewline(file); err != nil {
return err
}
return knownhosts.WriteKnownHost(file, host, remote, key)
}

func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey, ask bool) error {
keyNormalizedLine := knownhosts.Line([]string{host}, key)
for _, acceptKey := range acceptHostKeys {
Expand Down Expand Up @@ -236,16 +262,7 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey, ask bool)

acceptHostKeys = append(acceptHostKeys, keyNormalizedLine)

writeKnownHost := func() error {
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
return err
}
defer file.Close()
return knownhosts.WriteKnownHost(file, host, remote, key)
}

if err := writeKnownHost(); err != nil {
if err := writeKnownHost(path, host, remote, key); err != nil {
warning("Failed to add the host to the list of known hosts (%s): %v", path, err)
return nil
}
Expand All @@ -254,35 +271,51 @@ func addHostKey(path, host string, remote net.Addr, key ssh.PublicKey, ask bool)
return nil
}

func getHostKeyCallback(args *sshArgs) (ssh.HostKeyCallback, knownhosts.HostKeyCallback, error) {
func getHostKeyCallback(args *sshArgs, param *loginParam) (ssh.HostKeyCallback, knownhosts.HostKeyCallback, error) {
primaryPath := ""
var files []string
userFile := getOptionConfig(args, "UserKnownHostsFile")
if userFile != "" && strings.ToLower(userFile) != "none" {
for _, path := range strings.Fields(userFile) {
path = resolveHomeDir(path)
if primaryPath == "" {
primaryPath = path
}
if isFileExist(path) {
files = append(files, path)
debug("add UserKnownHostsFile: %s", path)
} else {
debug("UserKnownHostsFile [%s] does not exist", path)
}
addKnownHostsFiles := func(key string, user bool) error {
knownHostsFiles := getOptionConfig(args, key)
if knownHostsFiles == "" || user && strings.ToLower(knownHostsFiles) == "none" {
debug("%s is empty or none", key)
return nil
}
}
globalFile := getOptionConfig(args, "GlobalKnownHostsFile")
if globalFile != "" {
for _, path := range strings.Fields(globalFile) {
path = resolveHomeDir(path)
if isFileExist(path) {
files = append(files, path)
debug("add GlobalKnownHostsFile: %s", path)
for _, path := range strings.Fields(knownHostsFiles) {
var resolvedPath string
if user {
expandedPath, err := expandTokens(path, args, param, "%CdhikLlnpru")
if err != nil {
return err
}
resolvedPath = resolveHomeDir(expandedPath)
if primaryPath == "" {
primaryPath = resolvedPath
}
} else {
debug("GlobalKnownHostsFile [%s] does not exist", path)
resolvedPath = path
}
if !isFileExist(resolvedPath) {
debug("%s [%s] does not exist", key, resolvedPath)
continue
}
if !canReadFile(resolvedPath) {
if user {
warning("%s [%s] can't be read", key, resolvedPath)
} else {
debug("%s [%s] can't be read", key, resolvedPath)
}
continue
}
debug("add %s: %s", key, resolvedPath)
files = append(files, resolvedPath)
}
return nil
}
if err := addKnownHostsFiles("UserKnownHostsFile", true); err != nil {
return nil, nil, err
}
if err := addKnownHostsFiles("GlobalKnownHostsFile", false); err != nil {
return nil, nil, err
}

kh, err := knownhosts.New(files...)
Expand Down Expand Up @@ -420,6 +453,15 @@ func isFileExist(path string) bool {
return true
}

func canReadFile(path string) bool {
file, err := os.Open(path)
if err != nil {
return false
}
file.Close()
return true
}

func getSigner(dest string, path string) *sshSigner {
path = resolveHomeDir(path)
privateKey, err := os.ReadFile(path)
Expand Down Expand Up @@ -862,7 +904,7 @@ func sshConnect(args *sshArgs, client *ssh.Client, proxy string) (*ssh.Client, *
}

authMethods := getAuthMethods(args, param.host, param.user)
cb, kh, err := getHostKeyCallback(args)
cb, kh, err := getHostKeyCallback(args, param)
if err != nil {
return nil, param, false, err
}
Expand Down

0 comments on commit f016236

Please sign in to comment.