Skip to content

Commit

Permalink
add failover trigger and versionGetter
Browse files Browse the repository at this point in the history
  • Loading branch information
hugoShaka committed Nov 20, 2024
1 parent 6d5d6d5 commit a07f29f
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 6 deletions.
3 changes: 2 additions & 1 deletion lib/automaticupgrades/maintenance/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
type StaticTrigger struct {
name string
canStart bool
err error
}

// Name returns the StaticTrigger name.
Expand All @@ -38,7 +39,7 @@ func (m StaticTrigger) Name() string {

// CanStart returns the statically defined maintenance approval result.
func (m StaticTrigger) CanStart(_ context.Context, _ client.Object) (bool, error) {
return m.canStart, nil
return m.canStart, m.err
}

// Default returns the default behavior if the trigger fails. This cannot
Expand Down
52 changes: 51 additions & 1 deletion lib/automaticupgrades/maintenance/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package maintenance

import (
"context"
"github.com/gravitational/trace"
"strings"

"sigs.k8s.io/controller-runtime/pkg/client"
ctrllog "sigs.k8s.io/controller-runtime/pkg/log"
Expand Down Expand Up @@ -51,7 +53,10 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool {
start, err := trigger.CanStart(ctx, object)
if err != nil {
start = trigger.Default()
log.Error(err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue", start)
log.Error(
err, "trigger failed to evaluate, using its default value", "trigger", trigger.Name(), "defaultValue",
start,
)
} else {
log.Info("trigger evaluated", "trigger", trigger.Name(), "result", start)
}
Expand All @@ -62,3 +67,48 @@ func (t Triggers) CanStart(ctx context.Context, object client.Object) bool {
}
return false
}

// FailoverTrigger wraps multiple Triggers and tries them sequentially.
// Any error is considered fatal, except for the trace.NotImplementedErr
// which indicates the trigger is not supported yet and we should
// failover to the next trigger.
type FailoverTrigger []Trigger

// Name implements Trigger
func (f FailoverTrigger) Name() string {
names := make([]string, len(f))
for i, t := range f {
names[i] = t.Name()
}

return strings.Join(names, ", failover ")
}

// CanStart implements Trigger
// Triggers are evaluated sequentially, the result of the first trigger not returning
// trace.NotImplementedErr is used.
func (f FailoverTrigger) CanStart(ctx context.Context, object client.Object) (bool, error) {
for _, trigger := range f {
canStart, err := trigger.CanStart(ctx, object)
switch {
case err == nil:
return canStart, nil
case trace.IsNotImplemented(err):
continue
default:
return false, trace.Wrap(err)
}
}
return false, trace.NotFound("every trigger returned NotImplemented")
}

// Default implements Trigger.
// The default is the logical OR of every Trigger.Default.
func (f FailoverTrigger) Default() bool {
for _, trigger := range f {
if trigger.Default() {
return true
}
}
return false
}
150 changes: 150 additions & 0 deletions lib/automaticupgrades/maintenance/trigger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package maintenance

import (
"context"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"testing"
)

// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc
func checkTraceError(check func(error) bool) require.ErrorAssertionFunc {
return func(t require.TestingT, err error, i ...interface{}) {
require.True(t, check(err), i...)
}
}

func TestFailoverTrigger_CanStart(t *testing.T) {
t.Parallel()

// Test setup
ctx := context.Background()
tests := []struct {
name string
triggers []Trigger
expectResult bool
expectErr require.ErrorAssertionFunc
}{
{
name: "nil",
triggers: nil,
expectResult: false,
expectErr: checkTraceError(trace.IsNotFound),
},
{
name: "empty",
triggers: []Trigger{},
expectResult: false,
expectErr: checkTraceError(trace.IsNotFound),
},
{
name: "first trigger success firing",
triggers: []Trigger{
StaticTrigger{canStart: true},
StaticTrigger{canStart: false},
},
expectResult: true,
expectErr: require.NoError,
},
{
name: "first trigger success not firing",
triggers: []Trigger{
StaticTrigger{canStart: false},
StaticTrigger{canStart: true},
},
expectResult: false,
expectErr: require.NoError,
},
{
name: "first trigger failure",
triggers: []Trigger{
StaticTrigger{err: trace.LimitExceeded("got rate-limited")},
StaticTrigger{canStart: true},
},
expectResult: false,
expectErr: checkTraceError(trace.IsLimitExceeded),
},
{
name: "first trigger skipped, second getter success",
triggers: []Trigger{
StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticTrigger{canStart: true},
},
expectResult: true,
expectErr: require.NoError,
},
{
name: "first trigger skipped, second getter failure",
triggers: []Trigger{
StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticTrigger{err: trace.LimitExceeded("got rate-limited")},
},
expectResult: false,
expectErr: checkTraceError(trace.IsLimitExceeded),
},
{
name: "first trigger skipped, second getter skipped",
triggers: []Trigger{
StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticTrigger{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
},
expectResult: false,
expectErr: checkTraceError(trace.IsNotFound),
},
}
for _, tt := range tests {
t.Run(
tt.name, func(t *testing.T) {
// Test execution
trigger := FailoverTrigger(tt.triggers)
result, err := trigger.CanStart(ctx, nil)
require.Equal(t, tt.expectResult, result)
tt.expectErr(t, err)
},
)
}
}

func TestFailoverTrigger_Name(t *testing.T) {
tests := []struct {
name string
triggers []Trigger
expectResult string
}{
{
name: "nil",
triggers: nil,
expectResult: "",
},
{
name: "empty",
triggers: []Trigger{},
expectResult: "",
},
{
name: "one trigger",
triggers: []Trigger{
StaticTrigger{name: "proxy"},
},
expectResult: "proxy",
},
{
name: "two triggers",
triggers: []Trigger{
StaticTrigger{name: "proxy"},
StaticTrigger{name: "version-server"},
},
expectResult: "proxy, failover version-server",
},
}
for _, tt := range tests {
t.Run(
tt.name, func(t *testing.T) {
// Test execution
trigger := FailoverTrigger(tt.triggers)
result := trigger.Name()
require.Equal(t, tt.expectResult, result)
},
)
}
}
29 changes: 28 additions & 1 deletion lib/automaticupgrades/version/versionget.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,40 @@ type Getter interface {
GetVersion(context.Context) (string, error)
}

// FailoverGetter wraps multiple Getters and tries them sequentially.
// Any error is considered fatal, except for the trace.NotImplementedErr
// which indicates the version getter is not supported yet and we should
// failover to the next version getter.
type FailoverGetter []Getter

// GetVersion implements Getter
// Getters are evaluated sequentially, the result of the first getter not returning
// trace.NotImplementedErr is used.
func (f FailoverGetter) GetVersion(ctx context.Context) (string, error) {
for _, getter := range f {
version, err := getter.GetVersion(ctx)
switch {
case err == nil:
return version, nil
case trace.IsNotImplemented(err):
continue
default:
return "", trace.Wrap(err)
}
}
return "", trace.NotFound("every versionGetter returned NotImplemented")
}

// ValidVersionChange receives the current version and the candidate next version
// and evaluates if the version transition is valid.
func ValidVersionChange(ctx context.Context, current, next string) bool {
log := ctrllog.FromContext(ctx).V(1)
// Cannot upgrade to a non-valid version
if !semver.IsValid(next) {
log.Error(trace.BadParameter("next version is not following semver"), "version change is invalid", "nextVersion", next)
log.Error(
trace.BadParameter("next version is not following semver"), "version change is invalid", "nextVersion",
next,
)
return false
}
switch semver.Compare(next, current) {
Expand Down
98 changes: 95 additions & 3 deletions lib/automaticupgrades/version/versionget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package version

import (
"context"
"github.com/gravitational/trace"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -66,8 +67,99 @@ func TestValidVersionChange(t *testing.T) {
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next))
})
t.Run(
tt.name, func(t *testing.T) {
require.Equal(t, tt.want, ValidVersionChange(ctx, tt.current, tt.next))
},
)
}
}

// checkTraceError is a test helper that converts trace.IsXXXError into a require.ErrorAssertionFunc
func checkTraceError(check func(error) bool) require.ErrorAssertionFunc {
return func(t require.TestingT, err error, i ...interface{}) {
require.True(t, check(err), i...)
}
}

func TestFailoverGetter_GetVersion(t *testing.T) {
t.Parallel()

// Test setup
ctx := context.Background()
tests := []struct {
name string
getters []Getter
expectResult string
expectErr require.ErrorAssertionFunc
}{
{
name: "nil",
getters: nil,
expectResult: "",
expectErr: checkTraceError(trace.IsNotFound),
},
{
name: "empty",
getters: []Getter{},
expectResult: "",
expectErr: checkTraceError(trace.IsNotFound),
},
{
name: "first getter success",
getters: []Getter{
StaticGetter{version: semverMid},
StaticGetter{version: semverHigh},
},
expectResult: semverMid,
expectErr: require.NoError,
},
{
name: "first getter failure",
getters: []Getter{
StaticGetter{err: trace.LimitExceeded("got rate-limited")},
StaticGetter{version: semverHigh},
},
expectResult: "",
expectErr: checkTraceError(trace.IsLimitExceeded),
},
{
name: "first getter skipped, second getter success",
getters: []Getter{
StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticGetter{version: semverHigh},
},
expectResult: semverHigh,
expectErr: require.NoError,
},
{
name: "first getter skipped, second getter failure",
getters: []Getter{
StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticGetter{err: trace.LimitExceeded("got rate-limited")},
},
expectResult: "",
expectErr: checkTraceError(trace.IsLimitExceeded),
},
{
name: "first getter skipped, second getter skipped",
getters: []Getter{
StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
StaticGetter{err: trace.NotImplemented("proxy does not seem to implement RFD-184")},
},
expectResult: "",
expectErr: checkTraceError(trace.IsNotFound),
},
}
for _, tt := range tests {
t.Run(
tt.name, func(t *testing.T) {
// Test execution
getter := FailoverGetter(tt.getters)
result, err := getter.GetVersion(ctx)
require.Equal(t, tt.expectResult, result)
tt.expectErr(t, err)
},
)
}
}

0 comments on commit a07f29f

Please sign in to comment.