From a07f29ff80e1735b2dadbe7c6200a1faad4522c4 Mon Sep 17 00:00:00 2001 From: hugoShaka Date: Wed, 20 Nov 2024 17:31:50 -0500 Subject: [PATCH] add failover trigger and versionGetter --- lib/automaticupgrades/maintenance/mock.go | 3 +- lib/automaticupgrades/maintenance/trigger.go | 52 +++++- .../maintenance/trigger_test.go | 150 ++++++++++++++++++ lib/automaticupgrades/version/versionget.go | 29 +++- .../version/versionget_test.go | 98 +++++++++++- 5 files changed, 326 insertions(+), 6 deletions(-) create mode 100644 lib/automaticupgrades/maintenance/trigger_test.go diff --git a/lib/automaticupgrades/maintenance/mock.go b/lib/automaticupgrades/maintenance/mock.go index f46b990ee7930..f705bcee71f8b 100644 --- a/lib/automaticupgrades/maintenance/mock.go +++ b/lib/automaticupgrades/maintenance/mock.go @@ -29,6 +29,7 @@ import ( type StaticTrigger struct { name string canStart bool + err error } // Name returns the StaticTrigger name. @@ -38,7 +39,7 @@ func (m StaticTrigger) Name() string { // CanStart returns the statically defined maintenance approval result. func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) { - return m.canStart, nil + return m.canStart, m.err } // Default returns the default behavior if the trigger fails. This cannot diff --git a/lib/automaticupgrades/maintenance/trigger.go b/lib/automaticupgrades/maintenance/trigger.go index 53e12b26cdd4a..bd68ac052b2df 100644 --- a/lib/automaticupgrades/maintenance/trigger.go +++ b/lib/automaticupgrades/maintenance/trigger.go @@ -20,6 +20,8 @@ package maintenance import ( "context" + "github.com/gravitational/trace" + "strings" "sigs.k8s.io/controller-runtime/pkg/client" ctrllog "sigs.k8s.io/controller-runtime/pkg/log" @@ -51,7 +53,10 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool { start, err := trigger.CanStart(ctx, object) if err != nil { start = trigger.Default() - log.Error(err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", start) + log.Error( + err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", + start, + ) } else { log.Info("trigger evaluated", "trigger", trigger.Name(), "result", start) } @@ -62,3 +67,48 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool { } return false } + +// FailoverTrigger wraps multiple Triggers and tries them sequentially. +// Any error is considered fatal, except for the trace.NotImplementedErr +// which indicates the trigger is not supported yet and we should +// failover to the next trigger. +type FailoverTrigger []Trigger + +// Name implements Trigger +func (f FailoverTrigger) Name() string { + names := make([]string, len(f)) + for i, t := range f { + names[i] = t.Name() + } + + return strings.Join(names, ", failover ") +} + +// CanStart implements Trigger +// Triggers are evaluated sequentially, the result of the first trigger not returning +// trace.NotImplementedErr is used. +func (f FailoverTrigger) CanStart(ctx context.Context, object client.Object) (bool, error) { + for _, trigger := range f { + canStart, err := trigger.CanStart(ctx, object) + switch { + case err == nil: + return canStart, nil + case trace.IsNotImplemented(err): + continue + default: + return false, trace.Wrap(err) + } + } + return false, trace.NotFound("every trigger returned NotImplemented") +} + +// Default implements Trigger. +// The default is the logical OR of every Trigger.Default. +func (f FailoverTrigger) Default() bool { + for _, trigger := range f { + if trigger.Default() { + return true + } + } + return false +} diff --git a/lib/automaticupgrades/maintenance/trigger_test.go b/lib/automaticupgrades/maintenance/trigger_test.go new file mode 100644 index 0000000000000..f0738003505d3 --- /dev/null +++ b/lib/automaticupgrades/maintenance/trigger_test.go @@ -0,0 +1,150 @@ +package maintenance + +import ( + "context" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "testing" +) + +// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc +func checkTraceError(check func(error) bool) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, i ...interface{}) { + require.True(t, check(err), i...) + } +} + +func TestFailoverTrigger_CanStart(t *testing.T) { + t.Parallel() + + // Test setup + ctx := context.Background() + tests := []struct { + name string + triggers []Trigger + expectResult bool + expectErr require.ErrorAssertionFunc + }{ + { + name: "nil", + triggers: nil, + expectResult: false, + expectErr: checkTraceError(trace.IsNotFound), + }, + { + name: "empty", + triggers: []Trigger{}, + expectResult: false, + expectErr: checkTraceError(trace.IsNotFound), + }, + { + name: "first trigger success firing", + triggers: []Trigger{ + StaticTrigger{canStart: true}, + StaticTrigger{canStart: false}, + }, + expectResult: true, + expectErr: require.NoError, + }, + { + name: "first trigger success not firing", + triggers: []Trigger{ + StaticTrigger{canStart: false}, + StaticTrigger{canStart: true}, + }, + expectResult: false, + expectErr: require.NoError, + }, + { + name: "first trigger failure", + triggers: []Trigger{ + StaticTrigger{err: trace.LimitExceeded("got rate-limited")}, + StaticTrigger{canStart: true}, + }, + expectResult: false, + expectErr: checkTraceError(trace.IsLimitExceeded), + }, + { + name: "first trigger skipped, second getter success", + triggers: []Trigger{ + StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticTrigger{canStart: true}, + }, + expectResult: true, + expectErr: require.NoError, + }, + { + name: "first trigger skipped, second getter failure", + triggers: []Trigger{ + StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticTrigger{err: trace.LimitExceeded("got rate-limited")}, + }, + expectResult: false, + expectErr: checkTraceError(trace.IsLimitExceeded), + }, + { + name: "first trigger skipped, second getter skipped", + triggers: []Trigger{ + StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + }, + expectResult: false, + expectErr: checkTraceError(trace.IsNotFound), + }, + } + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + // Test execution + trigger := FailoverTrigger(tt.triggers) + result, err := trigger.CanStart(ctx, nil) + require.Equal(t, tt.expectResult, result) + tt.expectErr(t, err) + }, + ) + } +} + +func TestFailoverTrigger_Name(t *testing.T) { + tests := []struct { + name string + triggers []Trigger + expectResult string + }{ + { + name: "nil", + triggers: nil, + expectResult: "", + }, + { + name: "empty", + triggers: []Trigger{}, + expectResult: "", + }, + { + name: "one trigger", + triggers: []Trigger{ + StaticTrigger{name: "proxy"}, + }, + expectResult: "proxy", + }, + { + name: "two triggers", + triggers: []Trigger{ + StaticTrigger{name: "proxy"}, + StaticTrigger{name: "version-server"}, + }, + expectResult: "proxy, failover version-server", + }, + } + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + // Test execution + trigger := FailoverTrigger(tt.triggers) + result := trigger.Name() + require.Equal(t, tt.expectResult, result) + }, + ) + } +} diff --git a/lib/automaticupgrades/version/versionget.go b/lib/automaticupgrades/version/versionget.go index f1e7723a9a320..f02ca71a754a5 100644 --- a/lib/automaticupgrades/version/versionget.go +++ b/lib/automaticupgrades/version/versionget.go @@ -36,13 +36,40 @@ type Getter interface { GetVersion(context.Context) (string, error) } +// FailoverGetter wraps multiple Getters and tries them sequentially. +// Any error is considered fatal, except for the trace.NotImplementedErr +// which indicates the version getter is not supported yet and we should +// failover to the next version getter. +type FailoverGetter []Getter + +// GetVersion implements Getter +// Getters are evaluated sequentially, the result of the first getter not returning +// trace.NotImplementedErr is used. +func (f FailoverGetter) GetVersion(ctx context.Context) (string, error) { + for _, getter := range f { + version, err := getter.GetVersion(ctx) + switch { + case err == nil: + return version, nil + case trace.IsNotImplemented(err): + continue + default: + return "", trace.Wrap(err) + } + } + return "", trace.NotFound("every versionGetter returned NotImplemented") +} + // ValidVersionChange receives the current version and the candidate next version // and evaluates if the version transition is valid. func ValidVersionChange(ctx context.Context, current, next string) bool { log := ctrllog.FromContext(ctx).V(1) // Cannot upgrade to a non-valid version if !semver.IsValid(next) { - log.Error(trace.BadParameter("next version is not following semver"), "version change is invalid", "nextVersion", next) + log.Error( + trace.BadParameter("next version is not following semver"), "version change is invalid", "nextVersion", + next, + ) return false } switch semver.Compare(next, current) { diff --git a/lib/automaticupgrades/version/versionget_test.go b/lib/automaticupgrades/version/versionget_test.go index 80c2ec767b8fb..32214b27e7985 100644 --- a/lib/automaticupgrades/version/versionget_test.go +++ b/lib/automaticupgrades/version/versionget_test.go @@ -20,6 +20,7 @@ package version import ( "context" + "github.com/gravitational/trace" "testing" "github.com/stretchr/testify/require" @@ -66,8 +67,99 @@ func TestValidVersionChange(t *testing.T) { }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next)) - }) + t.Run( + tt.name, func(t *testing.T) { + require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next)) + }, + ) + } +} + +// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc +func checkTraceError(check func(error) bool) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, i ...interface{}) { + require.True(t, check(err), i...) + } +} + +func TestFailoverGetter_GetVersion(t *testing.T) { + t.Parallel() + + // Test setup + ctx := context.Background() + tests := []struct { + name string + getters []Getter + expectResult string + expectErr require.ErrorAssertionFunc + }{ + { + name: "nil", + getters: nil, + expectResult: "", + expectErr: checkTraceError(trace.IsNotFound), + }, + { + name: "empty", + getters: []Getter{}, + expectResult: "", + expectErr: checkTraceError(trace.IsNotFound), + }, + { + name: "first getter success", + getters: []Getter{ + StaticGetter{version: semverMid}, + StaticGetter{version: semverHigh}, + }, + expectResult: semverMid, + expectErr: require.NoError, + }, + { + name: "first getter failure", + getters: []Getter{ + StaticGetter{err: trace.LimitExceeded("got rate-limited")}, + StaticGetter{version: semverHigh}, + }, + expectResult: "", + expectErr: checkTraceError(trace.IsLimitExceeded), + }, + { + name: "first getter skipped, second getter success", + getters: []Getter{ + StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticGetter{version: semverHigh}, + }, + expectResult: semverHigh, + expectErr: require.NoError, + }, + { + name: "first getter skipped, second getter failure", + getters: []Getter{ + StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticGetter{err: trace.LimitExceeded("got rate-limited")}, + }, + expectResult: "", + expectErr: checkTraceError(trace.IsLimitExceeded), + }, + { + name: "first getter skipped, second getter skipped", + getters: []Getter{ + StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")}, + }, + expectResult: "", + expectErr: checkTraceError(trace.IsNotFound), + }, + } + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + // Test execution + getter := FailoverGetter(tt.getters) + result, err := getter.GetVersion(ctx) + require.Equal(t, tt.expectResult, result) + tt.expectErr(t, err) + }, + ) } }