From a496730bca30d62ed575d460faafa19df6890303 Mon Sep 17 00:00:00 2001 From: Laimonas Rastenis Date: Fri, 20 Oct 2023 15:22:16 +0300 Subject: [PATCH] [castai-agent] feat: conditional EKS node lifecycle discovery and default lifecycle --- internal/config/config.go | 13 ++- internal/services/providers/eks/eks.go | 40 +++++---- internal/services/providers/eks/eks_test.go | 94 ++++++++++++++------- internal/services/providers/providers.go | 8 +- pkg/node/node.go | 8 ++ 5 files changed, 111 insertions(+), 52 deletions(-) create mode 100644 pkg/node/node.go diff --git a/internal/config/config.go b/internal/config/config.go index 9c845b9b..b037199b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/samber/lo" "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -64,10 +65,11 @@ type API struct { } type EKS struct { - AccountID string `mapstructure:"account_id"` - Region string `mapstructure:"region"` - ClusterName string `mapstructure:"cluster_name"` - APITimeout time.Duration `mapstructure:"api_timeout"` + AccountID string `mapstructure:"account_id"` + Region string `mapstructure:"region"` + ClusterName string `mapstructure:"cluster_name"` + APITimeout time.Duration `mapstructure:"api_timeout"` + APINodeLifecycleDiscoveryEnabled *bool `mapstructure:"api_node_lifecycle_discovery_enabled"` } type GKE struct { @@ -193,6 +195,9 @@ func Get() Config { if cfg.EKS.APITimeout <= 0 { cfg.EKS.APITimeout = 120 * time.Second } + if cfg.EKS.APINodeLifecycleDiscoveryEnabled == nil { + cfg.EKS.APINodeLifecycleDiscoveryEnabled = lo.ToPtr(true) + } } if cfg.KOPS != nil { diff --git a/internal/services/providers/eks/eks.go b/internal/services/providers/eks/eks.go index d45e5376..0ede83ba 100644 --- a/internal/services/providers/eks/eks.go +++ b/internal/services/providers/eks/eks.go @@ -32,7 +32,7 @@ const ( ) // New configures and returns an EKS provider. -func New(ctx context.Context, log logrus.FieldLogger) (types.Provider, error) { +func New(ctx context.Context, log logrus.FieldLogger, apiNodeLifecycleDiscoveryEnabled bool) (types.Provider, error) { var opts []client.Opt if cfg := config.Get().EKS; cfg != nil { @@ -50,16 +50,18 @@ func New(ctx context.Context, log logrus.FieldLogger) (types.Provider, error) { } return &Provider{ - log: log, - awsClient: awsClient, - spotCache: map[string]bool{}, + log: log, + awsClient: awsClient, + apiNodeLifecycleDiscoveryEnabled: apiNodeLifecycleDiscoveryEnabled, + spotCache: map[string]bool{}, }, nil } // Provider is the EKS implementation of the providers.Provider interface. type Provider struct { - log logrus.FieldLogger - awsClient client.Client + log logrus.FieldLogger + awsClient client.Client + apiNodeLifecycleDiscoveryEnabled bool spotCache map[string]bool } @@ -128,22 +130,24 @@ func (p *Provider) FilterSpot(ctx context.Context, nodes []*v1.Node) ([]*v1.Node nodesByInstanceID[instanceID] = node } - if len(instanceIDs) > 0 { - instances, err := p.awsClient.GetInstancesByInstanceIDs(ctx, instanceIDs) - if err != nil { - return nil, fmt.Errorf("getting instances by instance IDs: %w", err) - } + if len(instanceIDs) == 0 || !p.apiNodeLifecycleDiscoveryEnabled { + return ret, nil + } - for _, instance := range instances { - isSpot := instance.InstanceLifecycle != nil && *instance.InstanceLifecycle == "spot" - instanceID := *instance.InstanceId + instances, err := p.awsClient.GetInstancesByInstanceIDs(ctx, instanceIDs) + if err != nil { + return nil, fmt.Errorf("getting instances by instance IDs: %w", err) + } - if isSpot { - ret = append(ret, nodesByInstanceID[instanceID]) - } + for _, instance := range instances { + isSpot := instance.InstanceLifecycle != nil && *instance.InstanceLifecycle == "spot" + instanceID := *instance.InstanceId - p.spotCache[instanceID] = isSpot + if isSpot { + ret = append(ret, nodesByInstanceID[instanceID]) } + + p.spotCache[instanceID] = isSpot } return ret, nil diff --git a/internal/services/providers/eks/eks_test.go b/internal/services/providers/eks/eks_test.go index a285bc9d..96fd3c1e 100644 --- a/internal/services/providers/eks/eks_test.go +++ b/internal/services/providers/eks/eks_test.go @@ -21,6 +21,7 @@ import ( ) func TestProvider_RegisterCluster(t *testing.T) { + r := require.New(t) ctx := context.Background() mockctrl := gomock.NewController(t) awsClient := mock_client.NewMockClient(mockctrl) @@ -56,18 +57,20 @@ func TestProvider_RegisterCluster(t *testing.T) { got, err := p.RegisterCluster(ctx, castClient) - require.NoError(t, err) - require.Equal(t, expected, got) + r.NoError(err) + r.Equal(expected, got) } func TestProvider_IsSpot(t *testing.T) { t.Run("spot instance capacity label", func(t *testing.T) { + r := require.New(t) awsClient := mock_client.NewMockClient(gomock.NewController(t)) p := &Provider{ - log: logrus.New(), - awsClient: awsClient, - spotCache: map[string]bool{}, + log: logrus.New(), + awsClient: awsClient, + apiNodeLifecycleDiscoveryEnabled: true, + spotCache: map[string]bool{}, } node := &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{ @@ -76,17 +79,19 @@ func TestProvider_IsSpot(t *testing.T) { got, err := p.FilterSpot(context.Background(), []*v1.Node{node}) - require.NoError(t, err) - require.Equal(t, []*v1.Node{node}, got) + r.NoError(err) + r.Equal([]*v1.Node{node}, got) }) t.Run("spot instance CAST AI label", func(t *testing.T) { + r := require.New(t) awsClient := mock_client.NewMockClient(gomock.NewController(t)) p := &Provider{ - log: logrus.New(), - awsClient: awsClient, - spotCache: map[string]bool{}, + log: logrus.New(), + awsClient: awsClient, + apiNodeLifecycleDiscoveryEnabled: true, + spotCache: map[string]bool{}, } node := &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{ @@ -95,17 +100,19 @@ func TestProvider_IsSpot(t *testing.T) { got, err := p.FilterSpot(context.Background(), []*v1.Node{node}) - require.NoError(t, err) - require.Equal(t, []*v1.Node{node}, got) + r.NoError(err) + r.Equal([]*v1.Node{node}, got) }) t.Run("spot instance lifecycle response", func(t *testing.T) { + r := require.New(t) awsClient := mock_client.NewMockClient(gomock.NewController(t)) p := &Provider{ - log: logrus.New(), - awsClient: awsClient, - spotCache: map[string]bool{}, + log: logrus.New(), + awsClient: awsClient, + apiNodeLifecycleDiscoveryEnabled: true, + spotCache: map[string]bool{}, } awsClient.EXPECT().GetInstancesByInstanceIDs(gomock.Any(), []string{"instanceID"}).Return([]*ec2.Instance{ @@ -123,22 +130,24 @@ func TestProvider_IsSpot(t *testing.T) { got, err := p.FilterSpot(context.Background(), []*v1.Node{node}) - require.NoError(t, err) - require.Equal(t, []*v1.Node{node}, got) + r.NoError(err) + r.Equal([]*v1.Node{node}, got) got, err = p.FilterSpot(context.Background(), []*v1.Node{node}) - require.NoError(t, err) - require.Equal(t, []*v1.Node{node}, got) + r.NoError(err) + r.Equal([]*v1.Node{node}, got) }) t.Run("on-demand instance", func(t *testing.T) { + r := require.New(t) awsClient := mock_client.NewMockClient(gomock.NewController(t)) p := &Provider{ - log: logrus.New(), - awsClient: awsClient, - spotCache: map[string]bool{}, + log: logrus.New(), + awsClient: awsClient, + apiNodeLifecycleDiscoveryEnabled: true, + spotCache: map[string]bool{}, } awsClient.EXPECT().GetInstancesByInstanceIDs(gomock.Any(), []string{"instanceID"}).Return([]*ec2.Instance{ @@ -156,17 +165,19 @@ func TestProvider_IsSpot(t *testing.T) { got, err := p.FilterSpot(context.Background(), []*v1.Node{node}) - require.NoError(t, err) - require.Empty(t, got) + r.NoError(err) + r.Empty(got) }) t.Run("should not perform call out to AWS API if node types can be determined using labels", func(t *testing.T) { + r := require.New(t) awsClient := mock_client.NewMockClient(gomock.NewController(t)) p := &Provider{ - log: logrus.New(), - awsClient: awsClient, - spotCache: map[string]bool{}, + log: logrus.New(), + awsClient: awsClient, + apiNodeLifecycleDiscoveryEnabled: true, + spotCache: map[string]bool{}, } nodeCastaiSpot := &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{ @@ -192,7 +203,32 @@ func TestProvider_IsSpot(t *testing.T) { got, err := p.FilterSpot(context.Background(), []*v1.Node{nodeCastaiSpot, nodeCastaiSpotFallback, nodeKarpenterSpot, nodeKarpenterOnDemand, nodeEKSSpot, nodeEKSOnDemand}) - require.NoError(t, err) - require.Equal(t, []*v1.Node{nodeCastaiSpot, nodeKarpenterSpot, nodeEKSSpot}, got) + r.NoError(err) + r.Equal([]*v1.Node{nodeCastaiSpot, nodeKarpenterSpot, nodeEKSSpot}, got) + }) + + t.Run("should consider on-demand node lifecycle when node lifecycle could not be discovered using labels and API lifecycle discovery is disabled", func(t *testing.T) { + r := require.New(t) + awsClient := mock_client.NewMockClient(gomock.NewController(t)) + + p := &Provider{ + log: logrus.New(), + awsClient: awsClient, + apiNodeLifecycleDiscoveryEnabled: false, + spotCache: map[string]bool{}, + } + + awsClient.EXPECT().GetInstancesByInstanceIDs(gomock.Any(), gomock.Any()).Times(0) + + node := &v1.Node{ + Spec: v1.NodeSpec{ + ProviderID: "aws:///eu-west-1a/instanceID", + }, + } + + got, err := p.FilterSpot(context.Background(), []*v1.Node{node}) + + r.NoError(err) + r.Empty(got) }) } diff --git a/internal/services/providers/providers.go b/internal/services/providers/providers.go index 32f3c6b0..42c41224 100644 --- a/internal/services/providers/providers.go +++ b/internal/services/providers/providers.go @@ -21,7 +21,13 @@ func GetProvider(ctx context.Context, log logrus.FieldLogger, discoveryService d cfg := config.Get() if cfg.Provider == eks.Name || cfg.EKS != nil { - return eks.New(ctx, log.WithField("provider", eks.Name)) + eksProviderLogger := log.WithField("provider", eks.Name) + apiNodeLifecycleDiscoveryEnabled := *cfg.EKS.APINodeLifecycleDiscoveryEnabled + if !apiNodeLifecycleDiscoveryEnabled { + eksProviderLogger.Info("node lifecycle discovery through AWS API is disabled - all nodes without spot labels will be considered on-demand") + } + + return eks.New(ctx, eksProviderLogger, apiNodeLifecycleDiscoveryEnabled) } if cfg.Provider == gke.Name || cfg.GKE != nil { diff --git a/pkg/node/node.go b/pkg/node/node.go new file mode 100644 index 00000000..8a0a6e52 --- /dev/null +++ b/pkg/node/node.go @@ -0,0 +1,8 @@ +package node + +type NodeLifecycle string + +const ( + NodeLifecycleOnDemand NodeLifecycle = "on-demand" + NodeLifecycleSpot NodeLifecycle = "spot" +)