diff --git a/internal/config/config.go b/internal/config/config.go index b037199b..5f299f45 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -43,6 +43,8 @@ type Mode string const ( ModeAgent Mode = "agent" ModeMonitor Mode = "monitor" + + DefaultAPINodeLifecycleDiscoveryEnabled = true ) type TLS struct { @@ -196,7 +198,7 @@ func Get() Config { cfg.EKS.APITimeout = 120 * time.Second } if cfg.EKS.APINodeLifecycleDiscoveryEnabled == nil { - cfg.EKS.APINodeLifecycleDiscoveryEnabled = lo.ToPtr(true) + cfg.EKS.APINodeLifecycleDiscoveryEnabled = lo.ToPtr(DefaultAPINodeLifecycleDiscoveryEnabled) } } diff --git a/internal/services/providers/providers.go b/internal/services/providers/providers.go index 42c41224..68398e18 100644 --- a/internal/services/providers/providers.go +++ b/internal/services/providers/providers.go @@ -22,7 +22,8 @@ func GetProvider(ctx context.Context, log logrus.FieldLogger, discoveryService d if cfg.Provider == eks.Name || cfg.EKS != nil { eksProviderLogger := log.WithField("provider", eks.Name) - apiNodeLifecycleDiscoveryEnabled := *cfg.EKS.APINodeLifecycleDiscoveryEnabled + apiNodeLifecycleDiscoveryEnabled := isAPINodeLifecycleDiscoveryEnabled(cfg) + if !apiNodeLifecycleDiscoveryEnabled { eksProviderLogger.Info("node lifecycle discovery through AWS API is disabled - all nodes without spot labels will be considered on-demand") } @@ -48,3 +49,11 @@ func GetProvider(ctx context.Context, log logrus.FieldLogger, discoveryService d return nil, fmt.Errorf("unknown provider %q", cfg.Provider) } + +func isAPINodeLifecycleDiscoveryEnabled(cfg config.Config) bool { + if cfg.EKS != nil && cfg.EKS.APINodeLifecycleDiscoveryEnabled != nil { + return *cfg.EKS.APINodeLifecycleDiscoveryEnabled + } + + return config.DefaultAPINodeLifecycleDiscoveryEnabled +} diff --git a/internal/services/providers/providers_test.go b/internal/services/providers/providers_test.go index 71808960..f7579c33 100644 --- a/internal/services/providers/providers_test.go +++ b/internal/services/providers/providers_test.go @@ -5,6 +5,7 @@ import ( "os" "testing" + "github.com/samber/lo" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" @@ -86,3 +87,39 @@ func TestGetProvider(t *testing.T) { r.IsType(&openshift.Provider{}, got) }) } + +func Test_isAPINodeLifecycleDiscoveryEnabled(t *testing.T) { + tests := map[string]struct { + cfg config.Config + want bool + }{ + "should use default node lifecycle discovery value when EKS config is nil": { + cfg: config.Config{}, + want: true, + }, + "should use default node lifecycle discovery value when EKS config is not nil and config value is nil": { + cfg: config.Config{ + EKS: &config.EKS{APINodeLifecycleDiscoveryEnabled: nil}, + }, + want: true, + }, + "should use node lifecycle discovery value from config when it is configured": { + cfg: config.Config{ + EKS: &config.EKS{ + APINodeLifecycleDiscoveryEnabled: lo.ToPtr(false), + }, + }, + want: false, + }, + } + + for testName, tt := range tests { + t.Run(testName, func(t *testing.T) { + r := require.New(t) + + got := isAPINodeLifecycleDiscoveryEnabled(tt.cfg) + + r.Equal(tt.want, got) + }) + } +}