diff --git a/internal/castai/types.go b/internal/castai/types.go index be884169..46dc94c4 100644 --- a/internal/castai/types.go +++ b/internal/castai/types.go @@ -18,11 +18,19 @@ type GKEParams struct { ClusterName string `json:"clusterName"` } +type KOPSParams struct { + CSP string `json:"cloud"` + Region string `json:"region"` + ClusterName string `json:"clusterName"` + StateStore string `json:"stateStore"` +} + type RegisterClusterRequest struct { - ID uuid.UUID `json:"id"` - Name string `json:"name"` - EKS *EKSParams `json:"eks"` - GKE *GKEParams `json:"gke"` + ID uuid.UUID `json:"id"` + Name string `json:"name"` + EKS *EKSParams `json:"eks"` + GKE *GKEParams `json:"gke"` + KOPS *KOPSParams `json:"kops"` } type Cluster struct { diff --git a/internal/config/config.go b/internal/config/config.go index 392943c7..9156cb02 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,6 +15,7 @@ type Config struct { CASTAI *CASTAI EKS *EKS GKE *GKE + KOPS *KOPS } type Log struct { @@ -43,6 +44,13 @@ type GKE struct { ClusterName string } +type KOPS struct { + CSP string + Region string + ClusterName string + StateStore string +} + var cfg *Config // Get configuration bound to environment variables. @@ -71,6 +79,11 @@ func Get() Config { _ = viper.BindEnv("gke.projectid", "GKE_PROJECT_ID") _ = viper.BindEnv("gke.clustername", "GKE_CLUSTER_NAME") + _ = viper.BindEnv("kops.csp", "KOPS_CSP") + _ = viper.BindEnv("kops.region", "KOPS_REGION") + _ = viper.BindEnv("kops.clustername", "KOPS_CLUSTER_NAME") + _ = viper.BindEnv("kops.statestore", "KOPS_STATE_STORE") + cfg = &Config{} if err := viper.Unmarshal(&cfg); err != nil { panic(fmt.Errorf("parsing configuration: %v", err)) @@ -89,34 +102,49 @@ func Get() Config { if cfg.CASTAI != nil { if cfg.CASTAI.ClusterID == "" { - requiredDiscoveryDisabled("CASTAI_CLUSTER_ID") + requiredWhenDiscoveryDisabled("CASTAI_CLUSTER_ID") } if cfg.CASTAI.OrganizationID == "" { - requiredDiscoveryDisabled("CASTAI_ORGANIZATION_ID") + requiredWhenDiscoveryDisabled("CASTAI_ORGANIZATION_ID") } } if cfg.EKS != nil { if cfg.EKS.AccountID == "" { - requiredDiscoveryDisabled("EKS_ACCOUNT_ID") + requiredWhenDiscoveryDisabled("EKS_ACCOUNT_ID") } if cfg.EKS.Region == "" { - requiredDiscoveryDisabled("EKS_REGION") + requiredWhenDiscoveryDisabled("EKS_REGION") } if cfg.EKS.ClusterName == "" { - requiredDiscoveryDisabled("EKS_CLUSTER_NAME") + requiredWhenDiscoveryDisabled("EKS_CLUSTER_NAME") } } if cfg.GKE != nil { if cfg.GKE.Region == "" { - requiredDiscoveryDisabled("GKE_REGION") + requiredWhenDiscoveryDisabled("GKE_REGION") } if cfg.GKE.ProjectID == "" { - requiredDiscoveryDisabled("GKE_PROJECT_ID") + requiredWhenDiscoveryDisabled("GKE_PROJECT_ID") } if cfg.GKE.ClusterName == "" { - requiredDiscoveryDisabled("GKE_CLUSTER_NAME") + requiredWhenDiscoveryDisabled("GKE_CLUSTER_NAME") + } + } + + if cfg.KOPS != nil { + if cfg.KOPS.CSP == "" { + requiredWhenDiscoveryDisabled("KOPS_CSP") + } + if cfg.KOPS.Region == "" { + requiredWhenDiscoveryDisabled("KOPS_REGION") + } + if cfg.KOPS.ClusterName == "" { + requiredWhenDiscoveryDisabled("KOPS_CLUSTER_NAME") + } + if cfg.KOPS.StateStore == "" { + requiredWhenDiscoveryDisabled("KOPS_STATE_STORE") } } @@ -132,6 +160,6 @@ func required(variable string) { panic(fmt.Errorf("env variable %s is required", variable)) } -func requiredDiscoveryDisabled(variable string) { +func requiredWhenDiscoveryDisabled(variable string) { panic(fmt.Errorf("env variable %s is required when discovery is disabled", variable)) } diff --git a/internal/services/providers/gke/gke.go b/internal/services/providers/gke/gke.go index 4b28d735..3fea2bca 100644 --- a/internal/services/providers/gke/gke.go +++ b/internal/services/providers/gke/gke.go @@ -16,10 +16,10 @@ import ( const ( Name = "gke" - labelPreemptible = "cloud.google.com/gke-preemptible" + LabelPreemptible = "cloud.google.com/gke-preemptible" ) -func New(_ context.Context, log logrus.FieldLogger) (types.Provider, error) { +func New(log logrus.FieldLogger) (types.Provider, error) { return &Provider{log: log}, nil } @@ -53,7 +53,7 @@ func (p *Provider) IsSpot(_ context.Context, node *corev1.Node) (bool, error) { return true, nil } - if val, ok := node.Labels[labelPreemptible]; ok && val == "true" { + if val, ok := node.Labels[LabelPreemptible]; ok && val == "true" { return true, nil } diff --git a/internal/services/providers/gke/gke_test.go b/internal/services/providers/gke/gke_test.go index 20253395..03c3db41 100644 --- a/internal/services/providers/gke/gke_test.go +++ b/internal/services/providers/gke/gke_test.go @@ -66,7 +66,7 @@ func TestProvider_IsSpot(t *testing.T) { }, { name: "gke spot node", - node: &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{labelPreemptible: "true"}}}, + node: &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{LabelPreemptible: "true"}}}, expected: true, }, { diff --git a/internal/services/providers/kops/kops.go b/internal/services/providers/kops/kops.go new file mode 100644 index 00000000..e4a936e7 --- /dev/null +++ b/internal/services/providers/kops/kops.go @@ -0,0 +1,327 @@ +package kops + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + + "castai-agent/internal/castai" + "castai-agent/internal/config" + awsclient "castai-agent/internal/services/providers/eks/client" + "castai-agent/internal/services/providers/gke" + "castai-agent/internal/services/providers/types" + "castai-agent/pkg/labels" +) + +const Name = "kops" + +func New(log logrus.FieldLogger, clientset kubernetes.Interface) (types.Provider, error) { + return &Provider{ + log: log, + clientset: clientset, + }, nil +} + +type Provider struct { + log logrus.FieldLogger + clientset kubernetes.Interface + awsClient awsclient.Client + csp string +} + +func (p *Provider) RegisterCluster(ctx context.Context, client castai.Client) (*types.ClusterRegistration, error) { + ns, err := p.clientset.CoreV1().Namespaces().Get(ctx, metav1.NamespaceSystem, metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("getting namespace %q: %w", metav1.NamespaceSystem, err) + } + + clusterID, err := getClusterID(ns) + if err != nil { + return nil, fmt.Errorf("getting cluster id: %w", err) + } + + var csp, region, clusterName, stateStore string + + if cfg := config.Get().KOPS; cfg != nil { + csp, region, clusterName, stateStore = cfg.CSP, cfg.Region, cfg.ClusterName, cfg.StateStore + } else { + c, r, err := p.getCSPAndRegion(ctx, "") + if err != nil { + return nil, fmt.Errorf("getting csp and region: %w", err) + } + + n, s, err := p.getClusterNameAndStateStore(ns) + if err != nil { + return nil, fmt.Errorf("getting cluster name and state store: %w", err) + } + + csp, region, clusterName, stateStore = *c, *r, *n, *s + } + + p.log.Infof( + "discovered kops cluster properties csp=%s region=%s cluster_name=%s state_store=%s", + csp, + region, + clusterName, + stateStore, + ) + + p.csp = csp + + if p.csp == "aws" { + opts := []awsclient.Opt{ + awsclient.WithMetadata("", region, clusterName), + awsclient.WithEC2Client(), + } + c, err := awsclient.New(ctx, p.log, opts...) + if err != nil { + p.log.Errorf( + "failed initializing aws client, spot functionality for savings estimation will be reduced: %v", + err, + ) + } else { + p.awsClient = c + } + } + + resp, err := client.RegisterCluster(ctx, &castai.RegisterClusterRequest{ + ID: *clusterID, + Name: clusterName, + KOPS: &castai.KOPSParams{ + CSP: csp, + Region: region, + ClusterName: clusterName, + StateStore: stateStore, + }, + }) + if err != nil { + return nil, err + } + + return &types.ClusterRegistration{ + ClusterID: resp.ID, + OrganizationID: resp.OrganizationID, + }, nil +} + +func (p *Provider) IsSpot(ctx context.Context, node *v1.Node) (bool, error) { + if val, ok := node.Labels[labels.Spot]; ok && val == "true" { + return true, nil + } + + if p.csp == "aws" && p.awsClient != nil { + hostname, ok := node.Labels[v1.LabelHostname] + if !ok { + return false, fmt.Errorf("label %s not found on node %s", v1.LabelHostname, node.Name) + } + + instances, err := p.awsClient.GetInstancesByPrivateDNS(ctx, []string{hostname}) + if err != nil { + return false, fmt.Errorf("getting instances by hostname: %w", err) + } + + for _, instance := range instances { + if instance.InstanceLifecycle != nil && *instance.InstanceLifecycle == "spot" { + return true, nil + } + } + } + + if p.csp == "gcp" { + if val, ok := node.Labels[gke.LabelPreemptible]; ok && val == "true" { + return true, nil + } + } + + return false, nil +} + +func (p *Provider) Name() string { + return Name +} + +// getClusterNameAndStateStore discovers the cluster name and kOps state store bucket from the kube-system namespace +// annotation. kOps annotates the kube-system namespace with annotations such as this: +// * addons.k8s.io/core.addons.k8s.io: '{"version":"1.4.0","channel":"s3://bucket/cluster-name/addons/bootstrap-channel.yaml","manifestHash":"hash"}' +// We can retrieve the state store bucket name and the cluster name from the "channel" property of the annotation value. +func (p *Provider) getClusterNameAndStateStore(ns *v1.Namespace) (clusterName, stateStore *string, reterr error) { + for k, v := range ns.Annotations { + manifest, ok := kopsAddonAnnotation(p.log, k, v) + if !ok { + continue + } + + path := manifest.Channel.Path + if path[0] == '/' { + path = path[1:] + } + + name := strings.Split(path, "/")[0] + store := fmt.Sprintf("%s://%s", manifest.Channel.Scheme, manifest.Channel.Host) + + return &name, &store, nil + } + + return nil, nil, errors.New("failed discovering cluster properties: cluster name, state store") +} + +// getCSPAndRegion discovers the cluster cloud service provider (CSP) and the region the cluster is deployed in by +// listing the cluster nodes and inspecting their labels. CSP is retrieved by parsing the Node.Spec.ProviderID property. +// Whereas the region is read from the well-known node region labels. +func (p *Provider) getCSPAndRegion(ctx context.Context, next string) (csp, region *string, reterr error) { + nodes, err := p.clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{Limit: 10, Continue: next}) + if err != nil { + return nil, nil, fmt.Errorf("listing nodes: %w", err) + } + + for i, n := range nodes.Items { + ready := false + + for _, cond := range n.Status.Conditions { + if cond.Type == v1.NodeReady && cond.Status == v1.ConditionTrue { + ready = true + break + } + } + + if !ready { + continue + } + + nodeCSP, ok := getCSP(&nodes.Items[i]) + if ok { + csp = &nodeCSP + } + + nodeRegion, ok := getRegion(&nodes.Items[i]) + if ok { + region = &nodeRegion + } + + if csp != nil && region != nil { + return csp, region, nil + } + } + + if nodes.Continue != "" { + return p.getCSPAndRegion(ctx, nodes.Continue) + } + + var properties []string + if csp == nil { + properties = append(properties, "csp") + } + if region == nil { + properties = append(properties, "region") + } + + return nil, nil, fmt.Errorf("failed discovering properties: %s", strings.Join(properties, ", ")) +} + +type kopsAddonManifest struct { + Version string + Channel url.URL + ID string + ManifestHash string +} + +func kopsAddonAnnotation(log logrus.FieldLogger, k, v string) (*kopsAddonManifest, bool) { + if !strings.HasPrefix(k, "addons.k8s.io/") { + return nil, false + } + + manifest := map[string]interface{}{} + if err := json.Unmarshal([]byte(v), &manifest); err != nil { + log.Debugf("failed unmarshalling %q namespace annotation %q value %s: %v", metav1.NamespaceSystem, k, v, err) + return nil, false + } + + channel := manifest["channel"].(string) + + if len(channel) == 0 { + log.Debugf(`%q namespace annotation %q value %s does not have the "channel" property`, metav1.NamespaceSystem, k, v) + return nil, false + } + + uri, err := url.Parse(channel) + if err != nil { + log.Debugf("%q namespace annotation %q channel value %s is not a valid uri: %v", metav1.NamespaceSystem, k, channel, err) + return nil, false + } + + if len(uri.Scheme) == 0 { + log.Debugf("%q namespace annotation %q channel scheme %s is empty", metav1.NamespaceSystem, k, channel) + return nil, false + } + + if len(uri.Host) == 0 { + log.Debugf("%q namespace annotation %q channel host %s is empty", metav1.NamespaceSystem, k, channel) + return nil, false + } + + if len(uri.Path) == 0 { + log.Debugf("%q namespace annotation %q channel path %s is empty", metav1.NamespaceSystem, k, channel) + return nil, false + } + + var version, id, hash string + if val, ok := manifest["version"]; ok { + version = val.(string) + } + if val, ok := manifest["id"]; ok { + id = val.(string) + } + if val, ok := manifest["manifestHash"]; ok { + hash = val.(string) + } + + return &kopsAddonManifest{ + Channel: *uri, + Version: version, + ID: id, + ManifestHash: hash, + }, true +} + +func getClusterID(ns *v1.Namespace) (*uuid.UUID, error) { + clusterID, err := uuid.Parse(string(ns.UID)) + if err != nil { + return nil, fmt.Errorf("parsing namespace %q uid: %w", metav1.NamespaceSystem, err) + } + return &clusterID, nil +} + +func getRegion(n *v1.Node) (string, bool) { + if val, ok := n.Labels[v1.LabelTopologyRegion]; ok { + return val, true + } + + if val, ok := n.Labels[v1.LabelFailureDomainBetaRegion]; ok { + return val, true + } + + return "", false +} + +func getCSP(n *v1.Node) (string, bool) { + providerID := n.Spec.ProviderID + + if strings.HasPrefix(providerID, "gce://") { + return "gcp", true + } + + if strings.HasPrefix(providerID, "aws://") { + return "aws", true + } + + return "", false +} diff --git a/internal/services/providers/kops/kops_test.go b/internal/services/providers/kops/kops_test.go new file mode 100644 index 00000000..0a6fd1ae --- /dev/null +++ b/internal/services/providers/kops/kops_test.go @@ -0,0 +1,260 @@ +package kops + +import ( + "context" + "os" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + k8stypes "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/utils/pointer" + + "castai-agent/internal/castai" + mock_castai "castai-agent/internal/castai/mock" + "castai-agent/internal/config" + mock_awsclient "castai-agent/internal/services/providers/eks/client/mock" + "castai-agent/internal/services/providers/gke" + "castai-agent/internal/services/providers/types" + "castai-agent/pkg/labels" +) + +func TestProvider_RegisterCluster(t *testing.T) { + t.Run("autodiscover cluster properties", func(t *testing.T) { + require.NoError(t, os.Setenv("API_KEY", "123")) + require.NoError(t, os.Setenv("API_URL", "test")) + + t.Cleanup(config.Reset) + t.Cleanup(os.Clearenv) + + var objects []runtime.Object + + namespaceID := uuid.New() + + objects = append(objects, &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + UID: k8stypes.UID(namespaceID.String()), + Name: metav1.NamespaceSystem, + Annotations: map[string]string{ + "addons.k8s.io/core.addons.k8s.io": `{"version":"1.4.0","channel":"s3://test-kops/test.k8s.local/addons/bootstrap-channel.yaml","manifestHash":"3ffe9ac576f9eec72e2bdfbd2ea17d56d9b17b90"}`, + }, + }, + }) + + // Simulate a large cluster with broken nodes. + for i := 0; i < 100; i++ { + objects = append(objects, &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "broken-" + strconv.Itoa(i), + Labels: map[string]string{}, + }, + }) + } + + objects = append(objects, &v1.Node{ + TypeMeta: metav1.TypeMeta{}, + ObjectMeta: metav1.ObjectMeta{ + Name: "normal", + Labels: map[string]string{ + v1.LabelTopologyRegion: "us-east-1", + }, + }, + Spec: v1.NodeSpec{ + ProviderID: "aws://us-east-1a/i-abcdefgh", + }, + Status: v1.NodeStatus{ + Conditions: []v1.NodeCondition{ + { + Type: v1.NodeReady, + Status: v1.ConditionTrue, + }, + }, + }, + }) + + clientset := fake.NewSimpleClientset(objects...) + + p, err := New(logrus.New(), clientset) + require.NoError(t, err) + + castaiclient := mock_castai.NewMockClient(gomock.NewController(t)) + + registrationResp := &types.ClusterRegistration{ + ClusterID: namespaceID.String(), + OrganizationID: uuid.New().String(), + } + + castaiclient.EXPECT().RegisterCluster(gomock.Any(), &castai.RegisterClusterRequest{ + ID: namespaceID, + Name: "test.k8s.local", + KOPS: &castai.KOPSParams{ + CSP: "aws", + Region: "us-east-1", + ClusterName: "test.k8s.local", + StateStore: "s3://test-kops", + }, + }).Return(&castai.RegisterClusterResponse{Cluster: castai.Cluster{ + ID: registrationResp.ClusterID, + OrganizationID: registrationResp.OrganizationID, + }}, nil) + + got, err := p.RegisterCluster(context.Background(), castaiclient) + + require.NoError(t, err) + require.Equal(t, registrationResp, got) + }) + + t.Run("override properties from config", func(t *testing.T) { + require.NoError(t, os.Setenv("API_KEY", "123")) + require.NoError(t, os.Setenv("API_URL", "test")) + require.NoError(t, os.Setenv("KOPS_CSP", "aws")) + require.NoError(t, os.Setenv("KOPS_REGION", "us-east-1")) + require.NoError(t, os.Setenv("KOPS_CLUSTER_NAME", "test.k8s.local")) + require.NoError(t, os.Setenv("KOPS_STATE_STORE", "s3://test-kops")) + + t.Cleanup(config.Reset) + t.Cleanup(os.Clearenv) + + namespaceID := uuid.New() + namespace := &v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + UID: k8stypes.UID(namespaceID.String()), + Name: metav1.NamespaceSystem, + Annotations: map[string]string{ + "addons.k8s.io/core.addons.k8s.io": `{"version":"1.4.0","channel":"s3://test-kops/test.k8s.local/addons/bootstrap-channel.yaml","manifestHash":"3ffe9ac576f9eec72e2bdfbd2ea17d56d9b17b90"}`, + }, + }, + } + + p, err := New(logrus.New(), fake.NewSimpleClientset(namespace)) + require.NoError(t, err) + + castaiclient := mock_castai.NewMockClient(gomock.NewController(t)) + + registrationResp := &types.ClusterRegistration{ + ClusterID: namespaceID.String(), + OrganizationID: uuid.New().String(), + } + + castaiclient.EXPECT().RegisterCluster(gomock.Any(), &castai.RegisterClusterRequest{ + ID: namespaceID, + Name: "test.k8s.local", + KOPS: &castai.KOPSParams{ + CSP: "aws", + Region: "us-east-1", + ClusterName: "test.k8s.local", + StateStore: "s3://test-kops", + }, + }).Return(&castai.RegisterClusterResponse{Cluster: castai.Cluster{ + ID: registrationResp.ClusterID, + OrganizationID: registrationResp.OrganizationID, + }}, nil) + + got, err := p.RegisterCluster(context.Background(), castaiclient) + + require.NoError(t, err) + require.Equal(t, registrationResp, got) + }) +} + +func TestProvider_IsSpot(t *testing.T) { + t.Run("castai managed spot nodes", func(t *testing.T) { + node := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + labels.Spot: "true", + }, + }, + } + + p := &Provider{} + + got, err := p.IsSpot(context.Background(), node) + + require.NoError(t, err) + require.True(t, got) + }) + + t.Run("aws spot nodes", func(t *testing.T) { + node := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + v1.LabelHostname: "hostname", + }, + }, + } + + awsclient := mock_awsclient.NewMockClient(gomock.NewController(t)) + + p := &Provider{ + csp: "aws", + awsClient: awsclient, + } + + awsclient.EXPECT().GetInstancesByPrivateDNS(gomock.Any(), []string{"hostname"}).Return([]*ec2.Instance{ + { + InstanceLifecycle: pointer.StringPtr("spot"), + }, + }, nil) + + got, err := p.IsSpot(context.Background(), node) + + require.NoError(t, err) + require.True(t, got) + }) + + t.Run("gcp spot nodes", func(t *testing.T) { + node := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + gke.LabelPreemptible: "true", + }, + }, + } + + p := &Provider{ + csp: "gcp", + } + + got, err := p.IsSpot(context.Background(), node) + + require.NoError(t, err) + require.True(t, got) + }) + + t.Run("non spot node", func(t *testing.T) { + node := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + v1.LabelHostname: "hostname", + }, + }, + } + + awsclient := mock_awsclient.NewMockClient(gomock.NewController(t)) + + p := &Provider{ + csp: "aws", + awsClient: awsclient, + } + + awsclient.EXPECT().GetInstancesByPrivateDNS(gomock.Any(), []string{"hostname"}).Return([]*ec2.Instance{ + { + InstanceLifecycle: pointer.StringPtr("on-demand"), + }, + }, nil) + + got, err := p.IsSpot(context.Background(), node) + + require.NoError(t, err) + require.False(t, got) + }) +} diff --git a/internal/services/providers/providers.go b/internal/services/providers/providers.go index b5e92ac6..e005054a 100644 --- a/internal/services/providers/providers.go +++ b/internal/services/providers/providers.go @@ -5,15 +5,17 @@ import ( "fmt" "github.com/sirupsen/logrus" + "k8s.io/client-go/kubernetes" "castai-agent/internal/config" "castai-agent/internal/services/providers/castai" "castai-agent/internal/services/providers/eks" "castai-agent/internal/services/providers/gke" + "castai-agent/internal/services/providers/kops" "castai-agent/internal/services/providers/types" ) -func GetProvider(ctx context.Context, log logrus.FieldLogger) (types.Provider, error) { +func GetProvider(ctx context.Context, log logrus.FieldLogger, clientset kubernetes.Interface) (types.Provider, error) { cfg := config.Get() if cfg.Provider == castai.Name || cfg.CASTAI != nil { @@ -25,7 +27,11 @@ func GetProvider(ctx context.Context, log logrus.FieldLogger) (types.Provider, e } if cfg.Provider == gke.Name || cfg.GKE != nil { - return gke.New(ctx, log.WithField("provider", gke.Name)) + return gke.New(log.WithField("provider", gke.Name)) + } + + if cfg.Provider == kops.Name || cfg.KOPS != nil { + return kops.New(log.WithField("provider", kops.Name), clientset) } return nil, fmt.Errorf("unknown provider %q", cfg.Provider) diff --git a/internal/services/providers/providers_test.go b/internal/services/providers/providers_test.go index 579daf00..1f6e20c0 100644 --- a/internal/services/providers/providers_test.go +++ b/internal/services/providers/providers_test.go @@ -12,6 +12,7 @@ import ( "castai-agent/internal/services/providers/castai" "castai-agent/internal/services/providers/eks" "castai-agent/internal/services/providers/gke" + "castai-agent/internal/services/providers/kops" ) func TestGetProvider(t *testing.T) { @@ -23,7 +24,7 @@ func TestGetProvider(t *testing.T) { require.NoError(t, os.Setenv("API_URL", "test")) require.NoError(t, os.Setenv("PROVIDER", "castai")) - got, err := GetProvider(context.Background(), logrus.New()) + got, err := GetProvider(context.Background(), logrus.New(), nil) require.NoError(t, err) require.IsType(t, &castai.Provider{}, got) @@ -40,7 +41,7 @@ func TestGetProvider(t *testing.T) { require.NoError(t, os.Setenv("EKS_ACCOUNT_ID", "accountID")) require.NoError(t, os.Setenv("EKS_REGION", "eu-central-1")) - got, err := GetProvider(context.Background(), logrus.New()) + got, err := GetProvider(context.Background(), logrus.New(), nil) require.NoError(t, err) require.IsType(t, &eks.Provider{}, got) @@ -57,9 +58,23 @@ func TestGetProvider(t *testing.T) { require.NoError(t, os.Setenv("GKE_PROJECT_ID", "projectID")) require.NoError(t, os.Setenv("GKE_REGION", "us-east4")) - got, err := GetProvider(context.Background(), logrus.New()) + got, err := GetProvider(context.Background(), logrus.New(), nil) require.NoError(t, err) require.IsType(t, &gke.Provider{}, got) }) + + t.Run("should return kops", func(t *testing.T) { + t.Cleanup(config.Reset) + t.Cleanup(os.Clearenv) + + require.NoError(t, os.Setenv("API_KEY", "api-key")) + require.NoError(t, os.Setenv("API_URL", "test")) + require.NoError(t, os.Setenv("PROVIDER", "kops")) + + got, err := GetProvider(context.Background(), logrus.New(), nil) + + require.NoError(t, err) + require.IsType(t, &kops.Provider{}, got) + }) } diff --git a/main.go b/main.go index 300a0dd6..579c1817 100644 --- a/main.go +++ b/main.go @@ -70,7 +70,17 @@ func run(ctx context.Context, log logrus.FieldLogger) (reterr error) { log = log.WithFields(fields) log.Infof("running agent version: %v", agentVersion) - provider, err := providers.GetProvider(ctx, log) + restconfig, err := retrieveKubeConfig(log) + if err != nil { + return err + } + + clientset, err := kubernetes.NewForConfig(restconfig) + if err != nil { + return err + } + + provider, err := providers.GetProvider(ctx, log, clientset) if err != nil { return fmt.Errorf("getting provider: %w", err) } @@ -90,16 +100,6 @@ func run(ctx context.Context, log logrus.FieldLogger) (reterr error) { log = log.WithFields(fields) log.Infof("cluster registered: %v", reg) - restconfig, err := retrieveKubeConfig(log) - if err != nil { - return err - } - - clientset, err := kubernetes.NewForConfig(restconfig) - if err != nil { - return err - } - wait.Until(func() { v, err := version.Get(log, clientset) if err != nil {