diff --git a/cli/internal/cmd/BUILD.bazel b/cli/internal/cmd/BUILD.bazel index d02b5a0884a..b7368313c43 100644 --- a/cli/internal/cmd/BUILD.bazel +++ b/cli/internal/cmd/BUILD.bazel @@ -91,6 +91,7 @@ go_library( "@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions", "@io_k8s_apimachinery//pkg/runtime", "@io_k8s_client_go//tools/clientcmd", + "@io_k8s_client_go//tools/clientcmd/api", "@io_k8s_client_go//tools/clientcmd/api/latest", "@io_k8s_sigs_yaml//:yaml", "@org_golang_google_grpc//:go_default_library", @@ -172,6 +173,8 @@ go_test( "@io_k8s_api//core/v1:core", "@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions", "@io_k8s_apimachinery//pkg/apis/meta/v1:meta", + "@io_k8s_client_go//tools/clientcmd/api", + "@io_k8s_sigs_yaml//:yaml", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index a459d394490..0e71da07b07 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -472,7 +472,23 @@ func (i *initCmd) writeOutput( tw.Flush() fmt.Fprintln(wr) - if err := i.fileHandler.Write(constants.AdminConfFilename, initResp.GetKubeconfig(), file.OptNone); err != nil { + i.log.Debugf("Rewriting cluster server address in kubeconfig to %s", idFile.IP) + kubeconfig, err := clientcmd.Load(initResp.GetKubeconfig()) + if err != nil { + return fmt.Errorf("loading kubeconfig: %w", err) + } + if len(kubeconfig.Clusters) != 1 { + return fmt.Errorf("expected exactly one cluster in kubeconfig, got %d", len(kubeconfig.Clusters)) + } + for _, cluster := range kubeconfig.Clusters { + cluster.Server = "https://" + net.JoinHostPort(idFile.IP, strconv.Itoa(constants.KubernetesPort)) + } + kubeconfigBytes, err := clientcmd.Write(*kubeconfig) + if err != nil { + return fmt.Errorf("marshaling kubeconfig: %w", err) + } + + if err := i.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil { return fmt.Errorf("writing kubeconfig: %w", err) } i.log.Debugf("Kubeconfig written to %s", i.pf.PrefixPrintablePath(constants.AdminConfFilename)) diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 75d39d6d1f5..c95a0c7f2d4 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -44,6 +44,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "k8s.io/client-go/tools/clientcmd" + k8sclientapi "k8s.io/client-go/tools/clientcmd/api" ) func TestInitArgumentValidation(t *testing.T) { @@ -339,12 +341,34 @@ func TestWriteOutput(t *testing.T) { assert := assert.New(t) require := require.New(t) + clusterEndpoint := "cluster-endpoint" + + expectedKubeconfig := k8sclientapi.Config{ + Clusters: map[string]*k8sclientapi.Cluster{ + "cluster": { + Server: "https://" + clusterEndpoint + ":6443", + }, + }, + } + expectedKubeconfigBytes, err := clientcmd.Write(expectedKubeconfig) + require.NoError(err) + + respKubeconfig := k8sclientapi.Config{ + Clusters: map[string]*k8sclientapi.Cluster{ + "cluster": { + Server: "https://192.0.2.1:6443", + }, + }, + } + respKubeconfigBytes, err := clientcmd.Write(respKubeconfig) + require.NoError(err) + resp := &initproto.InitResponse{ Kind: &initproto.InitResponse_InitSuccess{ InitSuccess: &initproto.InitSuccessResponse{ OwnerId: []byte("ownerID"), ClusterId: []byte("clusterID"), - Kubeconfig: []byte("kubeconfig"), + Kubeconfig: respKubeconfigBytes, }, }, } @@ -355,7 +379,7 @@ func TestWriteOutput(t *testing.T) { expectedIDFile := clusterid.File{ ClusterID: clusterID, OwnerID: ownerID, - IP: "cluster-ip", + IP: clusterEndpoint, UID: "test-uid", } @@ -365,10 +389,10 @@ func TestWriteOutput(t *testing.T) { idFile := clusterid.File{ UID: "test-uid", - IP: "cluster-ip", + IP: clusterEndpoint, } i := newInitCmd(nil, fileHandler, &nopSpinner{}, &stubMerger{}, logger.NewTest(t)) - err := i.writeOutput(idFile, resp.GetInitSuccess(), false, &out) + err = i.writeOutput(idFile, resp.GetInitSuccess(), false, &out) require.NoError(err) // assert.Contains(out.String(), ownerID) assert.Contains(out.String(), clusterID) @@ -377,7 +401,8 @@ func TestWriteOutput(t *testing.T) { afs := afero.Afero{Fs: testFs} adminConf, err := afs.ReadFile(constants.AdminConfFilename) assert.NoError(err) - assert.Equal(string(resp.GetInitSuccess().GetKubeconfig()), string(adminConf)) + assert.Contains(string(adminConf), clusterEndpoint) + assert.Equal(string(expectedKubeconfigBytes), string(adminConf)) idsFile, err := afs.ReadFile(constants.ClusterIDsFilename) assert.NoError(err)