Skip to content

Commit

Permalink
Merge pull request #140 from castai/eks-node-lifecycle-discovery-config
Browse files Browse the repository at this point in the history
[castai-agent] feat: conditional EKS node lifecycle discovery and default lifecycle
  • Loading branch information
laimonasr authored Oct 20, 2023
2 parents 680eecb + a496730 commit 8abd2de
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 52 deletions.
13 changes: 9 additions & 4 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"sync"
"time"

"github.com/samber/lo"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
Expand Down Expand Up @@ -64,10 +65,11 @@ type API struct {
}

type EKS struct {
AccountID string `mapstructure:"account_id"`
Region string `mapstructure:"region"`
ClusterName string `mapstructure:"cluster_name"`
APITimeout time.Duration `mapstructure:"api_timeout"`
AccountID string `mapstructure:"account_id"`
Region string `mapstructure:"region"`
ClusterName string `mapstructure:"cluster_name"`
APITimeout time.Duration `mapstructure:"api_timeout"`
APINodeLifecycleDiscoveryEnabled *bool `mapstructure:"api_node_lifecycle_discovery_enabled"`
}

type GKE struct {
Expand Down Expand Up @@ -193,6 +195,9 @@ func Get() Config {
if cfg.EKS.APITimeout <= 0 {
cfg.EKS.APITimeout = 120 * time.Second
}
if cfg.EKS.APINodeLifecycleDiscoveryEnabled == nil {
cfg.EKS.APINodeLifecycleDiscoveryEnabled = lo.ToPtr(true)
}
}

if cfg.KOPS != nil {
Expand Down
40 changes: 22 additions & 18 deletions internal/services/providers/eks/eks.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const (
)

// New configures and returns an EKS provider.
func New(ctx context.Context, log logrus.FieldLogger) (types.Provider, error) {
func New(ctx context.Context, log logrus.FieldLogger, apiNodeLifecycleDiscoveryEnabled bool) (types.Provider, error) {
var opts []client.Opt

if cfg := config.Get().EKS; cfg != nil {
Expand All @@ -50,16 +50,18 @@ func New(ctx context.Context, log logrus.FieldLogger) (types.Provider, error) {
}

return &Provider{
log: log,
awsClient: awsClient,
spotCache: map[string]bool{},
log: log,
awsClient: awsClient,
apiNodeLifecycleDiscoveryEnabled: apiNodeLifecycleDiscoveryEnabled,
spotCache: map[string]bool{},
}, nil
}

// Provider is the EKS implementation of the providers.Provider interface.
type Provider struct {
log logrus.FieldLogger
awsClient client.Client
log logrus.FieldLogger
awsClient client.Client
apiNodeLifecycleDiscoveryEnabled bool

spotCache map[string]bool
}
Expand Down Expand Up @@ -128,22 +130,24 @@ func (p *Provider) FilterSpot(ctx context.Context, nodes []*v1.Node) ([]*v1.Node
nodesByInstanceID[instanceID] = node
}

if len(instanceIDs) > 0 {
instances, err := p.awsClient.GetInstancesByInstanceIDs(ctx, instanceIDs)
if err != nil {
return nil, fmt.Errorf("getting instances by instance IDs: %w", err)
}
if len(instanceIDs) == 0 || !p.apiNodeLifecycleDiscoveryEnabled {
return ret, nil
}

for _, instance := range instances {
isSpot := instance.InstanceLifecycle != nil && *instance.InstanceLifecycle == "spot"
instanceID := *instance.InstanceId
instances, err := p.awsClient.GetInstancesByInstanceIDs(ctx, instanceIDs)
if err != nil {
return nil, fmt.Errorf("getting instances by instance IDs: %w", err)
}

if isSpot {
ret = append(ret, nodesByInstanceID[instanceID])
}
for _, instance := range instances {
isSpot := instance.InstanceLifecycle != nil && *instance.InstanceLifecycle == "spot"
instanceID := *instance.InstanceId

p.spotCache[instanceID] = isSpot
if isSpot {
ret = append(ret, nodesByInstanceID[instanceID])
}

p.spotCache[instanceID] = isSpot
}

return ret, nil
Expand Down
94 changes: 65 additions & 29 deletions internal/services/providers/eks/eks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
)

func TestProvider_RegisterCluster(t *testing.T) {
r := require.New(t)
ctx := context.Background()
mockctrl := gomock.NewController(t)
awsClient := mock_client.NewMockClient(mockctrl)
Expand Down Expand Up @@ -56,18 +57,20 @@ func TestProvider_RegisterCluster(t *testing.T) {

got, err := p.RegisterCluster(ctx, castClient)

require.NoError(t, err)
require.Equal(t, expected, got)
r.NoError(err)
r.Equal(expected, got)
}

func TestProvider_IsSpot(t *testing.T) {
t.Run("spot instance capacity label", func(t *testing.T) {
r := require.New(t)
awsClient := mock_client.NewMockClient(gomock.NewController(t))

p := &Provider{
log: logrus.New(),
awsClient: awsClient,
spotCache: map[string]bool{},
log: logrus.New(),
awsClient: awsClient,
apiNodeLifecycleDiscoveryEnabled: true,
spotCache: map[string]bool{},
}

node := &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{
Expand All @@ -76,17 +79,19 @@ func TestProvider_IsSpot(t *testing.T) {

got, err := p.FilterSpot(context.Background(), []*v1.Node{node})

require.NoError(t, err)
require.Equal(t, []*v1.Node{node}, got)
r.NoError(err)
r.Equal([]*v1.Node{node}, got)
})

t.Run("spot instance CAST AI label", func(t *testing.T) {
r := require.New(t)
awsClient := mock_client.NewMockClient(gomock.NewController(t))

p := &Provider{
log: logrus.New(),
awsClient: awsClient,
spotCache: map[string]bool{},
log: logrus.New(),
awsClient: awsClient,
apiNodeLifecycleDiscoveryEnabled: true,
spotCache: map[string]bool{},
}

node := &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{
Expand All @@ -95,17 +100,19 @@ func TestProvider_IsSpot(t *testing.T) {

got, err := p.FilterSpot(context.Background(), []*v1.Node{node})

require.NoError(t, err)
require.Equal(t, []*v1.Node{node}, got)
r.NoError(err)
r.Equal([]*v1.Node{node}, got)
})

t.Run("spot instance lifecycle response", func(t *testing.T) {
r := require.New(t)
awsClient := mock_client.NewMockClient(gomock.NewController(t))

p := &Provider{
log: logrus.New(),
awsClient: awsClient,
spotCache: map[string]bool{},
log: logrus.New(),
awsClient: awsClient,
apiNodeLifecycleDiscoveryEnabled: true,
spotCache: map[string]bool{},
}

awsClient.EXPECT().GetInstancesByInstanceIDs(gomock.Any(), []string{"instanceID"}).Return([]*ec2.Instance{
Expand All @@ -123,22 +130,24 @@ func TestProvider_IsSpot(t *testing.T) {

got, err := p.FilterSpot(context.Background(), []*v1.Node{node})

require.NoError(t, err)
require.Equal(t, []*v1.Node{node}, got)
r.NoError(err)
r.Equal([]*v1.Node{node}, got)

got, err = p.FilterSpot(context.Background(), []*v1.Node{node})

require.NoError(t, err)
require.Equal(t, []*v1.Node{node}, got)
r.NoError(err)
r.Equal([]*v1.Node{node}, got)
})

t.Run("on-demand instance", func(t *testing.T) {
r := require.New(t)
awsClient := mock_client.NewMockClient(gomock.NewController(t))

p := &Provider{
log: logrus.New(),
awsClient: awsClient,
spotCache: map[string]bool{},
log: logrus.New(),
awsClient: awsClient,
apiNodeLifecycleDiscoveryEnabled: true,
spotCache: map[string]bool{},
}

awsClient.EXPECT().GetInstancesByInstanceIDs(gomock.Any(), []string{"instanceID"}).Return([]*ec2.Instance{
Expand All @@ -156,17 +165,19 @@ func TestProvider_IsSpot(t *testing.T) {

got, err := p.FilterSpot(context.Background(), []*v1.Node{node})

require.NoError(t, err)
require.Empty(t, got)
r.NoError(err)
r.Empty(got)
})

t.Run("should not perform call out to AWS API if node types can be determined using labels", func(t *testing.T) {
r := require.New(t)
awsClient := mock_client.NewMockClient(gomock.NewController(t))

p := &Provider{
log: logrus.New(),
awsClient: awsClient,
spotCache: map[string]bool{},
log: logrus.New(),
awsClient: awsClient,
apiNodeLifecycleDiscoveryEnabled: true,
spotCache: map[string]bool{},
}

nodeCastaiSpot := &v1.Node{ObjectMeta: metav1.ObjectMeta{Labels: map[string]string{
Expand All @@ -192,7 +203,32 @@ func TestProvider_IsSpot(t *testing.T) {

got, err := p.FilterSpot(context.Background(), []*v1.Node{nodeCastaiSpot, nodeCastaiSpotFallback, nodeKarpenterSpot, nodeKarpenterOnDemand, nodeEKSSpot, nodeEKSOnDemand})

require.NoError(t, err)
require.Equal(t, []*v1.Node{nodeCastaiSpot, nodeKarpenterSpot, nodeEKSSpot}, got)
r.NoError(err)
r.Equal([]*v1.Node{nodeCastaiSpot, nodeKarpenterSpot, nodeEKSSpot}, got)
})

t.Run("should consider on-demand node lifecycle when node lifecycle could not be discovered using labels and API lifecycle discovery is disabled", func(t *testing.T) {
r := require.New(t)
awsClient := mock_client.NewMockClient(gomock.NewController(t))

p := &Provider{
log: logrus.New(),
awsClient: awsClient,
apiNodeLifecycleDiscoveryEnabled: false,
spotCache: map[string]bool{},
}

awsClient.EXPECT().GetInstancesByInstanceIDs(gomock.Any(), gomock.Any()).Times(0)

node := &v1.Node{
Spec: v1.NodeSpec{
ProviderID: "aws:///eu-west-1a/instanceID",
},
}

got, err := p.FilterSpot(context.Background(), []*v1.Node{node})

r.NoError(err)
r.Empty(got)
})
}
8 changes: 7 additions & 1 deletion internal/services/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ func GetProvider(ctx context.Context, log logrus.FieldLogger, discoveryService d
cfg := config.Get()

if cfg.Provider == eks.Name || cfg.EKS != nil {
return eks.New(ctx, log.WithField("provider", eks.Name))
eksProviderLogger := log.WithField("provider", eks.Name)
apiNodeLifecycleDiscoveryEnabled := *cfg.EKS.APINodeLifecycleDiscoveryEnabled
if !apiNodeLifecycleDiscoveryEnabled {
eksProviderLogger.Info("node lifecycle discovery through AWS API is disabled - all nodes without spot labels will be considered on-demand")
}

return eks.New(ctx, eksProviderLogger, apiNodeLifecycleDiscoveryEnabled)
}

if cfg.Provider == gke.Name || cfg.GKE != nil {
Expand Down
8 changes: 8 additions & 0 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package node

type NodeLifecycle string

const (
NodeLifecycleOnDemand NodeLifecycle = "on-demand"
NodeLifecycleSpot NodeLifecycle = "spot"
)

0 comments on commit 8abd2de

Please sign in to comment.