Skip to content

Commit

Permalink
Move parallelism input to stage input
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredoconnell committed Dec 6, 2024
1 parent 0e7a848 commit 136d39f
Showing 1 changed file with 64 additions and 44 deletions.
108 changes: 64 additions & 44 deletions internal/step/foreach/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ var executeLifecycleStage = step.LifecycleStage{
RunningName: "executing",
FinishedName: "finished",
InputFields: map[string]struct{}{
"items": {},
"wait_for": {},
"items": {},
"parallelism": {},
"wait_for": {},
},
NextStages: map[string]dgraph.DependencyType{
string(StageIDOutputs): dgraph.AndDependency,
Expand Down Expand Up @@ -152,24 +153,6 @@ func (l *forEachProvider) ProviderSchema() map[string]*schema.PropertySchema {
nil,
[]string{"\"subworkflow.yaml\""},
),
"parallelism": schema.NewPropertySchema(
schema.NewIntSchema(
schema.PointerTo[int64](1),
nil,
nil,
),
schema.NewDisplayValue(
schema.PointerTo("Parallelism"),
schema.PointerTo("How many subworkflows to run in parallel."),
nil,
),
false,
nil,
nil,
nil,
schema.PointerTo("1"),
nil,
),
}
}

Expand Down Expand Up @@ -212,18 +195,35 @@ func (l *forEachProvider) LoadSchema(inputs map[string]any, workflowContext map[
}

return &runnableStep{
workflow: preparedWorkflow,
parallelism: inputs["parallelism"].(int64),
logger: l.logger,
workflow: preparedWorkflow,
logger: l.logger,
}, nil
}

type runnableStep struct {
workflow workflow.ExecutableWorkflow
parallelism int64
logger log.Logger
workflow workflow.ExecutableWorkflow
logger log.Logger
}

var parallelismSchema = schema.NewPropertySchema(
schema.NewIntSchema(
schema.PointerTo[int64](1),
nil,
nil,
),
schema.NewDisplayValue(
schema.PointerTo("Parallelism"),
schema.PointerTo("How many subworkflows to run in parallel."),
nil,
),
false,
nil,
nil,
nil,
schema.PointerTo("1"),
nil,
)

func (r *runnableStep) Lifecycle(_ map[string]any) (step.Lifecycle[step.LifecycleStageWithSchema], error) {
workflowOutput := r.workflow.OutputSchema()

Expand Down Expand Up @@ -265,6 +265,7 @@ func (r *runnableStep) Lifecycle(_ map[string]any) (step.Lifecycle[step.Lifecycl
nil,
nil,
),
"parallelism": parallelismSchema,
},
},
{
Expand Down Expand Up @@ -442,33 +443,36 @@ func (r *runnableStep) Start(_ map[string]any, runID string, stageChangeHandler
lock: &sync.Mutex{},
currentStage: StageIDEnabling,
currentState: step.RunningStepStateStarting,
inputData: make(chan []any, 1),
executeInput: make(chan executeInput, 1),
enabledInput: make(chan bool, 1),
workflow: r.workflow,
stageChangeHandler: stageChangeHandler,
parallelism: r.parallelism,
logger: r.logger,
}
go rs.run()
return rs, nil
}

type executeInput struct {
inputData []any
parallelism int64
}

type runningStep struct {
runID string
workflow workflow.ExecutableWorkflow
currentStage StageID
lock *sync.Mutex
currentState step.RunningStepState
executionInputAvailable bool
inputData chan []any
executeInput chan executeInput
enabledInput chan bool
enabledInputAvailable bool
ctx context.Context
closed atomic.Bool
wg sync.WaitGroup
cancel context.CancelFunc
stageChangeHandler step.StageChangeHandler
parallelism int64
logger log.Logger
}

Expand All @@ -483,23 +487,39 @@ func (r *runningStep) ProvideStageInput(stage string, input map[string]any) erro
case string(StageIDExecute):
items := input["items"]
v := reflect.ValueOf(items)
input := make([]any, v.Len())
subworkflowInputs := make([]any, v.Len())
for i := 0; i < v.Len(); i++ {
item := v.Index(i).Interface()
_, err := r.workflow.Input().Unserialize(item)
if err != nil {
return fmt.Errorf("invalid input item %d for subworkflow (%w) for run/step %s", i, err, r.runID)
}
input[i] = item
subworkflowInputs[i] = item
}
if r.executionInputAvailable {
return fmt.Errorf("input for execute workflow provided twice for run/step %s", r.runID)
}
parallelismInput := input["parallelism"]
var parallelism int64
if parallelismInput != nil {
serializedParallelismInput, err := parallelismSchema.Unserialize(parallelismInput)
if err != nil {
return fmt.Errorf("failed to unserialized parallelism input for run/step %s: %w", r.runID, err)
}
parallelism = serializedParallelismInput.(int64)
} else {
parallelism = int64(1)
}

if r.currentState == step.RunningStepStateWaitingForInput && r.currentStage == StageIDExecute {
r.currentState = step.RunningStepStateRunning
}
r.executionInputAvailable = true
r.inputData <- input // Send before unlock to ensure that it never gets closed before sending.
// Send before unlock to ensure that it never gets closed before sending.
r.executeInput <- executeInput{
inputData: subworkflowInputs,
parallelism: parallelism,
}
return nil
case string(StageIDOutputs):
return nil
Expand Down Expand Up @@ -548,7 +568,7 @@ func (r *runningStep) Close() error {
r.cancel()
r.wg.Wait()
r.logger.Debugf("Closing inputData channel in foreach step provider")
close(r.inputData)
close(r.executeInput)
return nil
}

Expand Down Expand Up @@ -759,7 +779,7 @@ func (r *runningStep) markStageFailures(firstStage StageID, err error) {

func (r *runningStep) runOnInput() {
select {
case loopData, ok := <-r.inputData:
case loopData, ok := <-r.executeInput:
if !ok {
r.logger.Debugf("aborted waiting for result in foreach")
return
Expand All @@ -771,9 +791,9 @@ func (r *runningStep) runOnInput() {
}
}

func (r *runningStep) processInput(inputData []any) {
func (r *runningStep) processInput(input executeInput) {
r.logger.Debugf("Executing subworkflow for step %s...", r.runID)
outputs, errors := r.executeSubWorkflows(inputData)
outputs, errors := r.executeSubWorkflows(input)

r.logger.Debugf("Subworkflow %s complete.", r.runID)
r.lock.Lock()
Expand All @@ -788,7 +808,7 @@ func (r *runningStep) processInput(inputData []any) {
unresolvableStage = StageIDOutputs
unresolvableError = fmt.Errorf("foreach subworkflow failed with errors (%v)", errors)
outputID = "error"
dataMap := make(map[int]any, len(inputData))
dataMap := make(map[int]any, len(input.inputData))
for i, entry := range outputs {
if entry != nil {
dataMap[i] = entry
Expand Down Expand Up @@ -832,13 +852,13 @@ func (r *runningStep) processInput(inputData []any) {
}

// returns true if there is an error.
func (r *runningStep) executeSubWorkflows(inputData []any) ([]any, map[int]string) {
itemOutputs := make([]any, len(inputData))
itemErrors := make(map[int]string, len(inputData))
func (r *runningStep) executeSubWorkflows(input executeInput) ([]any, map[int]string) {
itemOutputs := make([]any, len(input.inputData))
itemErrors := make(map[int]string, len(input.inputData))
wg := &sync.WaitGroup{}
wg.Add(len(inputData))
sem := make(chan struct{}, r.parallelism)
for i, input := range inputData {
wg.Add(len(input.inputData))
sem := make(chan struct{}, input.parallelism)
for i, input := range input.inputData {
i := i
input := input
go func() {
Expand Down

0 comments on commit 136d39f

Please sign in to comment.