Skip to content

Commit

Permalink
Fix several issues in failure cases
Browse files Browse the repository at this point in the history
Fixed improper locking, paths that could lead to double-resolution, and more.
  • Loading branch information
jaredoconnell committed Jul 20, 2024
1 parent 44b07de commit d346d80
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 61 deletions.
34 changes: 25 additions & 9 deletions internal/step/foreach/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package foreach
import (
"context"
"fmt"
"go.arcalot.io/dgraph"
"reflect"
"sync"

Expand Down Expand Up @@ -46,9 +47,9 @@ var executeLifecycleStage = step.LifecycleStage{
"items": {},
"wait_for": {},
},
NextStages: []string{
string(StageIDOutputs),
string(StageIDFailed),
NextStages: map[string]dgraph.DependencyType{
string(StageIDOutputs): dgraph.AndDependency,
string(StageIDFailed): dgraph.CompletionAndDependency,
},
Fatal: false,
}
Expand All @@ -58,7 +59,7 @@ var outputLifecycleStage = step.LifecycleStage{
RunningName: "output",
FinishedName: "output",
InputFields: map[string]struct{}{},
NextStages: []string{},
NextStages: map[string]dgraph.DependencyType{},
Fatal: false,
}
var errorLifecycleStage = step.LifecycleStage{
Expand All @@ -67,7 +68,7 @@ var errorLifecycleStage = step.LifecycleStage{
RunningName: "processing error",
FinishedName: "error",
InputFields: map[string]struct{}{},
NextStages: []string{},
NextStages: map[string]dgraph.DependencyType{},
Fatal: true,
}

Expand Down Expand Up @@ -359,6 +360,7 @@ type runningStep struct {
inputAvailable bool
inputData chan []any
ctx context.Context
closed bool
wg sync.WaitGroup
cancel context.CancelFunc
stageChangeHandler step.StageChangeHandler
Expand All @@ -368,6 +370,11 @@ type runningStep struct {

func (r *runningStep) ProvideStageInput(stage string, input map[string]any) error {
r.lock.Lock()
if r.closed {
r.logger.Debugf("exiting foreach ProvideStageInput due to step being closed")
r.lock.Unlock()
return nil
}
switch stage {
case string(StageIDExecute):
items := input["items"]
Expand All @@ -390,8 +397,8 @@ func (r *runningStep) ProvideStageInput(stage string, input map[string]any) erro
r.currentState = step.RunningStepStateRunning
}
r.inputAvailable = true
r.inputData <- input // Send before unlock to ensure that it never gets closed before sending.
r.lock.Unlock()
r.inputData <- input
return nil
case string(StageIDOutputs):
r.lock.Unlock()
Expand All @@ -412,14 +419,19 @@ func (r *runningStep) CurrentStage() string {
}

func (r *runningStep) State() step.RunningStepState {
r.lock.Lock()
r.lock.Lock() // TODO: Determine why this gets stuck.
defer r.lock.Unlock()
return r.currentState
}

func (r *runningStep) Close() error {
r.lock.Lock()
r.closed = true
r.lock.Unlock()
r.cancel()
r.wg.Wait()
r.logger.Debugf("Closing inputData channel in foreach step provider")
close(r.inputData)
return nil
}

Expand All @@ -431,7 +443,7 @@ func (r *runningStep) ForceClose() error {
func (r *runningStep) run() {

Check failure on line 443 in internal/step/foreach/provider.go

View workflow job for this annotation

GitHub Actions / lint and test / golangci-lint

cognitive complexity 31 of func `(*runningStep).run` is high (> 30) (gocognit)
r.wg.Add(1)
defer func() {
close(r.inputData)
r.logger.Debugf("foreach run function done")
r.wg.Done()
}()
waitingForInput := false
Expand Down Expand Up @@ -471,7 +483,10 @@ func (r *runningStep) run() {
input := input
go func() {
defer func() {
<-sem
select {
case <-sem:
case <-r.ctx.Done(): // Must not deadlock if closed early.
}
wg.Done()
}()
r.logger.Debugf("Queuing item %d...", i)
Expand Down Expand Up @@ -540,6 +555,7 @@ func (r *runningStep) run() {
r.lock.Unlock()
r.stageChangeHandler.OnStepComplete(r, previousStage, &outputID, &outputData, &r.wg)
case <-r.ctx.Done():
r.logger.Debugf("context done")
return
}

Expand Down
14 changes: 7 additions & 7 deletions internal/step/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ func (l Lifecycle[StageType]) DAG() (dgraph.DirectedGraph[StageType], error) {
}
}
for _, stage := range l.Stages {
node := lang.Must2(dag.GetNodeByID(stage.Identifier()))
for _, nextStage := range stage.NextStageIDs() {
if err := node.Connect(nextStage); err != nil {
return nil, fmt.Errorf("failed to connect lifecycle stage %s to %s (%w)", node.ID(), nextStage, err)
for nextStage, dependencyType := range stage.NextStageIDs() {
nextStageNode := lang.Must2(dag.GetNodeByID(nextStage))
if err := nextStageNode.ConnectDependency(stage.Identifier(), dependencyType); err != nil {
return nil, fmt.Errorf("failed to connect lifecycle stage %s to %s (%w)", stage.Identifier(), nextStage, err)
}
}
}
Expand All @@ -52,7 +52,7 @@ type lifecycleStage interface {
// Identifier returns the ID of the stage.
Identifier() string
// NextStageIDs returns the next stage identifiers.
NextStageIDs() []string
NextStageIDs() map[string]dgraph.DependencyType
}

// LifecycleStage is the description of a single stage within a step lifecycle.
Expand All @@ -72,7 +72,7 @@ type LifecycleStage struct {
// will pause if there is no input available.
// It will automatically create a DAG node between the current and the described next stages to ensure
// that it is running in order.
NextStages []string
NextStages map[string]dgraph.DependencyType
// Fatal indicates that this stage should be treated as fatal unless handled by the workflow.
Fatal bool
}
Expand All @@ -83,7 +83,7 @@ func (l LifecycleStage) Identifier() string {
}

// NextStageIDs is a helper function that returns the next possible stages.
func (l LifecycleStage) NextStageIDs() []string {
func (l LifecycleStage) NextStageIDs() map[string]dgraph.DependencyType {
return l.NextStages
}

Expand Down
90 changes: 64 additions & 26 deletions internal/step/plugin/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package plugin
import (
"context"
"fmt"
"go.arcalot.io/dgraph"
"go.flow.arcalot.io/pluginsdk/plugin"
"reflect"
"strings"
Expand Down Expand Up @@ -168,8 +169,9 @@ var deployingLifecycleStage = step.LifecycleStage{
RunningName: "deploying",
FinishedName: "deployed",
InputFields: map[string]struct{}{string(StageIDDeploy): {}},
NextStages: []string{
string(StageIDStarting), string(StageIDDeployFailed),
NextStages: map[string]dgraph.DependencyType{
string(StageIDStarting): dgraph.AndDependency,
string(StageIDDeployFailed): dgraph.CompletionAndDependency,
},
Fatal: false,
}
Expand All @@ -189,8 +191,10 @@ var enablingLifecycleStage = step.LifecycleStage{
InputFields: map[string]struct{}{
"enabled": {},
},
NextStages: []string{
string(StageIDStarting), string(StageIDDisabled), string(StageIDCrashed),
NextStages: map[string]dgraph.DependencyType{
string(StageIDStarting): dgraph.AndDependency,
string(StageIDDisabled): dgraph.AndDependency,
string(StageIDCrashed): dgraph.CompletionAndDependency,
},
}

Expand All @@ -203,8 +207,9 @@ var startingLifecycleStage = step.LifecycleStage{
"input": {},
"wait_for": {},
},
NextStages: []string{
string(StageIDRunning), string(StageIDCrashed),
NextStages: map[string]dgraph.DependencyType{
string(StageIDRunning): dgraph.AndDependency,
string(StageIDCrashed): dgraph.CompletionAndDependency,
},
}

Expand All @@ -214,8 +219,9 @@ var runningLifecycleStage = step.LifecycleStage{
RunningName: "running",
FinishedName: "completed",
InputFields: map[string]struct{}{},
NextStages: []string{
string(StageIDOutput), string(StageIDCrashed),
NextStages: map[string]dgraph.DependencyType{
string(StageIDOutput): dgraph.AndDependency,
string(StageIDCrashed): dgraph.CompletionAndDependency,
},
}

Expand All @@ -227,8 +233,10 @@ var cancelledLifecycleStage = step.LifecycleStage{
InputFields: map[string]struct{}{
"stop_if": {},
},
NextStages: []string{
string(StageIDOutput), string(StageIDCrashed), string(StageIDDeployFailed),
NextStages: map[string]dgraph.DependencyType{
string(StageIDOutput): dgraph.AndDependency,
string(StageIDCrashed): dgraph.CompletionAndDependency,
string(StageIDDeployFailed): dgraph.CompletionAndDependency,
},
}

Expand Down Expand Up @@ -902,6 +910,7 @@ func (r *runningStep) closeComponents(closeATP bool) error {
r.cancel()
r.lock.Lock()
if r.closed {
r.lock.Unlock()
return nil // Already closed
}
var atpErr error
Expand Down Expand Up @@ -941,7 +950,7 @@ func (r *runningStep) run() {
r.logger.Warningf("failed to remove deployed container for step %s/%s", r.runID, r.pluginStepID)
}
r.lock.Unlock()
r.transitionToCancelled()
r.cancelledEarly()
return
default:
r.container = container
Expand All @@ -960,9 +969,6 @@ func (r *runningStep) run() {
return
}

// It's enabled, so the disabled stage will not occur.
r.stageChangeHandler.OnStepStageFailure(r, string(StageIDDisabled), &r.wg, err)

if err := r.startStage(container); err != nil {
r.startFailed(err)
return
Expand Down Expand Up @@ -1050,12 +1056,20 @@ func (r *runningStep) enableStage() (bool, error) {
&r.wg,
)

var enabled bool
select {
case enabled := <-r.enabledInput:
return enabled, nil
case enabled = <-r.enabledInput:
case <-r.ctx.Done():
return false, fmt.Errorf("step closed while determining enablement status")
}

if enabled {
r.lock.Lock()
// It's enabled, so the disabled stage will not occur.
r.stageChangeHandler.OnStepStageFailure(r, string(StageIDDisabled), &r.wg, fmt.Errorf("step enabled; cannot be disabled anymore"))
r.lock.Unlock()
}
return enabled, nil
}

func (r *runningStep) startStage(container deployer.Plugin) error {
Expand Down Expand Up @@ -1150,9 +1164,6 @@ func (r *runningStep) runStage() error {
var result atp.ExecutionResult
select {
case result = <-r.executionChannel:
if result.Error != nil {
return result.Error
}
case <-r.ctx.Done():
// In this case, it is being instructed to stop. A signal should have been sent.
// Shutdown (with sigterm) the container, then wait for the output (valid or error).
Expand All @@ -1170,6 +1181,10 @@ func (r *runningStep) runStage() error {

}

if result.Error != nil {
return result.Error
}

// Execution complete, move to state running stage outputs, then to state finished stage.
r.transitionStage(StageIDOutput, step.RunningStepStateRunning)
r.completeStep(r.currentStage, step.RunningStepStateFinished, &result.OutputID, &result.OutputData)
Expand All @@ -1178,6 +1193,8 @@ func (r *runningStep) runStage() error {
}

func (r *runningStep) markStageFailures(firstStage StageID, err error) {
r.lock.Lock()
defer r.lock.Unlock()
switch firstStage {
case StageIDEnabling:
r.stageChangeHandler.OnStepStageFailure(r, string(StageIDEnabling), &r.wg, err)
Expand Down Expand Up @@ -1214,15 +1231,16 @@ func (r *runningStep) deployFailed(err error) {
r.markStageFailures(StageIDEnabling, err)
}

func (r *runningStep) transitionToCancelled() {
func (r *runningStep) cancelledEarly() {
r.logger.Infof("Step %s/%s cancelled", r.runID, r.pluginStepID)
// Follow the convention of transitioning to running then finished.
r.transitionStage(StageIDCancelled, step.RunningStepStateRunning)
// Cancelled currently has no output.
r.transitionStage(StageIDCancelled, step.RunningStepStateFinished)

// This is called after deployment. So everything after deployment cannot occur.
err := fmt.Errorf("step %s/%s cancelled", r.runID, r.pluginStepID)
// Cancelled currently has no output.
// Set it as unresolvable since it's cancelled, and to prevent conflicts with its inputs.
r.transitionFromFailedStage(StageIDCancelled, step.RunningStepStateFinished, err)

// Note: This function is only called if it's cancelled during the deployment phase.
// If that changes, the stage IDs marked as failed need to be changed.
r.markStageFailures(StageIDEnabling, err)
Expand All @@ -1248,12 +1266,11 @@ func (r *runningStep) transitionToDisabled() {

err := fmt.Errorf("step %s/%s disabled", r.runID, r.pluginStepID)
r.markStageFailures(StageIDStarting, err)

}

func (r *runningStep) startFailed(err error) {
r.logger.Debugf("Start failed stage for step %s/%s", r.runID, r.pluginStepID)
r.transitionStage(StageIDCrashed, step.RunningStepStateRunning)
r.transitionFromFailedStage(StageIDCrashed, step.RunningStepStateRunning, err)
r.logger.Warningf("Plugin step %s/%s start failed. %v", r.runID, r.pluginStepID, err)

// Now it's done.
Expand Down Expand Up @@ -1288,8 +1305,29 @@ func (r *runningStep) transitionStage(newStage StageID, state step.RunningStepSt
r.transitionStageWithOutput(newStage, state, nil, nil)
}

func (r *runningStep) transitionFromFailedStage(newStage StageID, state step.RunningStepState, err error) {
r.lock.Lock()
defer r.lock.Unlock()
previousStage := string(r.currentStage)
r.currentStage = newStage
// Don't forget to update this, or else it will behave very oddly.
// First running, then finished. You can't skip states.
r.state = state
r.stageChangeHandler.OnStepStageFailure(
r,
previousStage,
&r.wg,
err,
)
}

// TransitionStage transitions the stage to the specified stage, and the state to the specified state.
func (r *runningStep) transitionStageWithOutput(newStage StageID, state step.RunningStepState, outputID *string, previousStageOutput *any) {
func (r *runningStep) transitionStageWithOutput(
newStage StageID,
state step.RunningStepState,
outputID *string,
previousStageOutput *any,
) {
// A current lack of observability into the atp client prevents
// non-fragile testing of this function.
r.lock.Lock()
Expand Down
2 changes: 2 additions & 0 deletions internal/step/plugin/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,5 @@ func TestProvider_StartFail(t *testing.T) {
assert.NoError(t, running.Close())
})
}

// TODO: Add more tests here for the current functions and code paths.
8 changes: 6 additions & 2 deletions workflow/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,12 @@ func (e *executor) connectStepDependencies(
if err != nil {
return fmt.Errorf("bug: node for current stage not found (%w)", err)
}
for _, nextStage := range stage.NextStages {
if err := currentStageNode.Connect(GetStageNodeID(stepID, nextStage)); err != nil {
for nextStage, dependencyType := range stage.NextStages {
nextStageNode, err := dag.GetNodeByID(GetStageNodeID(stepID, nextStage))
if err != nil {
return fmt.Errorf("bug: node for next stage not found (%w)", err)
}
if err := nextStageNode.ConnectDependency(currentStageNode.ID(), dependencyType); err != nil {
return fmt.Errorf("bug: cannot connect nodes (%w)", err)
}
}
Expand Down
Loading

0 comments on commit d346d80

Please sign in to comment.