diff --git a/embedded/embedded.go b/embedded/embedded.go index a002e92d..a925700a 100644 --- a/embedded/embedded.go +++ b/embedded/embedded.go @@ -218,7 +218,36 @@ func validateWorkflowDefinitionStates(wfd models.WorkflowDefinition, resources m if state.Seconds <= 0 { return errors.Errorf("invalid seconds parameter in wait %s.%s", wfd.Name, stateName) } - case models.SLStateTypeSucceed, models.SLStateTypeFail, models.SLStateTypeParallel: + case models.SLStateTypeMap: + checkNextState = true + if state.Iterator == nil { + return errors.Errorf("required parameter Iterator for Map state not given in state %s.%s", wfd.Name, stateName) + } + if state.MaxConcurrency < 0 { + return errors.Errorf("invalid MaxConcurrency in Map state %s.%s, must, cannot be negative, got %d", wfd.Name, stateName, state.MaxConcurrency) + } + innerWFD := models.WorkflowDefinition{ + Name: fmt.Sprintf("%s__Iterator", stateName), + StateMachine: state.Iterator, + } + if err := validateWorkflowDefinition(innerWFD, resources); err != nil { + return errors.Errorf("inside the Iterator for Map state %s.%s: %w", + wfd.Name, stateName, err, + ) + } + case models.SLStateTypeParallel: + for i, branch := range state.Branches { + innerWFD := models.WorkflowDefinition{ + Name: fmt.Sprintf("%s__Branch[%d]", stateName, i), + StateMachine: branch, + } + if err := validateWorkflowDefinition(innerWFD, resources); err != nil { + return errors.Errorf("inside the Branch[%d] for Parallel state %s.%s: %w", + i, wfd.Name, stateName, + ) + } + } + case models.SLStateTypeSucceed, models.SLStateTypeFail: // no op default: return errors.Errorf("invalid state type '%s' in %s.%s", state.Type, wfd.Name, stateName) @@ -326,6 +355,10 @@ func (e *Embedded) setStateMachineResources(i *models.StartWorkflowRequest, stat state.Branches[idx] = branch } stateMachine.States[stateName] = state + + case models.SLStateTypeMap: + e.setStateMachineResources(i, state.Iterator) + stateMachine.States[stateName] = state } } } diff --git a/embedded/embedded_test.go b/embedded/embedded_test.go index eba3f9df..3669934c 100644 --- a/embedded/embedded_test.go +++ b/embedded/embedded_test.go @@ -352,6 +352,48 @@ var validateWorkflowDefinitionStatesTests = []validateWorkflowDefinitionStatesTe require.NoError(t, err) }, }, + { + description: "validate map state (invalid inner)", + input: models.WorkflowDefinition{ + StateMachine: &models.SLStateMachine{ + States: map[string]models.SLState{ + "map": models.SLState{ + Type: models.SLStateTypeMap, + Iterator: &models.SLStateMachine{}, + End: true, + }, + }, + }, + }, + assertions: func(t *testing.T, err error) { + require.Error(t, err) + }, + }, + { + description: "validate map state (valid inner)", + input: models.WorkflowDefinition{ + StateMachine: &models.SLStateMachine{ + States: map[string]models.SLState{ + "map": models.SLState{ + Type: models.SLStateTypeMap, + Iterator: &models.SLStateMachine{ + States: map[string]models.SLState{ + "pass": models.SLState{ + End: true, + Type: models.SLStateTypePass, + Result: "result", + }, + }, + }, + End: true, + }, + }, + }, + }, + assertions: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, { description: "invalid state type", input: models.WorkflowDefinition{