diff --git a/controller/workflow.go b/controller/workflow.go index a053e37..e6976cd 100644 --- a/controller/workflow.go +++ b/controller/workflow.go @@ -2,6 +2,7 @@ package controller import ( "context" + "errors" "sync" "golang.org/x/sync/errgroup" @@ -11,37 +12,50 @@ import ( // Workflow runs an optional precondition reconciliation function, then dispatches the reconciliation event to // a list of concurrent reconciliation tasks, and runs an optional postcondition reconciliation function. +// +// If any of the reconciliation functions returns an error, the error is handled by an optional error handler. +// The error passed to the error handler func is conflated with any ocasional error carried over into the call +// to the workflow in the first place. It is up to the error handler to decide how to handle the error and whether +// to supress it or raise it again. type Workflow struct { Precondition ReconcileFunc Tasks []ReconcileFunc Postcondition ReconcileFunc + ErrorHandler ReconcileFunc } -func (d *Workflow) Run(ctx context.Context, resourceEvents []ResourceEvent, topology *machinery.Topology, err error, state *sync.Map) error { +func (w *Workflow) Run(ctx context.Context, resourceEvents []ResourceEvent, topology *machinery.Topology, err error, state *sync.Map) error { // run precondition reconcile function - if d.Precondition != nil { - if err := d.Precondition(ctx, resourceEvents, topology, err, state); err != nil { - return err + if w.Precondition != nil { + if preconditionErr := w.Precondition(ctx, resourceEvents, topology, err, state); preconditionErr != nil { + return w.handle(ctx, resourceEvents, topology, err, state, preconditionErr) } } // dispatch the event to concurrent tasks g, groupCtx := errgroup.WithContext(ctx) - for _, f := range d.Tasks { + for _, f := range w.Tasks { g.Go(func() error { return f(groupCtx, resourceEvents, topology, err, state) }) } - if err := g.Wait(); err != nil { - return err + if taskErr := g.Wait(); taskErr != nil { + return w.handle(ctx, resourceEvents, topology, err, state, taskErr) } // run precondition reconcile function - if d.Postcondition != nil { - if err := d.Postcondition(ctx, resourceEvents, topology, err, state); err != nil { - return err + if w.Postcondition != nil { + if postconditionErr := w.Postcondition(ctx, resourceEvents, topology, err, state); postconditionErr != nil { + return w.handle(ctx, resourceEvents, topology, err, state, postconditionErr) } } return nil } + +func (w *Workflow) handle(ctx context.Context, resourceEvents []ResourceEvent, topology *machinery.Topology, carryoverErr error, state *sync.Map, workflowErr error) error { + if w.ErrorHandler != nil { + return w.ErrorHandler(ctx, resourceEvents, topology, errors.Join(carryoverErr, workflowErr), state) + } + return workflowErr +} diff --git a/controller/workflow_test.go b/controller/workflow_test.go index da43f77..7642680 100644 --- a/controller/workflow_test.go +++ b/controller/workflow_test.go @@ -13,7 +13,6 @@ import ( ) func TestWorkflow(t *testing.T) { - reconcileFuncFor := func(flag *bool, err error) ReconcileFunc { return func(context.Context, []ResourceEvent, *machinery.Topology, error, *sync.Map) error { *flag = true @@ -21,7 +20,7 @@ func TestWorkflow(t *testing.T) { } } - var preconditionCalled, task1Called, task2Called, postconditionCalled bool + var preconditionCalled, task1Called, task2Called, postconditionCalled, errorHandled bool precondition := reconcileFuncFor(&preconditionCalled, nil) preconditionWithError := reconcileFuncFor(&preconditionCalled, fmt.Errorf("precondition error")) @@ -32,6 +31,16 @@ func TestWorkflow(t *testing.T) { postcondition := reconcileFuncFor(&postconditionCalled, nil) postconditionWithError := reconcileFuncFor(&postconditionCalled, fmt.Errorf("postcondition error")) + handleErrorAndSupress := func(context.Context, []ResourceEvent, *machinery.Topology, error, *sync.Map) error { + errorHandled = true + return nil + } + + handleErrorAndRaise := func(_ context.Context, _ []ResourceEvent, _ *machinery.Topology, err error, _ *sync.Map) error { + errorHandled = true + return err + } + testCases := []struct { name string workflow *Workflow @@ -40,6 +49,7 @@ func TestWorkflow(t *testing.T) { expectedTask2Called bool expectedPostconditionCalled bool possibleErrs []error + expectedErrorHandled bool }{ { name: "empty workflow", @@ -134,6 +144,25 @@ func TestWorkflow(t *testing.T) { expectedPostconditionCalled: true, possibleErrs: []error{fmt.Errorf("postcondition error")}, }, + { + name: "handle error and suppress", + workflow: &Workflow{ + Precondition: preconditionWithError, + ErrorHandler: handleErrorAndSupress, + }, + expectedPreconditionCalled: true, + expectedErrorHandled: true, + }, + { + name: "handle error and raise", + workflow: &Workflow{ + Precondition: preconditionWithError, + ErrorHandler: handleErrorAndRaise, + }, + expectedPreconditionCalled: true, + expectedErrorHandled: true, + possibleErrs: []error{fmt.Errorf("precondition error")}, + }, } for _, tc := range testCases { @@ -143,6 +172,7 @@ func TestWorkflow(t *testing.T) { task1Called = false task2Called = false postconditionCalled = false + errorHandled = false err := tc.workflow.Run(context.Background(), nil, nil, nil, nil) possibleErrs := lo.Map(tc.possibleErrs, func(err error, _ int) string { return err.Error() }) @@ -168,6 +198,9 @@ func TestWorkflow(t *testing.T) { if len(possibleErrs) > 0 && err != nil && !lo.ContainsBy(possibleErrs, func(possibleErr string) bool { return possibleErr == err.Error() }) { t.Errorf("expected error of the following errors (%v), got %v", strings.Join(possibleErrs, " / "), err) } + if tc.expectedErrorHandled != errorHandled { + t.Errorf("expected error handler called: %t, got %t", tc.expectedErrorHandled, errorHandled) + } }) }