diff --git a/go.mod b/go.mod index 3b7b11a9..d4ba01ce 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/go-logr/logr v1.4.1 github.com/google/go-cmp v0.6.0 github.com/google/go-github/v61 v61.0.0 + github.com/kevinburke/ssh_config v1.2.1-0.20231022042432-1d09c0b50564 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.9.0 k8s.io/api v0.29.1 @@ -65,7 +66,6 @@ require ( github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect @@ -119,4 +119,4 @@ require ( k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect -) \ No newline at end of file +) diff --git a/go.sum b/go.sum index 73c2398c..3092696a 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= -github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= +github.com/kevinburke/ssh_config v1.2.1-0.20231022042432-1d09c0b50564 h1:5RWThNvilNZUvijb0BuXorNxnYuHtKWb8eNZfTgZukU= +github.com/kevinburke/ssh_config v1.2.1-0.20231022042432-1d09c0b50564/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= diff --git a/pkg/util/git_repository.go b/pkg/util/git_repository.go index f5282ccf..300a4999 100644 --- a/pkg/util/git_repository.go +++ b/pkg/util/git_repository.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "io" + "os" + "os/user" "path/filepath" "strings" "sync" @@ -17,7 +19,10 @@ import ( "github.com/go-git/go-billy/v5/memfs" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing" + "github.com/go-git/go-git/v5/plumbing/transport" + "github.com/go-git/go-git/v5/plumbing/transport/ssh" "github.com/go-git/go-git/v5/storage/memory" + "github.com/kevinburke/ssh_config" ) type RepoMap struct { @@ -142,26 +147,57 @@ func CloneRemoteRepoToDir(ctx context.Context, remote v1alpha1.RemoteRepositoryS repo, err := git.PlainOpen(dir) if err != nil { if errors.Is(err, git.ErrRepositoryNotExists) { + ep, eErr := transport.NewEndpoint(remote.Url) + if eErr != nil { + return nil, nil, fmt.Errorf("reading endpoint %s: %w", remote.Url, eErr) + } + + var auth transport.AuthMethod + if ep.Protocol == "ssh" { + a, aErr := ssh.DefaultAuthBuilder(ep.User) + if aErr != nil { + // go-git default auth relies on ssh agent. if not available, get from ~/.ssh/config. + if strings.Contains(aErr.Error(), "SSH agent requested but SSH_AUTH_SOCK not-specified") { + sshConfigPath, sErr := getSSHConfigAbsPath() + if sErr != nil { + return nil, nil, fmt.Errorf("getting ssh config file: %w", sErr) + } + + au, sErr := getSSHKeyAuth(sshConfigPath, ep.Host, ep.User) + if sErr != nil { + return nil, nil, fmt.Errorf("ssh key auth: %w", sErr) + } + + auth = au + } else { + return nil, nil, aErr + } + } else { + auth = a + } + } + cloneOptions := &git.CloneOptions{ URL: remote.Url, Depth: depth, ShallowSubmodules: true, Tags: git.AllTags, InsecureSkipTLS: insecureSkipTLS, + Auth: auth, } if remote.CloneSubmodules { cloneOptions.RecurseSubmodules = git.DefaultSubmoduleRecursionDepth } - repo, err = git.PlainCloneContext(ctx, dir, false, cloneOptions) - if err != nil { + repo, eErr = git.PlainCloneContext(ctx, dir, false, cloneOptions) + if eErr != nil { if fallbackUrl != "" { cloneOptions.URL = fallbackUrl - repo, err = git.PlainCloneContext(ctx, dir, false, cloneOptions) - if err != nil { - return nil, nil, fmt.Errorf("cloning repo with fall back url: %w", err) + repo, eErr = git.PlainCloneContext(ctx, dir, false, cloneOptions) + if eErr != nil { + return nil, nil, fmt.Errorf("cloning repo with fall back url: %w", eErr) } } - return nil, nil, fmt.Errorf("cloning repo: %w", err) + return nil, nil, fmt.Errorf("cloning repo: %w", eErr) } } else { return nil, nil, fmt.Errorf("opening repo at %s %w", dir, err) @@ -269,3 +305,99 @@ func checkoutCommitOrRef(ctx context.Context, wt *git.Worktree, ref string) erro return nil } + +func getKeyfileAbsPath(relativePath string) (string, error) { + var absPath string + if strings.HasPrefix(relativePath, "~/") { + usr, err := user.Current() + if err != nil { + return "", err + } + keyFileAbs, err := filepath.Abs(filepath.Join(usr.HomeDir, relativePath[2:])) + if err != nil { + return "", err + } + absPath = keyFileAbs + } else { + keyFileAbs, err := filepath.Abs(relativePath) + if err != nil { + return "", err + } + absPath = keyFileAbs + } + return absPath, nil +} + +func getSSHKeyAuth(configPath, host, user string) (transport.AuthMethod, error) { + f, err := os.Open(configPath) + if err != nil { + return nil, err + } + + conf, err := ssh_config.Decode(f) + if err != nil { + return nil, err + } + + keyFileRelativePath, err := conf.Get(host, "IdentityFile") + if err != nil { + return nil, err + } + + // no key specified in config, find the default key + if keyFileRelativePath == "" { + homeDir, hErr := getHomeDir() + if hErr != nil { + return nil, hErr + } + // from `man ssh` on Mac OpenSSH_9.7p1, LibreSSL 3.3.6 + keyFiles := []string{ + "id_rsa", + "id_ecdsa", + "id_ecdsa_sk", + "id_ed25519", + "id_ed25519_sk", + "id_dsa", + } + for _, file := range keyFiles { + path := filepath.Join(homeDir, ".ssh", file) + if _, sErr := os.Stat(path); sErr == nil { + keyFileRelativePath = path + break + } + } + if keyFileRelativePath == "" { + return nil, fmt.Errorf("private key not speficied for %s. could not find default key", host) + } + } + + absPath, err := getKeyfileAbsPath(keyFileRelativePath) + if err != nil { + return nil, err + } + + auth, err := ssh.NewPublicKeysFromFile(user, absPath, "") + if err != nil { + return nil, err + } + return auth, nil +} + +func getSSHConfigAbsPath() (string, error) { + homeDir, err := getHomeDir() + if err != nil { + return "", err + } + return filepath.Abs(filepath.Join(homeDir, ".ssh/config")) +} + +func getHomeDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err + } + if homeDir == "" { + return "", fmt.Errorf("user does not have the home direcotry") + } + return homeDir, nil +} diff --git a/pkg/util/git_repository_test.go b/pkg/util/git_repository_test.go index 6e224657..b815017c 100644 --- a/pkg/util/git_repository_test.go +++ b/pkg/util/git_repository_test.go @@ -2,6 +2,11 @@ package util import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" "os" "path/filepath" "strings" @@ -11,6 +16,7 @@ import ( "github.com/go-git/go-billy/v5" "github.com/go-git/go-billy/v5/memfs" "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/transport/ssh" "github.com/go-git/go-git/v5/storage/memory" "github.com/stretchr/testify/assert" ) @@ -114,3 +120,85 @@ func TestGetWorktreeYamlFiles(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, 0, len(paths)) } + +func TestGetKeyfileAbsPath(t *testing.T) { + homeDir, _ := getHomeDir() + cwd, _ := os.Getwd() + tests := []struct { + name string + input string + expected string + hasError bool + }{ + {"Relative path", "testkey", filepath.Join(cwd, "testkey"), false}, + {"Home directory", "~/testkey", filepath.Join(homeDir, "testkey"), false}, + {"Absolute path", "/tmp/testkey", "/tmp/testkey", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := getKeyfileAbsPath(tt.input) + if tt.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestGetSSHKeyAuth(t *testing.T) { + // Create a temporary SSH config file + sshConfFile, err := os.CreateTemp("", "sshconfig") + assert.NoError(t, err) + defer os.Remove(sshConfFile.Name()) + + keyPath, err := createTestPrivateKey() + assert.NoError(t, err) + defer os.Remove(keyPath) + + _, _ = sshConfFile.Write([]byte(fmt.Sprintf("Host testhost\nIdentityFile %s", keyPath))) + sshConfFile.Close() + + auth, err := getSSHKeyAuth(sshConfFile.Name(), "testhost", "git") + assert.NoError(t, err) + assert.IsType(t, &ssh.PublicKeys{}, auth) + + _, err = getSSHKeyAuth("/nonexistent/path", "testhost", "git") + assert.Error(t, err) + + _, err = getSSHKeyAuth(sshConfFile.Name(), "not-in-config", "git") + assert.Error(t, err) +} + +func TestGetSSHConfigAbsPath(t *testing.T) { + expected, err := filepath.Abs(filepath.Join(os.Getenv("HOME"), ".ssh/config")) + assert.NoError(t, err) + + result, err := getSSHConfigAbsPath() + assert.NoError(t, err) + assert.True(t, filepath.IsAbs(result)) + assert.Equal(t, expected, result) +} + +func createTestPrivateKey() (string, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", err + } + + privKeyPEM := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + + keyfile, err := os.CreateTemp("", "key") + if err != nil { + return "", err + } + defer keyfile.Close() + + pem.Encode(keyfile, privKeyPEM) + return keyfile.Name(), nil +}